239 lines
9.2 KiB
Python
239 lines
9.2 KiB
Python
import os
|
||
import random
|
||
from collections import deque
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
from agents.actor import Actor
|
||
from agents.noise import OUNoise
|
||
|
||
"""
|
||
Baseline: SemanticOnly (仅语义基线)
|
||
=====================================
|
||
Purpose (ablation):
|
||
- This baseline removes the heterogeneous treatment of different user groups.
|
||
- It treats all users as semantic users and uses a single DDPG policy to control both groups.
|
||
- It serves as an ablation study to demonstrate the benefit of having heterogeneous, specialized policies for semantic vs. traditional users.
|
||
- 目的(消融实验):该基线移除了对不同用户组的异构处理。它将所有用户视为语义用户,并使用单一的 DDPG 策略同时控制两个用户组。作为消融实验,用于证明为语义用户和传统用户分别设计专门的异构策略的收益。
|
||
|
||
Difference from Co-MADDPG:
|
||
1. Heterogeneity: Homogeneous policy (all semantic) vs Heterogeneous policies.
|
||
2. Architecture: Single DDPG agent for both groups vs Multi-agent (Co-MADDPG).
|
||
3. 与 Co-MADDPG 的区别:
|
||
- 异构性:同构策略(全部视为语义用户) vs 异构策略。
|
||
- 架构:单 DDPG 智能体控制两组 vs 多智能体 (Co-MADDPG)。
|
||
|
||
Contribution:
|
||
- Contributes to performance analysis regarding user heterogeneity and specialized resource allocation.
|
||
- 贡献:用于关于用户异构性和专门化资源分配的性能分析。
|
||
"""
|
||
|
||
class SemanticCritic(nn.Module):
|
||
"""
|
||
Single-agent critic: observation + action → Q-value.
|
||
单智能体 Critic:观察 + 动作 → Q 值。
|
||
"""
|
||
def __init__(self, obs_dim, act_dim, hidden_sizes=[256, 256, 128]):
|
||
super().__init__()
|
||
assert len(hidden_sizes) == 3
|
||
self.net = nn.Sequential(
|
||
nn.Linear(obs_dim + act_dim, hidden_sizes[0]),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden_sizes[2], 1),
|
||
)
|
||
|
||
def forward(self, obs, act):
|
||
# Forward pass for single agent
|
||
# 单智能体前向传播
|
||
return self.net(torch.cat([obs, act], dim=1))
|
||
|
||
|
||
class SemanticBuffer:
|
||
"""
|
||
Replay buffer for SemanticOnly baseline.
|
||
仅语义基线的重放池。
|
||
|
||
Wrapper that accepts the 9-arg multi-agent push but stores single-agent transitions.
|
||
接收多智能体 9 参数 push 请求,但内部存储单智能体转换数据。
|
||
"""
|
||
def __init__(self, capacity):
|
||
self.buffer = deque(maxlen=capacity)
|
||
|
||
def push(self, obs_s, obs_b, act_s, act_b, rew_s, rew_b,
|
||
next_obs_s, next_obs_b, done=False):
|
||
"""
|
||
Store only semantic agent's observation/action and average reward.
|
||
仅存储语义智能体的观察/动作以及平均奖励。
|
||
"""
|
||
self.buffer.append((
|
||
np.asarray(obs_s, dtype=np.float32),
|
||
np.asarray(act_s, dtype=np.float32),
|
||
float(0.5 * (rew_s + rew_b)),
|
||
np.asarray(next_obs_s, dtype=np.float32),
|
||
float(done),
|
||
))
|
||
|
||
def sample(self, batch_size):
|
||
"""Sample batch."""
|
||
batch = random.sample(self.buffer, batch_size)
|
||
obs, act, rew, next_obs, dones = zip(*batch)
|
||
return (np.array(obs), np.array(act), np.array(rew, dtype=np.float32),
|
||
np.array(next_obs), np.array(dones, dtype=np.float32))
|
||
|
||
def __len__(self):
|
||
return len(self.buffer)
|
||
|
||
|
||
class SemanticOnly:
|
||
"""
|
||
SemanticOnly algorithm implementation.
|
||
仅语义算法实现。
|
||
"""
|
||
def __init__(self, config):
|
||
# Initialize configuration and device
|
||
# 初始化配置和设备
|
||
self.config = config
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# Hyperparameters
|
||
# 超参数
|
||
self.gamma = config['training']['gamma']
|
||
self.tau = config['training']['tau']
|
||
self.batch_size = config['training']['batch_size']
|
||
|
||
# Dimensions
|
||
# 维度
|
||
self.obs_dim = config['env']['num_subcarriers'] + 4
|
||
self.act_dim = 3
|
||
|
||
# Network configurations
|
||
# 网络配置
|
||
hidden_a = config['network']['actor_hidden']
|
||
critic_hidden = [256, 256, 128]
|
||
|
||
# Single Actor and Critic policy
|
||
# 单一 Actor 与 Critic 策略
|
||
self.actor = Actor(self.obs_dim, self.act_dim, hidden_a).to(self.device)
|
||
self.actor_target = Actor(self.obs_dim, self.act_dim, hidden_a).to(self.device)
|
||
self.actor_target.load_state_dict(self.actor.state_dict())
|
||
|
||
self.critic = SemanticCritic(self.obs_dim, self.act_dim, critic_hidden).to(self.device)
|
||
self.critic_target = SemanticCritic(self.obs_dim, self.act_dim, critic_hidden).to(self.device)
|
||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||
|
||
# Optimizers
|
||
# 优化器
|
||
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=config['training']['actor_lr'])
|
||
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=config['training']['critic_lr'])
|
||
|
||
# Buffer and Noise
|
||
# 重放池与噪声
|
||
self.replay_buffer = SemanticBuffer(config['training']['buffer_capacity'])
|
||
self.noise_s = OUNoise(self.act_dim, theta=config['training']['ou_theta'],
|
||
sigma_init=config['training']['ou_sigma_init'],
|
||
sigma_min=config['training']['ou_sigma_min'])
|
||
# Alias for compatibility with training loop
|
||
# 与训练循环兼容的别名
|
||
self.noise_b = self.noise_s
|
||
|
||
def select_action(self, obs_s, obs_b, explore=True):
|
||
"""
|
||
Select actions for both groups using the same policy.
|
||
使用相同策略为两组用户选择动作。
|
||
"""
|
||
self.actor.eval()
|
||
with torch.no_grad():
|
||
obs_t = torch.FloatTensor(obs_s).unsqueeze(0).to(self.device)
|
||
act = self.actor(obs_t).cpu().numpy()[0]
|
||
self.actor.train()
|
||
|
||
if explore:
|
||
# Apply OU noise
|
||
# 应用 OU 噪声
|
||
act = np.clip(act + self.noise_s.sample(), 0.0, 1.0)
|
||
else:
|
||
act = np.clip(act, 0.0, 1.0)
|
||
|
||
# Return the same action for both groups
|
||
# 为两组用户返回相同的动作
|
||
return act.copy(), act.copy()
|
||
|
||
def compute_rewards(self, qoe_s, qoe_b, qoe_sys):
|
||
"""
|
||
Compute rewards assuming full cooperation (λ=1).
|
||
假设完全协作 (λ=1) 计算奖励。
|
||
|
||
Formula: r = 0.5 * (qoe_s + qoe_b)
|
||
公式说明:由于全部视为语义用户,目标是最大化整体 QoE。
|
||
"""
|
||
lam = 1.0
|
||
r = 0.5 * (qoe_s + qoe_b)
|
||
return r, r, lam
|
||
|
||
def update(self):
|
||
"""
|
||
Update the single DDPG agent.
|
||
更新单个 DDPG 智能体。
|
||
"""
|
||
if len(self.replay_buffer) < self.batch_size:
|
||
return None
|
||
|
||
# Sample from buffer
|
||
# 从重放池采样
|
||
obs, act, rew, next_obs, dones = self.replay_buffer.sample(self.batch_size)
|
||
|
||
# To tensors
|
||
# 转换为张量
|
||
obs_t = torch.FloatTensor(obs).to(self.device)
|
||
act_t = torch.FloatTensor(act).to(self.device)
|
||
rew_t = torch.FloatTensor(rew).unsqueeze(1).to(self.device)
|
||
next_obs_t = torch.FloatTensor(next_obs).to(self.device)
|
||
dones_t = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
|
||
|
||
# 1. Critic update (1. Critic 更新)
|
||
with torch.no_grad():
|
||
next_act = self.actor_target(next_obs_t)
|
||
target_q = rew_t + self.gamma * (1 - dones_t) * self.critic_target(next_obs_t, next_act)
|
||
|
||
current_q = self.critic(obs_t, act_t)
|
||
critic_loss = F.mse_loss(current_q, target_q)
|
||
self.critic_optimizer.zero_grad()
|
||
critic_loss.backward()
|
||
self.critic_optimizer.step()
|
||
|
||
# 2. Actor update (2. Actor 更新)
|
||
new_act = self.actor(obs_t)
|
||
actor_loss = -self.critic(obs_t, new_act).mean()
|
||
self.actor_optimizer.zero_grad()
|
||
actor_loss.backward()
|
||
self.actor_optimizer.step()
|
||
|
||
# 3. Soft update targets (3. 目标网络软更新)
|
||
for target, source in [
|
||
(self.critic_target, self.critic),
|
||
(self.actor_target, self.actor),
|
||
]:
|
||
for tp, sp in zip(target.parameters(), source.parameters()):
|
||
tp.data.copy_(self.tau * sp.data + (1.0 - self.tau) * tp.data)
|
||
|
||
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
||
|
||
def save(self, path):
|
||
"""Save models."""
|
||
os.makedirs(path, exist_ok=True)
|
||
torch.save(self.actor.state_dict(), os.path.join(path, "actor.pth"))
|
||
torch.save(self.critic.state_dict(), os.path.join(path, "critic.pth"))
|
||
|
||
def load(self, path):
|
||
"""Load models."""
|
||
self.actor.load_state_dict(torch.load(os.path.join(path, "actor.pth"), map_location=self.device))
|
||
self.critic.load_state_dict(torch.load(os.path.join(path, "critic.pth"), map_location=self.device))
|
||
self.actor_target.load_state_dict(self.actor.state_dict())
|
||
self.critic_target.load_state_dict(self.critic.state_dict())
|