SemanticCommunication/code/agents/co_maddpg.py

377 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Co-MADDPG Algorithm for Wireless Resource Allocation / 无线资源分配中的 Co-MADDPG 算法
This file implements the Cooperative Multi-Agent Deep Deterministic Policy Gradient (Co-MADDPG) algorithm.
It features a Leader-Follower (Stackelberg) update structure for semantic and traditional agents.
本文档实现了协作式多智能体深度确定性策略梯度 (Co-MADDPG) 算法。
该算法针对语义智能体和传统智能体采用了领导者-跟随者Stackelberg更新结构。
Key Components / 关键组件:
- Actor-Critic Architecture / Actor-Critic 架构
- Stackelberg Update / Stackelberg 更新 (Follower update first, then Leader uses Follower's best response)
- Dynamic Cooperation Weight / 动态协作权重: \u03bb(t) = sigmoid(\u03b2*(QoE_sys - Q_th))
- Mixed Reward / 混合奖励: r_i = \u03bb*r_coop + (1-\u03bb)*r_comp
- Soft Update / 软更新: \u03b8_target \u2190 \u03c4*\u03b8 + (1-\u03c4)*\u03b8_target
Reference / 参考文献: Section 3.2 Leader-Follower Game and Co-MADDPG in the project paper.
"""
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from agents.actor import Actor
from agents.critic import Critic
from agents.replay_buffer import ReplayBuffer
from agents.noise import OUNoise
class CoMADDPG:
"""
Co-MADDPG Algorithm featuring Leader-Follower updating structure.
具有领导者-跟随者更新结构的 Co-MADDPG 算法。
Agent S: Semantic Agent (Leader) / 语义智能体(领导者)
Agent B: Traditional/Bit-stream Agent (Follower) / 传统/比特流智能体(跟随者)
Paper Ref / 论文参考: Section 3.2 - Co-MADDPG Implementation details.
"""
def __init__(self, config):
self.config = config
# Dimensions derived from config / 从配置中提取维度信息
# Dimensions derived from config
self.obs_dim = config['env']['num_subcarriers'] + 4
self.act_dim = 3
# The critic observes joint states and actions / Critic 观察联合状态和动作
# The critic observes joint states and actions
obs_dim_total = self.obs_dim * 2
act_dim_total = self.act_dim * 2
# Determine device implicitly / 自动检测设备 (CUDA 或 CPU)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters / 超参数设置
# Hyperparameters
train_cfg = config.get('training', {})
self.gamma = train_cfg.get('gamma', 0.95)
self.tau = train_cfg.get('tau', 0.01)
self.beta = train_cfg.get('beta', 5.0)
self.q_threshold = train_cfg.get('q_threshold', 0.6)
self.batch_size = train_cfg.get('batch_size', 256)
actor_lr = train_cfg.get('actor_lr', 1e-4)
critic_lr = train_cfg.get('critic_lr', 3e-4)
buffer_capacity = train_cfg.get('buffer_capacity', 100000)
# Network configurations / 网络配置项
# Network configurations
net_cfg = config.get('network', {})
actor_hidden = net_cfg.get('actor_hidden', [256, 256, 128])
critic_hidden = net_cfg.get('critic_hidden', [512, 512, 256])
# Create Actor Networks / 创建 Actor 网络
self.actor_s = Actor(self.obs_dim, self.act_dim, actor_hidden).to(self.device)
self.actor_b = Actor(self.obs_dim, self.act_dim, actor_hidden).to(self.device)
# Create Actor Target Networks / 创建 Actor 目标网络
self.actor_s_target = Actor(self.obs_dim, self.act_dim, actor_hidden).to(self.device)
self.actor_b_target = Actor(self.obs_dim, self.act_dim, actor_hidden).to(self.device)
self.actor_s_target.load_state_dict(self.actor_s.state_dict())
self.actor_b_target.load_state_dict(self.actor_b.state_dict())
# Create Critic Networks / 创建 Critic 网络
self.critic_s = Critic(obs_dim_total, act_dim_total, critic_hidden).to(self.device)
self.critic_b = Critic(obs_dim_total, act_dim_total, critic_hidden).to(self.device)
# Create Critic Target Networks / 创建 Critic 目标网络
self.critic_s_target = Critic(obs_dim_total, act_dim_total, critic_hidden).to(self.device)
self.critic_b_target = Critic(obs_dim_total, act_dim_total, critic_hidden).to(self.device)
self.critic_s_target.load_state_dict(self.critic_s.state_dict())
self.critic_b_target.load_state_dict(self.critic_b.state_dict())
# Optimizers / 优化器设置
self.actor_optimizer_s = optim.Adam(self.actor_s.parameters(), lr=actor_lr)
self.actor_optimizer_b = optim.Adam(self.actor_b.parameters(), lr=actor_lr)
self.critic_optimizer_s = optim.Adam(self.critic_s.parameters(), lr=critic_lr)
self.critic_optimizer_b = optim.Adam(self.critic_b.parameters(), lr=critic_lr)
# MSE Loss for critics / Critic 使用的均方误差损失函数
self.critic_loss_fn = nn.MSELoss()
# Replay Buffer / 经验回放池
self.replay_buffer = ReplayBuffer(buffer_capacity)
# Ornstein-Uhlenbeck noise / OU 探索噪声
ou_sigma = train_cfg.get('ou_sigma_init', 0.2)
ou_theta = train_cfg.get('ou_theta', 0.15)
self.noise_s = OUNoise(self.act_dim, theta=ou_theta, sigma_init=ou_sigma)
self.noise_b = OUNoise(self.act_dim, theta=ou_theta, sigma_init=ou_sigma)
def select_action(self, obs_s, obs_b, explore=True):
"""
Determines the actions using the actor networks, with optional OU exploration noise.
使用 Actor 网络确定动作,可选择性添加 OU 探索噪声。
Args / 参数:
obs_s, obs_b: Observations for agents S and B. / 智能体 S 和 B 的观测值。
explore (bool): Whether to add noise for exploration. / 是否添加探索噪声。
Returns / 返回:
tuple: (act_s, act_b) actions for each agent. / 每个智能体的动作 (act_s, act_b)。
"""
self.actor_s.eval()
self.actor_b.eval()
with torch.no_grad():
obs_s_t = torch.FloatTensor(obs_s).unsqueeze(0).to(self.device)
obs_b_t = torch.FloatTensor(obs_b).unsqueeze(0).to(self.device)
act_s = self.actor_s(obs_s_t).cpu().numpy().squeeze(0)
act_b = self.actor_b(obs_b_t).cpu().numpy().squeeze(0)
self.actor_s.train()
self.actor_b.train()
# Apply OU noise if exploration is enabled / 如果开启探索,则添加 OU 噪声
if explore:
act_s += self.noise_s.sample()
act_b += self.noise_b.sample()
# Formula / 公式: act \u2208 [0, 1]
# Clip mapping bounds as enforced by the (tanh + 1)/2 activation in Actor / 按照 Actor 中的激活函数限制动作范围到 [0, 1]
act_s = np.clip(act_s, 0.0, 1.0)
act_b = np.clip(act_b, 0.0, 1.0)
return act_s, act_b
def compute_lambda(self, qoe_sys):
"""
Compute dynamic cooperation weight \u03bb(t). / 计算动态协作权重 \u03bb(t)。
Formula / 公式: \u03bb(t) = sigmoid(\u03b2 * (QoE_sys - Q_th))
Args / 参数:
qoe_sys (float): Current system QoE. / 当前系统 QoE。
Returns / 返回:
float: Cooperation weight \u03bb(t) \u2208 [0, 1]. / 协作权重 \u03bb(t)。
"""
return 1.0 / (1.0 + np.exp(-self.beta * (qoe_sys - self.q_threshold)))
def compute_rewards(self, qoe_s, qoe_b, qoe_sys):
"""
Compute joint dynamically weighted rewards based on \u03bb cooperation factor.
基于 \u03bb 协作因子计算动态加权的联合奖励。
Formula / 公式: r_i = \u03bb * r_coop_i + (1 - \u03bb) * r_comp_i
Args / 参数:
qoe_s, qoe_b, qoe_sys: QoE values for semantic, traditional, and system levels. / 语义层、传统层和系统层的 QoE 值。
Returns / 返回:
tuple: (r_s, r_b, lambda_val) final mixed rewards and the cooperation weight. / 最终混合奖励与协作权重。
"""
lambda_val = self.compute_lambda(qoe_sys)
rew_cfg = self.config.get('reward', {})
coop_self = rew_cfg.get('coop_self', 0.5)
coop_other = rew_cfg.get('coop_other', 0.3)
coop_sys = rew_cfg.get('coop_sys', 0.2)
comp_self = rew_cfg.get('comp_self', 0.8)
comp_sys = rew_cfg.get('comp_sys', 0.2)
# Cooperative logic (shared benefit mindset) / 协作逻辑(共同利益导向)
# Formula / 公式: r_coop_i = 0.5*qoe_i + 0.3*qoe_j + 0.2*qoe_sys
r_coop_s = coop_self * qoe_s + coop_other * qoe_b + coop_sys * qoe_sys
r_coop_b = coop_self * qoe_b + coop_other * qoe_s + coop_sys * qoe_sys
# Competitive logic (individual maximization mindset) / 竞争逻辑(个体利益导向)
# Formula / 公式: r_comp_i = 0.8*qoe_i + 0.2*qoe_sys
r_comp_s = comp_self * qoe_s + comp_sys * qoe_sys
r_comp_b = comp_self * qoe_b + comp_sys * qoe_sys
# Dynamically balanced reward (mix based on System QoE state vs threshold) / 动态平衡奖励(基于系统 QoE 状态与阈值的混合)
r_s = lambda_val * r_coop_s + (1.0 - lambda_val) * r_comp_s
r_b = lambda_val * r_coop_b + (1.0 - lambda_val) * r_comp_b
return r_s, r_b, lambda_val
def update(self):
"""
Perform a gradient update loop containing Leader-Follower sequential methodology.
执行包含领导者-跟随者顺序方法的梯度更新循环。
Update Order / 更新顺序:
1. Update Follower (Agent B) Critic & Actor / 更新跟随者(智能体 B的 Critic 和 Actor
2. Update Leader (Agent S) Critic & Actor / 更新领导者(智能体 S的 Critic 和 Actor
Returns / 返回:
tuple: (critic_loss_s, critic_loss_b, actor_loss_s, actor_loss_b) or None if buffer not ready. / 各项损失值。
"""
if len(self.replay_buffer) < self.batch_size:
return None
# Sample batch from replay buffer / 从回放池中采样批次数据
batch = self.replay_buffer.sample(self.batch_size)
# Destructure standardized tuple. Assumes order:
# (obs_s, obs_b, act_s, act_b, rew_s, rew_b, next_obs_s, next_obs_b, dones)
obs_s, obs_b, act_s, act_b, rew_s, rew_b, next_obs_s, next_obs_b, dones = batch
obs_s = torch.FloatTensor(obs_s).to(self.device)
obs_b = torch.FloatTensor(obs_b).to(self.device)
act_s = torch.FloatTensor(act_s).to(self.device)
act_b = torch.FloatTensor(act_b).to(self.device)
rew_s = torch.FloatTensor(rew_s).unsqueeze(1).to(self.device)
rew_b = torch.FloatTensor(rew_b).unsqueeze(1).to(self.device)
next_obs_s = torch.FloatTensor(next_obs_s).to(self.device)
next_obs_b = torch.FloatTensor(next_obs_b).to(self.device)
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
# Construct joint states & actions for centralized critic / 构建用于集中式 Critic 的联合状态和动作空间
obs_all = torch.cat([obs_s, obs_b], dim=1)
next_obs_all = torch.cat([next_obs_s, next_obs_b], dim=1)
act_all = torch.cat([act_s, act_b], dim=1)
# Target actions for next state / 计算下一状态的目标动作值
with torch.no_grad():
next_act_s_target = self.actor_s_target(next_obs_s)
next_act_b_target = self.actor_b_target(next_obs_b)
next_act_all_target = torch.cat([next_act_s_target, next_act_b_target], dim=1)
# =====================================================================
# PHASE 1: Update Follower (Agent B) FIRST / 第一阶段:首先更新跟随者 (智能体 B)
# Stackelberg methodology / Stackelberg 方法论: Follower responds to Leader's action / 跟随者响应领导者的动作
# =====================================================================
# PHASE 1: Update Follower (Agent B) FIRST
# =====================================================================
# Update Critic B / 更新智能体 B 的 Critic
with torch.no_grad():
target_q_b_next = self.critic_b_target(next_obs_all, next_act_all_target)
target_q_b = rew_b + self.gamma * (1.0 - dones) * target_q_b_next
current_q_b = self.critic_b(obs_all, act_all)
critic_loss_b = self.critic_loss_fn(current_q_b, target_q_b)
self.critic_optimizer_b.zero_grad()
critic_loss_b.backward()
self.critic_optimizer_b.step()
# Update Actor B / 更新智能体 B 的 Actor
# Loss: -mean(critic_b(obs_all, [act_s_from_buffer, actor_b(obs_b)]))
# In Phase 1, the follower assumes leader's action from replay buffer / 在第一阶段,跟随者假定领导者的动作为回放池中的动作
# Loss: -mean(critic_b(obs_all, [act_s_from_buffer, actor_b(obs_b)]))
new_act_b = self.actor_b(obs_b)
act_all_for_b = torch.cat([act_s, new_act_b], dim=1)
actor_loss_b = -self.critic_b(obs_all, act_all_for_b).mean()
self.actor_optimizer_b.zero_grad()
actor_loss_b.backward()
self.actor_optimizer_b.step()
# =====================================================================
# PHASE 2: Update Leader (Agent S) with UPDATED Follower / 第二阶段:基于更新后的跟随者更新领导者 (智能体 S)
# Leader S uses Follower B's best response / 领导者 S 利用跟随者 B 的最佳响应函数
# =====================================================================
# PHASE 2: Update Leader (Agent S) with UPDATED Follower
# =====================================================================
# Update Critic S / 更新智能体 S 的 Critic
with torch.no_grad():
target_q_s_next = self.critic_s_target(next_obs_all, next_act_all_target)
target_q_s = rew_s + self.gamma * (1.0 - dones) * target_q_s_next
current_q_s = self.critic_s(obs_all, act_all)
critic_loss_s = self.critic_loss_fn(current_q_s, target_q_s)
self.critic_optimizer_s.zero_grad()
critic_loss_s.backward()
self.critic_optimizer_s.step()
# Update Actor S / 更新智能体 S 的 Actor
# KEY / 核心逻辑: Use newly updated actor_b(obs_b).detach() as follower's assumed action / 使用刚更新的 actor_b(obs_b).detach() 作为跟随者的预估动作
# This represents the Leader's knowledge of the Follower's best response / 这代表了领导者对跟随者最佳响应的认知
# KEY: Use newly updated actor_b(obs_b).detach() as follower's assumed action
new_act_s = self.actor_s(obs_s)
updated_act_b_detached = self.actor_b(obs_b).detach()
act_all_for_s = torch.cat([new_act_s, updated_act_b_detached], dim=1)
actor_loss_s = -self.critic_s(obs_all, act_all_for_s).mean()
self.actor_optimizer_s.zero_grad()
actor_loss_s.backward()
self.actor_optimizer_s.step()
# =====================================================================
# Target Networks Soft Update / 目标网络软更新
# Formula / 公式: \u03b8_target \u2190 \u03c4 * \u03b8 + (1 - \u03c4) * \u03b8_target
# =====================================================================
# Target Networks Soft Update
# =====================================================================
self.soft_update(self.actor_s_target, self.actor_s, self.tau)
self.soft_update(self.actor_b_target, self.actor_b, self.tau)
self.soft_update(self.critic_s_target, self.critic_s, self.tau)
self.soft_update(self.critic_b_target, self.critic_b, self.tau)
return critic_loss_s.item(), critic_loss_b.item(), actor_loss_s.item(), actor_loss_b.item()
def soft_update(self, target, source, tau):
"""
Polyak averaging for target network parameters. / 目标网络参数的 Polyak 平均(软更新)。
Args / 参数:
target: Target network. / 目标网络。
source: Source network. / 源网络。
tau (float): Soft update interpolation factor. / 软更新插值因子 \u03c4
"""
for target_param, source_param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(tau * source_param.data + (1.0 - tau) * target_param.data)
def save(self, path):
"""
Saves all 4 network state_dicts and optimizers. / 保存所有 4 个网络的权重和优化器状态。
Args / 参数:
path (str): File path to save the checkpoint. / 保存检查点的文件路径。
"""
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
'actor_s': self.actor_s.state_dict(),
'actor_b': self.actor_b.state_dict(),
'critic_s': self.critic_s.state_dict(),
'critic_b': self.critic_b.state_dict(),
'actor_optimizer_s': self.actor_optimizer_s.state_dict(),
'actor_optimizer_b': self.actor_optimizer_b.state_dict(),
'critic_optimizer_s': self.critic_optimizer_s.state_dict(),
'critic_optimizer_b': self.critic_optimizer_b.state_dict(),
}, path)
def load(self, path):
"""
Loads all 4 networks and optimizer parameters from saved states. / 从保存的状态加载所有 4 个网络和优化器参数。
Args / 参数:
path (str): File path of the checkpoint to load. / 要加载的检查点文件路径。
"""
checkpoint = torch.load(path, map_location=self.device)
self.actor_s.load_state_dict(checkpoint['actor_s'])
self.actor_b.load_state_dict(checkpoint['actor_b'])
self.critic_s.load_state_dict(checkpoint['critic_s'])
self.critic_b.load_state_dict(checkpoint['critic_b'])
self.actor_optimizer_s.load_state_dict(checkpoint['actor_optimizer_s'])
self.actor_optimizer_b.load_state_dict(checkpoint['actor_optimizer_b'])
self.critic_optimizer_s.load_state_dict(checkpoint['critic_optimizer_s'])
self.critic_optimizer_b.load_state_dict(checkpoint['critic_optimizer_b'])
# Hard sync the target networks after loading
self.actor_s_target.load_state_dict(self.actor_s.state_dict())
self.actor_b_target.load_state_dict(self.actor_b.state_dict())
self.critic_s_target.load_state_dict(self.critic_s.state_dict())
self.critic_b_target.load_state_dict(self.critic_b.state_dict())