SemanticCommunication/code/agents/noise.py

75 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Ornstein-Uhlenbeck (OU) Exploration Noise / OU 探索噪声
This file implements the OU noise process for continuous action exploration.
The noise is temporally correlated and features linear sigma decay over training.
本文档实现了用于连续动作探索的 OU 噪声过程。
该噪声具有时间相关性并在训练过程中具有线性标准差sigma衰减特性。
Formula / 公式: dx = \u03b8(\u03bc - x)dt + \u03c3dW
Decay / 衰减: Linear sigma decay over specified decay period. / 在指定的衰减周期内线性衰减 sigma。
Reference / 参考文献: Section 3.2.2 Exploration Mechanism in the project paper.
"""
import numpy as np
class OUNoise:
"""Ornstein-Uhlenbeck process for temporally correlated exploration noise.
用于生成具有时间相关性的探索噪声的 Ornstein-Uhlenbeck 过程。
Formula / 公式: dx = \u03b8(\u03bc - x)dt + \u03c3dW
Args / 参数:
action_dim (int): Dimensionality of the action space. / 动作空间的维度。
mu (float): Long-term mean of the process. / 过程的长期均值 \u03bc
theta (float): Mean-reversion rate. / 均值回归速率 \u03b8
sigma_init (float): Initial standard deviation. / 初始标准差 \u03c3
sigma_min (float): Minimum standard deviation after decay. / 衰减后的最小标准差。
decay_period (int): Number of episodes over which sigma decays linearly. / sigma 线性衰减的总回合数。
"""
def __init__(self, action_dim: int, mu: float = 0.0, theta: float = 0.15,
sigma_init: float = 0.2, sigma_min: float = 0.01,
decay_period: int = 5000):
# Initialize noise parameters and state / 初始化噪声参数和状态
self.action_dim = action_dim
self.mu = mu
self.theta = theta
self.sigma_init = sigma_init
self.sigma_min = sigma_min
self.sigma = sigma_init
self.decay_period = decay_period
self.state = np.full(action_dim, mu, dtype=np.float64)
def reset(self):
"""
Reset the internal state to the mean. / 将内部状态重置为均值 \u03bc
"""
self.state = np.full(self.action_dim, self.mu, dtype=np.float64)
def decay_sigma(self, episode: int):
"""Linearly decay sigma from sigma_init to sigma_min over decay_period.
在衰减周期内,将 sigma 从初始值线性衰减到最小值。
Args / 参数:
episode (int): Current episode number. / 当前回合数。
"""
# Calculate decay fraction / 计算衰减比例
frac = min(1.0, episode / max(1, self.decay_period))
# Linear decay formula / 线性衰减公式: \u03c3 = \u03c3_init + frac * (\u03c3_min - \u03c3_init)
self.sigma = self.sigma_init + frac * (self.sigma_min - self.sigma_init)
def sample(self) -> np.ndarray:
"""
Generate a noise sample via the OU process. / 通过 OU 过程生成噪声样本。
Returns / 返回:
np.ndarray: Noise vector of shape (action_dim,). / 形状为 (action_dim,) 的噪声向量。
"""
# OU Formula / OU 公式: dx = \u03b8 * (\u03bc - x) + \u03c3 * N(0,1)
dx = (self.theta * (self.mu - self.state)
+ self.sigma * np.random.randn(self.action_dim))
# Update state / 更新状态: x = x + dx
self.state = self.state + dx
return self.state.copy()