377 lines
19 KiB
Python
377 lines
19 KiB
Python
"""
|
||
Co-MADDPG Algorithm for Wireless Resource Allocation / 无线资源分配中的 Co-MADDPG 算法
|
||
|
||
This file implements the Cooperative Multi-Agent Deep Deterministic Policy Gradient (Co-MADDPG) algorithm.
|
||
It features a Leader-Follower (Stackelberg) update structure for semantic and traditional agents.
|
||
本文档实现了协作式多智能体深度确定性策略梯度 (Co-MADDPG) 算法。
|
||
该算法针对语义智能体和传统智能体采用了领导者-跟随者(Stackelberg)更新结构。
|
||
|
||
Key Components / 关键组件:
|
||
- Actor-Critic Architecture / Actor-Critic 架构
|
||
- Stackelberg Update / Stackelberg 更新 (Follower update first, then Leader uses Follower's best response)
|
||
- Dynamic Cooperation Weight / 动态协作权重: \u03bb(t) = sigmoid(\u03b2*(QoE_sys - Q_th))
|
||
- Mixed Reward / 混合奖励: r_i = \u03bb*r_coop + (1-\u03bb)*r_comp
|
||
- Soft Update / 软更新: \u03b8_target \u2190 \u03c4*\u03b8 + (1-\u03c4)*\u03b8_target
|
||
|
||
Reference / 参考文献: Section 3.2 Leader-Follower Game and Co-MADDPG in the project paper.
|
||
"""
|
||
import os
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
import numpy as np
|
||
|
||
from agents.actor import Actor
|
||
from agents.critic import Critic
|
||
from agents.replay_buffer import ReplayBuffer
|
||
from agents.noise import OUNoise
|
||
|
||
class CoMADDPG:
|
||
"""
|
||
Co-MADDPG Algorithm featuring Leader-Follower updating structure.
|
||
具有领导者-跟随者更新结构的 Co-MADDPG 算法。
|
||
|
||
Agent S: Semantic Agent (Leader) / 语义智能体(领导者)
|
||
Agent B: Traditional/Bit-stream Agent (Follower) / 传统/比特流智能体(跟随者)
|
||
|
||
Paper Ref / 论文参考: Section 3.2 - Co-MADDPG Implementation details.
|
||
"""
|
||
def __init__(self, config):
|
||
self.config = config
|
||
|
||
# Dimensions derived from config / 从配置中提取维度信息
|
||
# Dimensions derived from config
|
||
self.obs_dim = config['env']['num_subcarriers'] + 4
|
||
self.act_dim = 3
|
||
|
||
# The critic observes joint states and actions / Critic 观察联合状态和动作
|
||
# The critic observes joint states and actions
|
||
obs_dim_total = self.obs_dim * 2
|
||
act_dim_total = self.act_dim * 2
|
||
|
||
# Determine device implicitly / 自动检测设备 (CUDA 或 CPU)
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# Hyperparameters / 超参数设置
|
||
# Hyperparameters
|
||
train_cfg = config.get('training', {})
|
||
self.gamma = train_cfg.get('gamma', 0.95)
|
||
self.tau = train_cfg.get('tau', 0.01)
|
||
self.beta = train_cfg.get('beta', 5.0)
|
||
self.q_threshold = train_cfg.get('q_threshold', 0.6)
|
||
self.batch_size = train_cfg.get('batch_size', 256)
|
||
|
||
actor_lr = train_cfg.get('actor_lr', 1e-4)
|
||
critic_lr = train_cfg.get('critic_lr', 3e-4)
|
||
buffer_capacity = train_cfg.get('buffer_capacity', 100000)
|
||
|
||
# Network configurations / 网络配置项
|
||
# Network configurations
|
||
net_cfg = config.get('network', {})
|
||
actor_hidden = net_cfg.get('actor_hidden', [256, 256, 128])
|
||
critic_hidden = net_cfg.get('critic_hidden', [512, 512, 256])
|
||
|
||
# Create Actor Networks / 创建 Actor 网络
|
||
self.actor_s = Actor(self.obs_dim, self.act_dim, actor_hidden).to(self.device)
|
||
self.actor_b = Actor(self.obs_dim, self.act_dim, actor_hidden).to(self.device)
|
||
|
||
# Create Actor Target Networks / 创建 Actor 目标网络
|
||
self.actor_s_target = Actor(self.obs_dim, self.act_dim, actor_hidden).to(self.device)
|
||
self.actor_b_target = Actor(self.obs_dim, self.act_dim, actor_hidden).to(self.device)
|
||
self.actor_s_target.load_state_dict(self.actor_s.state_dict())
|
||
self.actor_b_target.load_state_dict(self.actor_b.state_dict())
|
||
|
||
# Create Critic Networks / 创建 Critic 网络
|
||
self.critic_s = Critic(obs_dim_total, act_dim_total, critic_hidden).to(self.device)
|
||
self.critic_b = Critic(obs_dim_total, act_dim_total, critic_hidden).to(self.device)
|
||
|
||
# Create Critic Target Networks / 创建 Critic 目标网络
|
||
self.critic_s_target = Critic(obs_dim_total, act_dim_total, critic_hidden).to(self.device)
|
||
self.critic_b_target = Critic(obs_dim_total, act_dim_total, critic_hidden).to(self.device)
|
||
self.critic_s_target.load_state_dict(self.critic_s.state_dict())
|
||
self.critic_b_target.load_state_dict(self.critic_b.state_dict())
|
||
|
||
# Optimizers / 优化器设置
|
||
self.actor_optimizer_s = optim.Adam(self.actor_s.parameters(), lr=actor_lr)
|
||
self.actor_optimizer_b = optim.Adam(self.actor_b.parameters(), lr=actor_lr)
|
||
self.critic_optimizer_s = optim.Adam(self.critic_s.parameters(), lr=critic_lr)
|
||
self.critic_optimizer_b = optim.Adam(self.critic_b.parameters(), lr=critic_lr)
|
||
|
||
# MSE Loss for critics / Critic 使用的均方误差损失函数
|
||
self.critic_loss_fn = nn.MSELoss()
|
||
|
||
# Replay Buffer / 经验回放池
|
||
self.replay_buffer = ReplayBuffer(buffer_capacity)
|
||
|
||
# Ornstein-Uhlenbeck noise / OU 探索噪声
|
||
ou_sigma = train_cfg.get('ou_sigma_init', 0.2)
|
||
ou_theta = train_cfg.get('ou_theta', 0.15)
|
||
self.noise_s = OUNoise(self.act_dim, theta=ou_theta, sigma_init=ou_sigma)
|
||
self.noise_b = OUNoise(self.act_dim, theta=ou_theta, sigma_init=ou_sigma)
|
||
|
||
def select_action(self, obs_s, obs_b, explore=True):
|
||
"""
|
||
Determines the actions using the actor networks, with optional OU exploration noise.
|
||
使用 Actor 网络确定动作,可选择性添加 OU 探索噪声。
|
||
|
||
Args / 参数:
|
||
obs_s, obs_b: Observations for agents S and B. / 智能体 S 和 B 的观测值。
|
||
explore (bool): Whether to add noise for exploration. / 是否添加探索噪声。
|
||
|
||
Returns / 返回:
|
||
tuple: (act_s, act_b) actions for each agent. / 每个智能体的动作 (act_s, act_b)。
|
||
"""
|
||
self.actor_s.eval()
|
||
self.actor_b.eval()
|
||
|
||
with torch.no_grad():
|
||
obs_s_t = torch.FloatTensor(obs_s).unsqueeze(0).to(self.device)
|
||
obs_b_t = torch.FloatTensor(obs_b).unsqueeze(0).to(self.device)
|
||
|
||
act_s = self.actor_s(obs_s_t).cpu().numpy().squeeze(0)
|
||
act_b = self.actor_b(obs_b_t).cpu().numpy().squeeze(0)
|
||
|
||
self.actor_s.train()
|
||
self.actor_b.train()
|
||
|
||
# Apply OU noise if exploration is enabled / 如果开启探索,则添加 OU 噪声
|
||
if explore:
|
||
act_s += self.noise_s.sample()
|
||
act_b += self.noise_b.sample()
|
||
|
||
# Formula / 公式: act \u2208 [0, 1]
|
||
# Clip mapping bounds as enforced by the (tanh + 1)/2 activation in Actor / 按照 Actor 中的激活函数限制动作范围到 [0, 1]
|
||
act_s = np.clip(act_s, 0.0, 1.0)
|
||
act_b = np.clip(act_b, 0.0, 1.0)
|
||
|
||
return act_s, act_b
|
||
|
||
def compute_lambda(self, qoe_sys):
|
||
"""
|
||
Compute dynamic cooperation weight \u03bb(t). / 计算动态协作权重 \u03bb(t)。
|
||
|
||
Formula / 公式: \u03bb(t) = sigmoid(\u03b2 * (QoE_sys - Q_th))
|
||
|
||
Args / 参数:
|
||
qoe_sys (float): Current system QoE. / 当前系统 QoE。
|
||
|
||
Returns / 返回:
|
||
float: Cooperation weight \u03bb(t) \u2208 [0, 1]. / 协作权重 \u03bb(t)。
|
||
"""
|
||
return 1.0 / (1.0 + np.exp(-self.beta * (qoe_sys - self.q_threshold)))
|
||
|
||
def compute_rewards(self, qoe_s, qoe_b, qoe_sys):
|
||
"""
|
||
Compute joint dynamically weighted rewards based on \u03bb cooperation factor.
|
||
基于 \u03bb 协作因子计算动态加权的联合奖励。
|
||
|
||
Formula / 公式: r_i = \u03bb * r_coop_i + (1 - \u03bb) * r_comp_i
|
||
|
||
Args / 参数:
|
||
qoe_s, qoe_b, qoe_sys: QoE values for semantic, traditional, and system levels. / 语义层、传统层和系统层的 QoE 值。
|
||
|
||
Returns / 返回:
|
||
tuple: (r_s, r_b, lambda_val) final mixed rewards and the cooperation weight. / 最终混合奖励与协作权重。
|
||
"""
|
||
lambda_val = self.compute_lambda(qoe_sys)
|
||
|
||
rew_cfg = self.config.get('reward', {})
|
||
coop_self = rew_cfg.get('coop_self', 0.5)
|
||
coop_other = rew_cfg.get('coop_other', 0.3)
|
||
coop_sys = rew_cfg.get('coop_sys', 0.2)
|
||
|
||
comp_self = rew_cfg.get('comp_self', 0.8)
|
||
comp_sys = rew_cfg.get('comp_sys', 0.2)
|
||
|
||
# Cooperative logic (shared benefit mindset) / 协作逻辑(共同利益导向)
|
||
# Formula / 公式: r_coop_i = 0.5*qoe_i + 0.3*qoe_j + 0.2*qoe_sys
|
||
r_coop_s = coop_self * qoe_s + coop_other * qoe_b + coop_sys * qoe_sys
|
||
r_coop_b = coop_self * qoe_b + coop_other * qoe_s + coop_sys * qoe_sys
|
||
|
||
# Competitive logic (individual maximization mindset) / 竞争逻辑(个体利益导向)
|
||
# Formula / 公式: r_comp_i = 0.8*qoe_i + 0.2*qoe_sys
|
||
r_comp_s = comp_self * qoe_s + comp_sys * qoe_sys
|
||
r_comp_b = comp_self * qoe_b + comp_sys * qoe_sys
|
||
|
||
# Dynamically balanced reward (mix based on System QoE state vs threshold) / 动态平衡奖励(基于系统 QoE 状态与阈值的混合)
|
||
r_s = lambda_val * r_coop_s + (1.0 - lambda_val) * r_comp_s
|
||
r_b = lambda_val * r_coop_b + (1.0 - lambda_val) * r_comp_b
|
||
|
||
return r_s, r_b, lambda_val
|
||
|
||
def update(self):
|
||
"""
|
||
Perform a gradient update loop containing Leader-Follower sequential methodology.
|
||
执行包含领导者-跟随者顺序方法的梯度更新循环。
|
||
|
||
Update Order / 更新顺序:
|
||
1. Update Follower (Agent B) Critic & Actor / 更新跟随者(智能体 B)的 Critic 和 Actor
|
||
2. Update Leader (Agent S) Critic & Actor / 更新领导者(智能体 S)的 Critic 和 Actor
|
||
|
||
Returns / 返回:
|
||
tuple: (critic_loss_s, critic_loss_b, actor_loss_s, actor_loss_b) or None if buffer not ready. / 各项损失值。
|
||
"""
|
||
if len(self.replay_buffer) < self.batch_size:
|
||
return None
|
||
|
||
# Sample batch from replay buffer / 从回放池中采样批次数据
|
||
batch = self.replay_buffer.sample(self.batch_size)
|
||
|
||
# Destructure standardized tuple. Assumes order:
|
||
# (obs_s, obs_b, act_s, act_b, rew_s, rew_b, next_obs_s, next_obs_b, dones)
|
||
obs_s, obs_b, act_s, act_b, rew_s, rew_b, next_obs_s, next_obs_b, dones = batch
|
||
|
||
obs_s = torch.FloatTensor(obs_s).to(self.device)
|
||
obs_b = torch.FloatTensor(obs_b).to(self.device)
|
||
act_s = torch.FloatTensor(act_s).to(self.device)
|
||
act_b = torch.FloatTensor(act_b).to(self.device)
|
||
rew_s = torch.FloatTensor(rew_s).unsqueeze(1).to(self.device)
|
||
rew_b = torch.FloatTensor(rew_b).unsqueeze(1).to(self.device)
|
||
next_obs_s = torch.FloatTensor(next_obs_s).to(self.device)
|
||
next_obs_b = torch.FloatTensor(next_obs_b).to(self.device)
|
||
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
|
||
|
||
# Construct joint states & actions for centralized critic / 构建用于集中式 Critic 的联合状态和动作空间
|
||
obs_all = torch.cat([obs_s, obs_b], dim=1)
|
||
next_obs_all = torch.cat([next_obs_s, next_obs_b], dim=1)
|
||
act_all = torch.cat([act_s, act_b], dim=1)
|
||
|
||
# Target actions for next state / 计算下一状态的目标动作值
|
||
with torch.no_grad():
|
||
next_act_s_target = self.actor_s_target(next_obs_s)
|
||
next_act_b_target = self.actor_b_target(next_obs_b)
|
||
next_act_all_target = torch.cat([next_act_s_target, next_act_b_target], dim=1)
|
||
|
||
# =====================================================================
|
||
# PHASE 1: Update Follower (Agent B) FIRST / 第一阶段:首先更新跟随者 (智能体 B)
|
||
# Stackelberg methodology / Stackelberg 方法论: Follower responds to Leader's action / 跟随者响应领导者的动作
|
||
# =====================================================================
|
||
# PHASE 1: Update Follower (Agent B) FIRST
|
||
# =====================================================================
|
||
|
||
# Update Critic B / 更新智能体 B 的 Critic
|
||
with torch.no_grad():
|
||
target_q_b_next = self.critic_b_target(next_obs_all, next_act_all_target)
|
||
target_q_b = rew_b + self.gamma * (1.0 - dones) * target_q_b_next
|
||
|
||
current_q_b = self.critic_b(obs_all, act_all)
|
||
critic_loss_b = self.critic_loss_fn(current_q_b, target_q_b)
|
||
|
||
self.critic_optimizer_b.zero_grad()
|
||
critic_loss_b.backward()
|
||
self.critic_optimizer_b.step()
|
||
|
||
# Update Actor B / 更新智能体 B 的 Actor
|
||
# Loss: -mean(critic_b(obs_all, [act_s_from_buffer, actor_b(obs_b)]))
|
||
# In Phase 1, the follower assumes leader's action from replay buffer / 在第一阶段,跟随者假定领导者的动作为回放池中的动作
|
||
# Loss: -mean(critic_b(obs_all, [act_s_from_buffer, actor_b(obs_b)]))
|
||
new_act_b = self.actor_b(obs_b)
|
||
act_all_for_b = torch.cat([act_s, new_act_b], dim=1)
|
||
|
||
actor_loss_b = -self.critic_b(obs_all, act_all_for_b).mean()
|
||
|
||
self.actor_optimizer_b.zero_grad()
|
||
actor_loss_b.backward()
|
||
self.actor_optimizer_b.step()
|
||
|
||
# =====================================================================
|
||
# PHASE 2: Update Leader (Agent S) with UPDATED Follower / 第二阶段:基于更新后的跟随者更新领导者 (智能体 S)
|
||
# Leader S uses Follower B's best response / 领导者 S 利用跟随者 B 的最佳响应函数
|
||
# =====================================================================
|
||
# PHASE 2: Update Leader (Agent S) with UPDATED Follower
|
||
# =====================================================================
|
||
|
||
# Update Critic S / 更新智能体 S 的 Critic
|
||
with torch.no_grad():
|
||
target_q_s_next = self.critic_s_target(next_obs_all, next_act_all_target)
|
||
target_q_s = rew_s + self.gamma * (1.0 - dones) * target_q_s_next
|
||
|
||
current_q_s = self.critic_s(obs_all, act_all)
|
||
critic_loss_s = self.critic_loss_fn(current_q_s, target_q_s)
|
||
|
||
self.critic_optimizer_s.zero_grad()
|
||
critic_loss_s.backward()
|
||
self.critic_optimizer_s.step()
|
||
|
||
# Update Actor S / 更新智能体 S 的 Actor
|
||
# KEY / 核心逻辑: Use newly updated actor_b(obs_b).detach() as follower's assumed action / 使用刚更新的 actor_b(obs_b).detach() 作为跟随者的预估动作
|
||
# This represents the Leader's knowledge of the Follower's best response / 这代表了领导者对跟随者最佳响应的认知
|
||
# KEY: Use newly updated actor_b(obs_b).detach() as follower's assumed action
|
||
new_act_s = self.actor_s(obs_s)
|
||
updated_act_b_detached = self.actor_b(obs_b).detach()
|
||
act_all_for_s = torch.cat([new_act_s, updated_act_b_detached], dim=1)
|
||
|
||
actor_loss_s = -self.critic_s(obs_all, act_all_for_s).mean()
|
||
|
||
self.actor_optimizer_s.zero_grad()
|
||
actor_loss_s.backward()
|
||
self.actor_optimizer_s.step()
|
||
|
||
# =====================================================================
|
||
# Target Networks Soft Update / 目标网络软更新
|
||
# Formula / 公式: \u03b8_target \u2190 \u03c4 * \u03b8 + (1 - \u03c4) * \u03b8_target
|
||
# =====================================================================
|
||
# Target Networks Soft Update
|
||
# =====================================================================
|
||
self.soft_update(self.actor_s_target, self.actor_s, self.tau)
|
||
self.soft_update(self.actor_b_target, self.actor_b, self.tau)
|
||
self.soft_update(self.critic_s_target, self.critic_s, self.tau)
|
||
self.soft_update(self.critic_b_target, self.critic_b, self.tau)
|
||
|
||
return critic_loss_s.item(), critic_loss_b.item(), actor_loss_s.item(), actor_loss_b.item()
|
||
|
||
def soft_update(self, target, source, tau):
|
||
"""
|
||
Polyak averaging for target network parameters. / 目标网络参数的 Polyak 平均(软更新)。
|
||
|
||
Args / 参数:
|
||
target: Target network. / 目标网络。
|
||
source: Source network. / 源网络。
|
||
tau (float): Soft update interpolation factor. / 软更新插值因子 \u03c4。
|
||
"""
|
||
for target_param, source_param in zip(target.parameters(), source.parameters()):
|
||
target_param.data.copy_(tau * source_param.data + (1.0 - tau) * target_param.data)
|
||
|
||
def save(self, path):
|
||
"""
|
||
Saves all 4 network state_dicts and optimizers. / 保存所有 4 个网络的权重和优化器状态。
|
||
|
||
Args / 参数:
|
||
path (str): File path to save the checkpoint. / 保存检查点的文件路径。
|
||
"""
|
||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||
torch.save({
|
||
'actor_s': self.actor_s.state_dict(),
|
||
'actor_b': self.actor_b.state_dict(),
|
||
'critic_s': self.critic_s.state_dict(),
|
||
'critic_b': self.critic_b.state_dict(),
|
||
'actor_optimizer_s': self.actor_optimizer_s.state_dict(),
|
||
'actor_optimizer_b': self.actor_optimizer_b.state_dict(),
|
||
'critic_optimizer_s': self.critic_optimizer_s.state_dict(),
|
||
'critic_optimizer_b': self.critic_optimizer_b.state_dict(),
|
||
}, path)
|
||
|
||
def load(self, path):
|
||
"""
|
||
Loads all 4 networks and optimizer parameters from saved states. / 从保存的状态加载所有 4 个网络和优化器参数。
|
||
|
||
Args / 参数:
|
||
path (str): File path of the checkpoint to load. / 要加载的检查点文件路径。
|
||
"""
|
||
checkpoint = torch.load(path, map_location=self.device)
|
||
self.actor_s.load_state_dict(checkpoint['actor_s'])
|
||
self.actor_b.load_state_dict(checkpoint['actor_b'])
|
||
self.critic_s.load_state_dict(checkpoint['critic_s'])
|
||
self.critic_b.load_state_dict(checkpoint['critic_b'])
|
||
|
||
self.actor_optimizer_s.load_state_dict(checkpoint['actor_optimizer_s'])
|
||
self.actor_optimizer_b.load_state_dict(checkpoint['actor_optimizer_b'])
|
||
self.critic_optimizer_s.load_state_dict(checkpoint['critic_optimizer_s'])
|
||
self.critic_optimizer_b.load_state_dict(checkpoint['critic_optimizer_b'])
|
||
|
||
# Hard sync the target networks after loading
|
||
self.actor_s_target.load_state_dict(self.actor_s.state_dict())
|
||
self.actor_b_target.load_state_dict(self.actor_b.state_dict())
|
||
self.critic_s_target.load_state_dict(self.critic_s.state_dict())
|
||
self.critic_b_target.load_state_dict(self.critic_b.state_dict())
|