392 lines
15 KiB
Python
392 lines
15 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Co-MADDPG Training Entry Point | Co-MADDPG 训练入口脚本
|
||
|
||
This script manages the training process for Co-MADDPG and various baseline
|
||
algorithms in a semantic-traditional hybrid wireless resource allocation
|
||
environment. It handles configuration loading, environment initialization,
|
||
the training loop, and result logging.
|
||
|
||
本脚本管理在语义-传统混合无线资源分配环境下的 Co-MADDPG 及各基准算法的训练过程。
|
||
它负责加载配置、初始化环境、训练循环以及结果记录。
|
||
|
||
Core Components:
|
||
- CLI Argument Parsing / 命令行参数解析
|
||
- Training Loop with Dynamic Rewards / 带有动态奖励的训练循环
|
||
- Model Saving & Checkpointing / 模型保存与断点保存
|
||
- Performance Metric Tracking / 性能指标追踪
|
||
|
||
Reference:
|
||
- "Dynamic Cooperative-Competitive Multi-Agent Reinforcement Learning for
|
||
Resource Allocation in Semantic-Traditional Hybrid Wireless Networks"
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import time
|
||
import json
|
||
import yaml
|
||
import numpy as np
|
||
import torch
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
|
||
# Add project root to path | 将项目根目录添加到路径中
|
||
PROJECT_ROOT = Path(__file__).parent
|
||
sys.path.insert(0, str(PROJECT_ROOT))
|
||
|
||
from envs.wireless_env import WirelessEnv
|
||
from agents.co_maddpg import CoMADDPG
|
||
from baselines.pure_coop import PureCooperative
|
||
from baselines.pure_comp import PureCompetitive
|
||
from baselines.single_dqn import SingleAgentDQN
|
||
from baselines.iddpg import IndependentDDPG
|
||
from baselines.fixed_lambda import FixedLambda
|
||
from baselines.equal_alloc import EqualAllocation
|
||
from baselines.semantic_only import SemanticOnly
|
||
from utils.metrics import compute_system_qoe, jain_fairness, rate_satisfaction, moving_average
|
||
|
||
|
||
def load_config(config_path: str) -> dict:
|
||
"""
|
||
Load YAML configuration file. | 加载 YAML 配置文件。
|
||
"""
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
return config
|
||
|
||
|
||
def get_algorithm(name: str, config: dict):
|
||
"""
|
||
Instantiate algorithm by name. | 按名称实例化算法。
|
||
|
||
The agent interface typically follows this contract:
|
||
1. select_action(obs_s, obs_b, explore): Choose actions for semantic/traditional agents.
|
||
2. compute_rewards(qoe_s, qoe_b, qoe_sys): Calculate mixed rewards based on lambda.
|
||
3. update(): Perform one optimization step using replay buffer data.
|
||
4. buffer / replay_buffer: Store transitions (s, a, r, s', done).
|
||
|
||
智能体接口通常遵循以下契约:
|
||
1. select_action(obs_s, obs_b, explore):为语义/传统智能体选择动作。
|
||
2. compute_rewards(qoe_s, qoe_b, qoe_sys):基于 lambda 计算混合奖励。
|
||
3. update():使用经验回放池数据执行一个优化步骤。
|
||
4. buffer / replay_buffer:存储转换元组 (s, a, r, s', done)。
|
||
"""
|
||
algorithms = {
|
||
'co_maddpg': CoMADDPG,
|
||
'pure_coop': PureCooperative,
|
||
'pure_comp': PureCompetitive,
|
||
'single_dqn': SingleAgentDQN,
|
||
'iddpg': IndependentDDPG,
|
||
'fixed_lambda': FixedLambda,
|
||
'equal_alloc': EqualAllocation,
|
||
'semantic_only': SemanticOnly,
|
||
}
|
||
if name not in algorithms:
|
||
raise ValueError(f"Unknown algorithm: {name}. Choose from {list(algorithms.keys())}")
|
||
return algorithms[name](config)
|
||
|
||
|
||
ALGO_DISPLAY_NAMES = {
|
||
'co_maddpg': 'Co-MADDPG',
|
||
'pure_coop': 'Pure Cooperative',
|
||
'pure_comp': 'Pure Competitive',
|
||
'single_dqn': 'Single-Agent DQN',
|
||
'iddpg': 'IDDPG',
|
||
'fixed_lambda': 'Fixed λ=0.5',
|
||
'equal_alloc': 'Equal Allocation',
|
||
'semantic_only': 'Semantic-Only',
|
||
}
|
||
|
||
|
||
def train_single(algorithm_name: str, config: dict, save_dir: str) -> dict:
|
||
"""
|
||
Train a single algorithm and return training history. | 训练单个算法并返回训练历史。
|
||
|
||
Workflow / 工作流程:
|
||
1. Initialize Environment & Agent / 初始化环境和智能体
|
||
2. Outer Loop: Episodes / 外层循环:回合
|
||
3. Inner Loop: Steps / 内层循环:步数
|
||
4. Reward computation & Replay Buffer storage / 奖励计算与经验回放存储
|
||
5. Policy update & Noise decay / 策略更新与噪声衰减
|
||
|
||
Returns:
|
||
dict with training metrics over episodes. | 包含各回合训练指标的字典。
|
||
"""
|
||
print(f"\n{'='*60}")
|
||
print(f"Training: {ALGO_DISPLAY_NAMES.get(algorithm_name, algorithm_name)}")
|
||
print(f"{'='*60}")
|
||
|
||
# Set random seeds | 设置随机种子
|
||
seed = config['training'].get('seed', 42)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed(seed)
|
||
|
||
# Create environment and agent | 创建环境和智能体
|
||
env = WirelessEnv(config)
|
||
agent = get_algorithm(algorithm_name, config)
|
||
|
||
max_episodes = config['training']['max_episodes']
|
||
max_steps = config['training']['max_steps']
|
||
update_interval = config['training'].get('update_interval', 5)
|
||
|
||
# Training history initialization | 训练历史记录初始化
|
||
history = {
|
||
'episode_qoe_sys': [],
|
||
'episode_qoe_semantic': [],
|
||
'episode_qoe_traditional': [],
|
||
'episode_lambda': [],
|
||
'episode_fairness': [],
|
||
'episode_rate_satisfaction': [],
|
||
'episode_reward_s': [],
|
||
'episode_reward_b': [],
|
||
}
|
||
|
||
start_time = time.time()
|
||
best_qoe = -float('inf')
|
||
|
||
# Episode loop | 回合循环
|
||
for episode in range(1, max_episodes + 1):
|
||
obs_s, obs_b = env.reset()
|
||
|
||
ep_qoe_sys_list = []
|
||
ep_qoe_s_list = []
|
||
ep_qoe_b_list = []
|
||
ep_lambda_list = []
|
||
ep_fairness_list = []
|
||
ep_rate_sat_list = []
|
||
ep_reward_s_total = 0.0
|
||
ep_reward_b_total = 0.0
|
||
|
||
# Noise decay mechanism | 噪声衰减机制
|
||
# Reduces exploration over time to stabilize policy | 随着时间推移减少探索以稳定策略
|
||
if hasattr(agent, 'noise_s'):
|
||
agent.noise_s.decay_sigma(episode)
|
||
if hasattr(agent, 'noise_b'):
|
||
agent.noise_b.decay_sigma(episode)
|
||
|
||
# Step loop | 步数循环
|
||
for step in range(1, max_steps + 1):
|
||
# 1. Action Selection | 动作选择
|
||
act_s, act_b = agent.select_action(obs_s, obs_b, explore=True)
|
||
|
||
# 2. Environment Interaction | 环境交互
|
||
next_obs_s, next_obs_b, qoe_s, qoe_b, done, info = env.step(act_s, act_b)
|
||
|
||
qoe_sys = info['qoe_sys']
|
||
|
||
# 3. Reward Calculation | 奖励计算
|
||
# Uses dynamic lambda for Co-MADDPG | Co-MADDPG 使用动态 lambda
|
||
if hasattr(agent, 'compute_rewards'):
|
||
r_s, r_b, lambda_val = agent.compute_rewards(qoe_s, qoe_b, qoe_sys)
|
||
else:
|
||
r_s, r_b, lambda_val = qoe_s, qoe_b, 0.5
|
||
|
||
# 4. Storage in Replay Buffer | 存储在经验回放池
|
||
if hasattr(agent, 'buffer'):
|
||
agent.buffer.push(obs_s, obs_b, act_s, act_b, r_s, r_b, next_obs_s, next_obs_b, done)
|
||
elif hasattr(agent, 'replay_buffer'):
|
||
agent.replay_buffer.push(obs_s, obs_b, act_s, act_b, r_s, r_b, next_obs_s, next_obs_b, done)
|
||
|
||
# 5. Agent Update | 智能体更新
|
||
# Updates occur every few steps to improve training efficiency | 每隔几步更新一次以提高训练效率
|
||
if step % update_interval == 0:
|
||
agent.update()
|
||
|
||
# Metric tracking | 指标追踪
|
||
ep_qoe_sys_list.append(qoe_sys)
|
||
ep_qoe_s_list.append(qoe_s)
|
||
ep_qoe_b_list.append(qoe_b)
|
||
ep_lambda_list.append(lambda_val)
|
||
ep_fairness_list.append(jain_fairness(info['qoe_list']))
|
||
ep_rate_sat_list.append(info['rate_satisfaction'])
|
||
ep_reward_s_total += r_s
|
||
ep_reward_b_total += r_b
|
||
|
||
obs_s = next_obs_s
|
||
obs_b = next_obs_b
|
||
|
||
if done:
|
||
break
|
||
|
||
# Record episode average metrics | 记录回合平均指标
|
||
avg_qoe_sys = np.mean(ep_qoe_sys_list)
|
||
avg_qoe_s = np.mean(ep_qoe_s_list)
|
||
avg_qoe_b = np.mean(ep_qoe_b_list)
|
||
avg_lambda = np.mean(ep_lambda_list)
|
||
avg_fairness = np.mean(ep_fairness_list)
|
||
avg_rate_sat = np.mean(ep_rate_sat_list)
|
||
|
||
history['episode_qoe_sys'].append(avg_qoe_sys)
|
||
history['episode_qoe_semantic'].append(avg_qoe_s)
|
||
history['episode_qoe_traditional'].append(avg_qoe_b)
|
||
history['episode_lambda'].append(avg_lambda)
|
||
history['episode_fairness'].append(avg_fairness)
|
||
history['episode_rate_satisfaction'].append(avg_rate_sat)
|
||
history['episode_reward_s'].append(ep_reward_s_total)
|
||
history['episode_reward_b'].append(ep_reward_b_total)
|
||
|
||
# Best model checkpointing | 最佳模型断点保存
|
||
if avg_qoe_sys > best_qoe:
|
||
best_qoe = avg_qoe_sys
|
||
model_path = os.path.join(save_dir, f'{algorithm_name}_best.pt')
|
||
if hasattr(agent, 'save'):
|
||
agent.save(model_path)
|
||
|
||
# Logging / 日志记录
|
||
if episode % 100 == 0 or episode == 1:
|
||
elapsed = time.time() - start_time
|
||
print(f" Ep {episode:5d}/{max_episodes} | "
|
||
f"QoE_sys: {avg_qoe_sys:.4f} | "
|
||
f"QoE_s: {avg_qoe_s:.4f} | "
|
||
f"QoE_b: {avg_qoe_b:.4f} | "
|
||
f"λ: {avg_lambda:.3f} | "
|
||
f"Fair: {avg_fairness:.3f} | "
|
||
f"RateSat: {avg_rate_sat:.2f} | "
|
||
f"Time: {elapsed:.0f}s")
|
||
|
||
total_time = time.time() - start_time
|
||
history['training_time'] = total_time
|
||
|
||
# Final model save | 最终模型保存
|
||
final_model_path = os.path.join(save_dir, f'{algorithm_name}_final.pt')
|
||
if hasattr(agent, 'save'):
|
||
agent.save(final_model_path)
|
||
|
||
# Save training history as JSON | 以 JSON 格式保存训练历史
|
||
history_path = os.path.join(save_dir, f'{algorithm_name}_history.json')
|
||
serializable_history = {k: [float(v) for v in vals] if isinstance(vals, list) else float(vals)
|
||
for k, vals in history.items()}
|
||
with open(history_path, 'w') as f:
|
||
json.dump(serializable_history, f, indent=2)
|
||
|
||
print(f"\n Training complete! Time: {total_time:.1f}s | Best QoE_sys: {best_qoe:.4f}")
|
||
print(f" Model saved to: {final_model_path}")
|
||
print(f" History saved to: {history_path}")
|
||
|
||
return history
|
||
|
||
|
||
def train_all(config: dict, save_dir: str) -> dict:
|
||
"""
|
||
Train all algorithms sequentially and return combined results.
|
||
按顺序训练所有算法并返回组合结果。
|
||
|
||
This facilitates large-scale comparison across all baselines.
|
||
这有助于跨所有基准算法进行大规模比较。
|
||
"""
|
||
all_results = {}
|
||
algorithms = ['co_maddpg', 'pure_coop', 'pure_comp', 'single_dqn',
|
||
'iddpg', 'fixed_lambda', 'equal_alloc', 'semantic_only']
|
||
|
||
for algo_name in algorithms:
|
||
try:
|
||
history = train_single(algo_name, config, save_dir)
|
||
all_results[algo_name] = history
|
||
except Exception as e:
|
||
print(f"\n ERROR training {algo_name}: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
all_results[algo_name] = None
|
||
|
||
# Save combined results | 保存组合结果
|
||
combined_path = os.path.join(save_dir, 'all_results.json')
|
||
serializable = {}
|
||
for k, v in all_results.items():
|
||
if v is None:
|
||
serializable[k] = None
|
||
else:
|
||
serializable[k] = {
|
||
kk: [float(x) for x in vv] if isinstance(vv, list) else float(vv)
|
||
for kk, vv in v.items()
|
||
}
|
||
with open(combined_path, 'w') as f:
|
||
json.dump(serializable, f, indent=2)
|
||
|
||
print(f"\nAll results saved to: {combined_path}")
|
||
return all_results
|
||
|
||
|
||
def main():
|
||
"""
|
||
Main entry point with CLI argument parsing. | 带有命令行参数解析的主入口。
|
||
"""
|
||
parser = argparse.ArgumentParser(description='Co-MADDPG Training')
|
||
# Config arguments | 配置参数
|
||
parser.add_argument('--config', type=str, default='configs/default.yaml',
|
||
help='Path to config YAML file')
|
||
parser.add_argument('--algorithm', type=str, default='co_maddpg',
|
||
choices=['co_maddpg', 'pure_coop', 'pure_comp', 'single_dqn',
|
||
'iddpg', 'fixed_lambda', 'equal_alloc', 'semantic_only', 'all'],
|
||
help='Algorithm to train')
|
||
|
||
# Override hyperparameters | 覆盖超参数
|
||
parser.add_argument('--episodes', type=int, default=None,
|
||
help='Override max episodes')
|
||
parser.add_argument('--steps', type=int, default=None,
|
||
help='Override max steps per episode')
|
||
parser.add_argument('--seed', type=int, default=None,
|
||
help='Override random seed')
|
||
|
||
# Resource / output settings | 资源 / 输出设置
|
||
parser.add_argument('--save_dir', type=str, default=None,
|
||
help='Directory to save results')
|
||
parser.add_argument('--gpu', type=int, default=0,
|
||
help='GPU device index (-1 for CPU)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Load config | 加载配置
|
||
config_path = os.path.join(PROJECT_ROOT, args.config)
|
||
config = load_config(config_path)
|
||
|
||
# Override config with CLI args | 用命令行参数覆盖配置
|
||
if args.episodes is not None:
|
||
config['training']['max_episodes'] = args.episodes
|
||
if args.steps is not None:
|
||
config['training']['max_steps'] = args.steps
|
||
if args.seed is not None:
|
||
config['training']['seed'] = args.seed
|
||
|
||
# Hardware selection | 硬件选择
|
||
if args.gpu >= 0 and torch.cuda.is_available():
|
||
torch.cuda.set_device(args.gpu)
|
||
print(f"Using GPU: {torch.cuda.get_device_name(args.gpu)}")
|
||
else:
|
||
print("Using CPU")
|
||
|
||
# Create timestamped save directory | 创建带有时间戳的保存目录
|
||
if args.save_dir:
|
||
save_dir = args.save_dir
|
||
else:
|
||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
save_dir = os.path.join(PROJECT_ROOT, 'results', f'run_{timestamp}')
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
# Save a snapshot of the final config for reproducibility | 保存最终配置快照以确保可重复性
|
||
config_snapshot_path = os.path.join(save_dir, 'config.yaml')
|
||
with open(config_snapshot_path, 'w') as f:
|
||
yaml.dump(config, f, default_flow_style=False)
|
||
|
||
print(f"Config: {config_path}")
|
||
print(f"Save directory: {save_dir}")
|
||
print(f"Subcarriers: {config['env']['num_subcarriers']}")
|
||
print(f"Users: {config['env']['num_semantic_users']}S + {config['env']['num_traditional_users']}B")
|
||
print(f"Episodes: {config['training']['max_episodes']}")
|
||
print(f"Steps/episode: {config['training']['max_steps']}")
|
||
|
||
# Start training | 开始训练
|
||
if args.algorithm == 'all':
|
||
train_all(config, save_dir)
|
||
else:
|
||
train_single(args.algorithm, config, save_dir)
|
||
|
||
print(f"\nDone! Results in: {save_dir}")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|