75 lines
3.4 KiB
Python
75 lines
3.4 KiB
Python
"""
|
||
Ornstein-Uhlenbeck (OU) Exploration Noise / OU 探索噪声
|
||
|
||
This file implements the OU noise process for continuous action exploration.
|
||
The noise is temporally correlated and features linear sigma decay over training.
|
||
本文档实现了用于连续动作探索的 OU 噪声过程。
|
||
该噪声具有时间相关性,并在训练过程中具有线性标准差(sigma)衰减特性。
|
||
|
||
Formula / 公式: dx = \u03b8(\u03bc - x)dt + \u03c3dW
|
||
Decay / 衰减: Linear sigma decay over specified decay period. / 在指定的衰减周期内线性衰减 sigma。
|
||
Reference / 参考文献: Section 3.2.2 Exploration Mechanism in the project paper.
|
||
"""
|
||
import numpy as np
|
||
|
||
|
||
class OUNoise:
|
||
"""Ornstein-Uhlenbeck process for temporally correlated exploration noise.
|
||
用于生成具有时间相关性的探索噪声的 Ornstein-Uhlenbeck 过程。
|
||
|
||
Formula / 公式: dx = \u03b8(\u03bc - x)dt + \u03c3dW
|
||
|
||
Args / 参数:
|
||
action_dim (int): Dimensionality of the action space. / 动作空间的维度。
|
||
mu (float): Long-term mean of the process. / 过程的长期均值 \u03bc。
|
||
theta (float): Mean-reversion rate. / 均值回归速率 \u03b8。
|
||
sigma_init (float): Initial standard deviation. / 初始标准差 \u03c3。
|
||
sigma_min (float): Minimum standard deviation after decay. / 衰减后的最小标准差。
|
||
decay_period (int): Number of episodes over which sigma decays linearly. / sigma 线性衰减的总回合数。
|
||
"""
|
||
|
||
def __init__(self, action_dim: int, mu: float = 0.0, theta: float = 0.15,
|
||
sigma_init: float = 0.2, sigma_min: float = 0.01,
|
||
decay_period: int = 5000):
|
||
# Initialize noise parameters and state / 初始化噪声参数和状态
|
||
self.action_dim = action_dim
|
||
self.mu = mu
|
||
self.theta = theta
|
||
self.sigma_init = sigma_init
|
||
self.sigma_min = sigma_min
|
||
self.sigma = sigma_init
|
||
self.decay_period = decay_period
|
||
self.state = np.full(action_dim, mu, dtype=np.float64)
|
||
|
||
def reset(self):
|
||
"""
|
||
Reset the internal state to the mean. / 将内部状态重置为均值 \u03bc。
|
||
"""
|
||
self.state = np.full(self.action_dim, self.mu, dtype=np.float64)
|
||
|
||
def decay_sigma(self, episode: int):
|
||
"""Linearly decay sigma from sigma_init to sigma_min over decay_period.
|
||
在衰减周期内,将 sigma 从初始值线性衰减到最小值。
|
||
|
||
Args / 参数:
|
||
episode (int): Current episode number. / 当前回合数。
|
||
"""
|
||
# Calculate decay fraction / 计算衰减比例
|
||
frac = min(1.0, episode / max(1, self.decay_period))
|
||
# Linear decay formula / 线性衰减公式: \u03c3 = \u03c3_init + frac * (\u03c3_min - \u03c3_init)
|
||
self.sigma = self.sigma_init + frac * (self.sigma_min - self.sigma_init)
|
||
|
||
def sample(self) -> np.ndarray:
|
||
"""
|
||
Generate a noise sample via the OU process. / 通过 OU 过程生成噪声样本。
|
||
|
||
Returns / 返回:
|
||
np.ndarray: Noise vector of shape (action_dim,). / 形状为 (action_dim,) 的噪声向量。
|
||
"""
|
||
# OU Formula / OU 公式: dx = \u03b8 * (\u03bc - x) + \u03c3 * N(0,1)
|
||
dx = (self.theta * (self.mu - self.state)
|
||
+ self.sigma * np.random.randn(self.action_dim))
|
||
# Update state / 更新状态: x = x + dx
|
||
self.state = self.state + dx
|
||
return self.state.copy()
|