SemanticCommunication/code/agents/replay_buffer.py

93 lines
4.0 KiB
Python

"""
Experience Replay Buffer for Multi-Agent RL / 多智能体强化学习的经验回放池
This file implements a fixed-size replay buffer to store and sample transitions.
Each transition contains observations, actions, and rewards for both semantic and traditional agents.
本文档实现了一个固定大小的回放池,用于存储和采样状态转移。
每个状态转移包含语义智能体和传统智能体的观测、动作及奖励。
Storage Format / 存储格式: 9-field transitions / 9 字段状态转移
(obs_s, obs_b, act_s, act_b, rew_s, rew_b, next_obs_s, next_obs_b, done)
Reference / 参考文献: Section 3.2.3 Experience Replay in the project paper.
"""
import random
from collections import deque
import numpy as np
class ReplayBuffer:
"""Fixed-size experience replay buffer for two-agent transitions.
用于双智能体状态转移的固定大小经验回放池。
Stores transitions of the form / 存储如下形式的状态转移:
(obs_s, obs_b, act_s, act_b, rew_s, rew_b, next_obs_s, next_obs_b, done)
Args / 参数:
capacity (int): Maximum number of transitions to store. / 存储转换的最大数量。
"""
def __init__(self, capacity: int):
# Initialize the buffer as a double-ended queue with a maximum length / 将回放池初始化为具有最大长度的双端队列
self.buffer = deque(maxlen=capacity)
def push(self, obs_s, obs_b, act_s, act_b, rew_s, rew_b,
next_obs_s, next_obs_b, done=False):
"""
Store a single transition into the buffer. / 将单次状态转移存入回放池。
Args / 参数:
obs_s, obs_b: Observations for Semantic and Traditional agents. / 语义智能体与传统智能体的观测。
act_s, act_b: Actions taken by each agent. / 各个智能体采取的动作。
rew_s, rew_b: Rewards received by each agent. / 各个智能体获得的奖励。
next_obs_s, next_obs_b: Next observations. / 下一个状态的观测。
done (bool): Whether the episode ended. / 回合是否结束。
"""
# Append the 9-field transition to the deque / 将 9 字段的状态转移添加到队列中
self.buffer.append((
np.asarray(obs_s, dtype=np.float32),
np.asarray(obs_b, dtype=np.float32),
np.asarray(act_s, dtype=np.float32),
np.asarray(act_b, dtype=np.float32),
float(rew_s),
float(rew_b),
np.asarray(next_obs_s, dtype=np.float32),
np.asarray(next_obs_b, dtype=np.float32),
float(done),
))
def sample(self, batch_size: int):
"""
Sample a random batch of transitions for training. / 随机采样一批状态转移用于训练。
Args / 参数:
batch_size (int): Number of transitions to sample. / 采样数量。
Returns / 返回:
tuple of np.ndarray: (obs_s, obs_b, act_s, act_b, rew_s, rew_b, next_obs_s, next_obs_b, dones)
Each array has shape (batch_size, ...). / 每个数组的形状均为 (batch_size, ...)。
"""
# Randomly select 'batch_size' samples from the buffer / 从回放池中随机选择 batch_size 个样本
batch = random.sample(self.buffer, batch_size)
# Unzip the batch into separate components / 将采样到的批次拆解为独立的组件
obs_s, obs_b, act_s, act_b, rew_s, rew_b, \
next_obs_s, next_obs_b, dones = zip(*batch)
# Convert each component to a numpy array / 将各组件转换为 numpy 数组
return (
np.array(obs_s),
np.array(obs_b),
np.array(act_s),
np.array(act_b),
np.array(rew_s),
np.array(rew_b),
np.array(next_obs_s),
np.array(next_obs_b),
np.array(dones),
)
def __len__(self) -> int:
"""
Return the current size of the buffer. / 返回回放池的当前大小。
"""
return len(self.buffer)