93 lines
4.0 KiB
Python
93 lines
4.0 KiB
Python
"""
|
|
Experience Replay Buffer for Multi-Agent RL / 多智能体强化学习的经验回放池
|
|
|
|
This file implements a fixed-size replay buffer to store and sample transitions.
|
|
Each transition contains observations, actions, and rewards for both semantic and traditional agents.
|
|
本文档实现了一个固定大小的回放池,用于存储和采样状态转移。
|
|
每个状态转移包含语义智能体和传统智能体的观测、动作及奖励。
|
|
|
|
Storage Format / 存储格式: 9-field transitions / 9 字段状态转移
|
|
(obs_s, obs_b, act_s, act_b, rew_s, rew_b, next_obs_s, next_obs_b, done)
|
|
Reference / 参考文献: Section 3.2.3 Experience Replay in the project paper.
|
|
"""
|
|
import random
|
|
from collections import deque
|
|
|
|
import numpy as np
|
|
|
|
|
|
class ReplayBuffer:
|
|
"""Fixed-size experience replay buffer for two-agent transitions.
|
|
用于双智能体状态转移的固定大小经验回放池。
|
|
|
|
Stores transitions of the form / 存储如下形式的状态转移:
|
|
(obs_s, obs_b, act_s, act_b, rew_s, rew_b, next_obs_s, next_obs_b, done)
|
|
|
|
Args / 参数:
|
|
capacity (int): Maximum number of transitions to store. / 存储转换的最大数量。
|
|
"""
|
|
|
|
def __init__(self, capacity: int):
|
|
# Initialize the buffer as a double-ended queue with a maximum length / 将回放池初始化为具有最大长度的双端队列
|
|
self.buffer = deque(maxlen=capacity)
|
|
|
|
def push(self, obs_s, obs_b, act_s, act_b, rew_s, rew_b,
|
|
next_obs_s, next_obs_b, done=False):
|
|
"""
|
|
Store a single transition into the buffer. / 将单次状态转移存入回放池。
|
|
|
|
Args / 参数:
|
|
obs_s, obs_b: Observations for Semantic and Traditional agents. / 语义智能体与传统智能体的观测。
|
|
act_s, act_b: Actions taken by each agent. / 各个智能体采取的动作。
|
|
rew_s, rew_b: Rewards received by each agent. / 各个智能体获得的奖励。
|
|
next_obs_s, next_obs_b: Next observations. / 下一个状态的观测。
|
|
done (bool): Whether the episode ended. / 回合是否结束。
|
|
"""
|
|
# Append the 9-field transition to the deque / 将 9 字段的状态转移添加到队列中
|
|
self.buffer.append((
|
|
np.asarray(obs_s, dtype=np.float32),
|
|
np.asarray(obs_b, dtype=np.float32),
|
|
np.asarray(act_s, dtype=np.float32),
|
|
np.asarray(act_b, dtype=np.float32),
|
|
float(rew_s),
|
|
float(rew_b),
|
|
np.asarray(next_obs_s, dtype=np.float32),
|
|
np.asarray(next_obs_b, dtype=np.float32),
|
|
float(done),
|
|
))
|
|
|
|
def sample(self, batch_size: int):
|
|
"""
|
|
Sample a random batch of transitions for training. / 随机采样一批状态转移用于训练。
|
|
|
|
Args / 参数:
|
|
batch_size (int): Number of transitions to sample. / 采样数量。
|
|
|
|
Returns / 返回:
|
|
tuple of np.ndarray: (obs_s, obs_b, act_s, act_b, rew_s, rew_b, next_obs_s, next_obs_b, dones)
|
|
Each array has shape (batch_size, ...). / 每个数组的形状均为 (batch_size, ...)。
|
|
"""
|
|
# Randomly select 'batch_size' samples from the buffer / 从回放池中随机选择 batch_size 个样本
|
|
batch = random.sample(self.buffer, batch_size)
|
|
# Unzip the batch into separate components / 将采样到的批次拆解为独立的组件
|
|
obs_s, obs_b, act_s, act_b, rew_s, rew_b, \
|
|
next_obs_s, next_obs_b, dones = zip(*batch)
|
|
# Convert each component to a numpy array / 将各组件转换为 numpy 数组
|
|
return (
|
|
np.array(obs_s),
|
|
np.array(obs_b),
|
|
np.array(act_s),
|
|
np.array(act_b),
|
|
np.array(rew_s),
|
|
np.array(rew_b),
|
|
np.array(next_obs_s),
|
|
np.array(next_obs_b),
|
|
np.array(dones),
|
|
)
|
|
|
|
def __len__(self) -> int:
|
|
"""
|
|
Return the current size of the buffer. / 返回回放池的当前大小。
|
|
"""
|
|
return len(self.buffer)
|