#!/usr/bin/env python3 """ Co-MADDPG Evaluation & Figure Generation | Co-MADDPG 评估与图表生成 This script evaluates trained models across various network scenarios and generates the 12 primary figures for the research paper. It covers robustness tests (SNR), scalability (User Load), and internal dynamics (Lambda). 本脚本在各种网络场景下评估已训练的模型,并为研究论文生成 12 张主要图表。 它涵盖了鲁棒性测试 (SNR)、可扩展性 (用户负载) 和内部动态 (Lambda)。 Scenarios Documented: 1. Convergence / 收敛性 (Fig 2) 2. SNR Sensitivity / SNR 敏感性 (Fig 3, 4) 3. User Load Scalability / 用户负载可扩展性 (Fig 5, 6) 4. Dynamic Lambda Trajectory / 动态 Lambda 轨迹 (Fig 7, 8) 5. Semantic-Traditional Ratio / 语义-传统比例 (Fig 9) 6. Component Ablation / 组件消融实验 (Fig 10) 7. Beta Parameter Sensitivity / Beta 参数敏感性 (Fig 11) 8. Q_th Threshold Sensitivity / Q_th 阈值敏感性 (Fig 12) Reference: - Section VII: Experimental Results """ import os import sys import argparse import json import yaml import numpy as np import torch from pathlib import Path from copy import deepcopy PROJECT_ROOT = Path(__file__).parent sys.path.insert(0, str(PROJECT_ROOT)) from envs.wireless_env import WirelessEnv from agents.co_maddpg import CoMADDPG from baselines.pure_coop import PureCooperative from baselines.pure_comp import PureCompetitive from baselines.single_dqn import SingleAgentDQN from baselines.iddpg import IndependentDDPG from baselines.fixed_lambda import FixedLambda from baselines.equal_alloc import EqualAllocation from baselines.semantic_only import SemanticOnly from utils.metrics import jain_fairness, rate_satisfaction, compute_system_qoe, moving_average from utils.visualization import Plotter # Mapping internal keys to display names and classes # 将内部键映射到显示名称和类 ALGO_MAP = { 'co_maddpg': ('Co-MADDPG', CoMADDPG), 'pure_coop': ('Pure Cooperative', PureCooperative), 'pure_comp': ('Pure Competitive', PureCompetitive), 'single_dqn': ('Single-Agent DQN', SingleAgentDQN), 'iddpg': ('IDDPG', IndependentDDPG), 'fixed_lambda': ('Fixed λ=0.5', FixedLambda), 'equal_alloc': ('Equal Allocation', EqualAllocation), 'semantic_only': ('Semantic-Only', SemanticOnly), } def load_config(config_path: str) -> dict: """Load YAML configuration file. | 加载 YAML 配置文件。""" with open(config_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) def evaluate_episode(env, agent, config, num_episodes=10): """ Run evaluation episodes and return average metrics. 执行评估回合并返回平均指标。 Parameters ---------- env : WirelessEnv The wireless environment instance. | 无线环境实例。 agent : BaseAgent The trained agent model. | 已训练的智能体模型。 config : dict Configuration parameters. | 配置参数。 num_episodes : int Number of episodes to average over. | 用于计算平均值的回合数。 """ max_steps = config['training']['max_steps'] all_qoe_sys = [] all_qoe_s = [] all_qoe_b = [] all_fairness = [] all_rate_sat = [] all_lambda = [] all_rates = [] for _ in range(num_episodes): obs_s, obs_b = env.reset() ep_qoe_sys = [] ep_lambda = [] for step in range(max_steps): # Deterministic action selection (no exploration noise) # 确定性动作选择(无探索噪声) act_s, act_b = agent.select_action(obs_s, obs_b, explore=False) next_obs_s, next_obs_b, qoe_s, qoe_b, done, info = env.step(act_s, act_b) qoe_sys = info['qoe_sys'] # Get lambda if applicable | 获取 lambda(如果适用) if hasattr(agent, 'compute_lambda'): lambda_val = agent.compute_lambda(qoe_sys) else: lambda_val = 0.5 ep_qoe_sys.append(qoe_sys) ep_lambda.append(lambda_val) obs_s = next_obs_s obs_b = next_obs_b if done: break # Calculate episode means | 计算回合平均值 all_qoe_sys.append(np.mean(ep_qoe_sys)) all_qoe_s.append(info['qoe_semantic']) all_qoe_b.append(info['qoe_traditional']) all_fairness.append(jain_fairness(info['qoe_list'])) all_rate_sat.append(info['rate_satisfaction']) all_lambda.append(np.mean(ep_lambda)) all_rates.extend(info['rates']) return { 'qoe_sys': np.mean(all_qoe_sys), 'qoe_sys_std': np.std(all_qoe_sys), 'qoe_semantic': np.mean(all_qoe_s), 'qoe_traditional': np.mean(all_qoe_b), 'fairness': np.mean(all_fairness), 'rate_satisfaction': np.mean(all_rate_sat), 'avg_lambda': np.mean(all_lambda), 'lambda_trajectory': all_lambda, } # ============================================================ # Scenario 1: Convergence (Fig 2) # ============================================================ def scenario_convergence(results_dir: str, save_dir: str): """ Generate convergence curves from training history. 根据训练历史生成收敛曲线。 Loads JSON history files for each algorithm and plots system QoE. 加载每个算法的 JSON 历史文件并绘制系统 QoE。 """ print("\n[Scenario 1] Convergence curves (Fig 2)") plotter = Plotter() data_dict = {} for algo_key, (display_name, _) in ALGO_MAP.items(): history_path = os.path.join(results_dir, f'{algo_key}_history.json') if os.path.exists(history_path): with open(history_path, 'r') as f: history = json.load(f) if 'episode_qoe_sys' in history: data_dict[display_name] = history['episode_qoe_sys'] if data_dict: plotter.plot_convergence(data_dict, os.path.join(save_dir, 'fig2_convergence')) print(f" Saved fig2_convergence") else: print(" No training history found. Run training first.") # ============================================================ # Scenario 2: QoE vs SNR (Fig 3, 4) # ============================================================ def scenario_snr(config: dict, results_dir: str, save_dir: str, num_eval=5): """ Evaluate performance across different SNR levels. 在不同 SNR 水平下评估性能。 Simulation Method: Adjusts noise PSD to achieve target SNR (0 to 30 dB). 仿真方法:调整噪声功率谱密度 (PSD) 以达到目标 SNR(0 到 30 dB)。 """ print("\n[Scenario 2] QoE vs SNR (Fig 3, 4)") plotter = Plotter() snr_levels_db = np.arange(0, 31, 5) # 0, 5, 10, 15, 20, 25, 30 qoe_data = {} fairness_data = {} for algo_key, (display_name, AlgoClass) in ALGO_MAP.items(): qoe_vals = [] fair_vals = [] for snr_db in snr_levels_db: # Modify noise PSD to achieve target SNR | 修改噪声 PSD 以达到目标 SNR test_config = deepcopy(config) # SNR = Signal_Power - Noise_Power. Adjusting noise_psd shifts SNR. # SNR = 信号功率 - 噪声功率。调整 noise_psd 会改变 SNR。 snr_offset = snr_db - 15 # 15 dB is roughly the baseline SNR | 15 dB 大约是基准 SNR test_config['env']['noise_psd'] = -174 - snr_offset env = WirelessEnv(test_config) agent = AlgoClass(test_config) # Load trained model weights | 加载已训练的模型权重 model_path = os.path.join(results_dir, f'{algo_key}_best.pt') if os.path.exists(model_path) and hasattr(agent, 'load'): try: agent.load(model_path) except Exception: pass result = evaluate_episode(env, agent, test_config, num_episodes=num_eval) qoe_vals.append(result['qoe_sys']) fair_vals.append(result['fairness']) qoe_data[display_name] = qoe_vals fairness_data[display_name] = fair_vals print(f" {display_name}: QoE range [{min(qoe_vals):.3f}, {max(qoe_vals):.3f}]") plotter.plot_qoe_vs_snr(qoe_data, os.path.join(save_dir, 'fig3_qoe_vs_snr')) plotter.plot_fairness_vs_snr(fairness_data, os.path.join(save_dir, 'fig4_fairness_vs_snr')) print(f" Saved fig3, fig4") return {'snr_levels': snr_levels_db.tolist(), 'qoe': qoe_data, 'fairness': fairness_data} # ============================================================ # Scenario 3: QoE vs User Load (Fig 5, 6) # ============================================================ def scenario_user_load(config: dict, results_dir: str, save_dir: str, num_eval=5): """ Evaluate performance with different user counts. 评估不同用户数量下的性能。 Simulation Method: Varies total user count K from 4 to 12, split between S and B. 仿真方法:将总用户数 K 在 4 到 12 之间变化,在语义 (S) 和传统 (B) 用户之间分配。 """ print("\n[Scenario 3] QoE vs User Load (Fig 5, 6)") plotter = Plotter() user_counts = [4, 6, 8, 10, 12] # Total K | 总用户数 K qoe_data = {} rate_sat_data = {} for algo_key, (display_name, AlgoClass) in ALGO_MAP.items(): qoe_vals = [] rate_vals = [] for k_total in user_counts: test_config = deepcopy(config) # Distribute users equally between types | 在不同类型之间平均分配用户 k_s = k_total // 2 k_b = k_total - k_s test_config['env']['num_semantic_users'] = k_s test_config['env']['num_traditional_users'] = k_b env = WirelessEnv(test_config) agent = AlgoClass(test_config) model_path = os.path.join(results_dir, f'{algo_key}_best.pt') if os.path.exists(model_path) and hasattr(agent, 'load'): try: agent.load(model_path) except Exception: pass result = evaluate_episode(env, agent, test_config, num_episodes=num_eval) qoe_vals.append(result['qoe_sys']) rate_vals.append(result['rate_satisfaction']) qoe_data[display_name] = qoe_vals rate_sat_data[display_name] = rate_vals print(f" {display_name}: QoE range [{min(qoe_vals):.3f}, {max(qoe_vals):.3f}]") plotter.plot_qoe_vs_users(qoe_data, os.path.join(save_dir, 'fig5_qoe_vs_users')) plotter.plot_rate_satisfaction_vs_users(rate_sat_data, os.path.join(save_dir, 'fig6_rate_sat_vs_users')) print(f" Saved fig5, fig6") # ============================================================ # Scenario 4: Lambda Dynamics (Fig 7, 8) # ============================================================ def scenario_lambda_dynamics(config: dict, results_dir: str, save_dir: str): """ Analyze dynamic λ switching behavior of Co-MADDPG. 分析 Co-MADDPG 的动态 λ 切换行为。 """ print("\n[Scenario 4] Lambda Dynamics (Fig 7, 8)") plotter = Plotter() env = WirelessEnv(config) agent = CoMADDPG(config) model_path = os.path.join(results_dir, 'co_maddpg_best.pt') if os.path.exists(model_path): try: agent.load(model_path) except Exception: pass # Run one episode and collect λ trajectory | 执行一个回合并收集 λ 轨迹 obs_s, obs_b = env.reset() lambda_vals = [] qoe_vals = [] for step in range(config['training']['max_steps']): act_s, act_b = agent.select_action(obs_s, obs_b, explore=False) next_obs_s, next_obs_b, qoe_s, qoe_b, done, info = env.step(act_s, act_b) qoe_sys = info['qoe_sys'] lambda_val = agent.compute_lambda(qoe_sys) lambda_vals.append(float(lambda_val)) qoe_vals.append(float(qoe_sys)) obs_s, obs_b = next_obs_s, next_obs_b if done: break plotter.plot_lambda_trajectory(lambda_vals, os.path.join(save_dir, 'fig7_lambda_trajectory')) plotter.plot_lambda_qoe_scatter(lambda_vals, qoe_vals, os.path.join(save_dir, 'fig8_lambda_qoe_scatter')) print(f" Saved fig7, fig8") # ============================================================ # Scenario 5: Semantic/Traditional Ratio (Fig 9) # ============================================================ def scenario_user_ratio(config: dict, results_dir: str, save_dir: str, num_eval=5): """ Evaluate with different semantic/traditional user ratios. 评估不同语义/传统用户比例下的性能。 Studies the impact as semantic communication becomes more prevalent. 研究语义通信变得更加普遍时的影响。 """ print("\n[Scenario 5] User Ratio Analysis (Fig 9)") plotter = Plotter() total_users = 6 ratios = [0.0, 0.17, 0.33, 0.5, 0.67, 0.83, 1.0] # semantic fraction | 语义用户占比 qoe_data = {} for algo_key, (display_name, AlgoClass) in ALGO_MAP.items(): qoe_vals = [] for ratio in ratios: # Map ratio to discrete integer counts | 将比例映射为离散整数计数 k_s = max(0, min(total_users, int(round(ratio * total_users)))) k_b = total_users - k_s # Ensure at least one of each for hybrid env constraints if necessary # 如有必要,确保混合环境约束下每种类型至少有一个 if k_s == 0: k_s = 1; k_b = total_users - 1 if k_b == 0: k_b = 1; k_s = total_users - 1 test_config = deepcopy(config) test_config['env']['num_semantic_users'] = k_s test_config['env']['num_traditional_users'] = k_b env = WirelessEnv(test_config) agent = AlgoClass(test_config) model_path = os.path.join(results_dir, f'{algo_key}_best.pt') if os.path.exists(model_path) and hasattr(agent, 'load'): try: agent.load(model_path) except Exception: pass result = evaluate_episode(env, agent, test_config, num_episodes=num_eval) qoe_vals.append(result['qoe_sys']) qoe_data[display_name] = qoe_vals plotter.plot_qoe_vs_ratio(qoe_data, ratios, os.path.join(save_dir, 'fig9_qoe_vs_ratio')) print(f" Saved fig9") # ============================================================ # Scenario 6: Ablation Study (Fig 10) # ============================================================ def scenario_ablation(config: dict, results_dir: str, save_dir: str, num_eval=5): """ Run ablation study comparing core components. 运行消融实验比较核心组件。 Ablation Mapping: - w/o Stackelberg: Pure Cooperative (simultaneous update) | 无 Stackelberg:纯协作(同步更新) - w/o Dynamic λ: Fixed Lambda (λ=0.5) | 无动态 λ:固定 Lambda (λ=0.5) - w/o Cooperation: Pure Competitive (λ=0) | 无协作:纯竞争 (λ=0) - w/o CTDE: IDDPG (Independent Critics) | 无 CTDE:IDDPG(独立评论家) """ print("\n[Scenario 6] Ablation Study (Fig 10)") plotter = Plotter() ablation_keys = { 'Co-MADDPG (Full)': 'co_maddpg', 'w/o Stackelberg': 'pure_coop', 'w/o Dynamic λ': 'fixed_lambda', 'w/o Cooperation': 'pure_comp', 'w/o CTDE': 'iddpg', } ablation_data = {} for label, algo_key in ablation_keys.items(): history_path = os.path.join(results_dir, f'{algo_key}_history.json') if os.path.exists(history_path): with open(history_path, 'r') as f: history = json.load(f) # Average of last 500 episodes for stability | 为保证稳定性取最后 500 回合的平均值 qoe_series = history.get('episode_qoe_sys', []) if len(qoe_series) >= 500: ablation_data[label] = np.mean(qoe_series[-500:]) elif len(qoe_series) > 0: ablation_data[label] = np.mean(qoe_series[-len(qoe_series)//5:]) else: ablation_data[label] = 0.0 else: # Fallback to direct evaluation if history missing | 如果历史记录缺失,则回退到直接评估 env = WirelessEnv(config) AlgoClass = ALGO_MAP[algo_key][1] agent = AlgoClass(config) model_path = os.path.join(results_dir, f'{algo_key}_best.pt') if os.path.exists(model_path) and hasattr(agent, 'load'): try: agent.load(model_path) except Exception: pass result = evaluate_episode(env, agent, config, num_episodes=num_eval) ablation_data[label] = result['qoe_sys'] plotter.plot_ablation(ablation_data, os.path.join(save_dir, 'fig10_ablation')) print(f" Saved fig10") # ============================================================ # Scenario 7: β Sensitivity (Fig 11) # ============================================================ def scenario_beta_sensitivity(config: dict, results_dir: str, save_dir: str, num_eval=5): """ Evaluate sensitivity to the β parameter in the sigmoid function. 评估 Sigmoid 函数中 β 参数的敏感性。 β controls the steepness of switching between competition and cooperation. β 控制竞争与协作之间切换的陡峭程度。 """ print("\n[Scenario 7] β Sensitivity (Fig 11)") plotter = Plotter() betas = [1, 3, 5, 7, 10] qoe_data = {} for beta in betas: test_config = deepcopy(config) test_config['training']['beta'] = float(beta) env = WirelessEnv(test_config) agent = CoMADDPG(test_config) model_path = os.path.join(results_dir, 'co_maddpg_best.pt') if os.path.exists(model_path): try: agent.load(model_path) except Exception: pass result = evaluate_episode(env, agent, test_config, num_episodes=num_eval) qoe_data[f'β={beta}'] = result['qoe_sys'] print(f" β={beta}: QoE_sys={result['qoe_sys']:.4f}") plotter.plot_beta_sensitivity(qoe_data, betas, os.path.join(save_dir, 'fig11_beta_sensitivity')) print(f" Saved fig11") # ============================================================ # Scenario 8: Q_th Sensitivity (Fig 12) # ============================================================ def scenario_qth_sensitivity(config: dict, results_dir: str, save_dir: str, num_eval=5): """ Evaluate sensitivity to the Q_th threshold parameter. 评估 Q_th 阈值参数的敏感性。 Q_th is the target QoE level below which cooperation is triggered. Q_th 是触发协作的目标 QoE 水平。 """ print("\n[Scenario 8] Q_th Sensitivity (Fig 12)") plotter = Plotter() qths = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8] qoe_data = {} for qth in qths: test_config = deepcopy(config) test_config['training']['q_threshold'] = float(qth) env = WirelessEnv(test_config) agent = CoMADDPG(test_config) model_path = os.path.join(results_dir, 'co_maddpg_best.pt') if os.path.exists(model_path): try: agent.load(model_path) except Exception: pass result = evaluate_episode(env, agent, test_config, num_episodes=num_eval) qoe_data[f'Q_th={qth}'] = result['qoe_sys'] print(f" Q_th={qth}: QoE_sys={result['qoe_sys']:.4f}") plotter.plot_qth_sensitivity(qoe_data, qths, os.path.join(save_dir, 'fig12_qth_sensitivity')) print(f" Saved fig12") # ============================================================ # Run all scenarios | 执行所有场景 # ============================================================ def run_all_scenarios(config: dict, results_dir: str, save_dir: str): """Run all evaluation scenarios and generate all figures. | 执行所有评估场景并生成所有图表。""" os.makedirs(save_dir, exist_ok=True) scenario_convergence(results_dir, save_dir) scenario_snr(config, results_dir, save_dir) scenario_user_load(config, results_dir, save_dir) scenario_lambda_dynamics(config, results_dir, save_dir) scenario_user_ratio(config, results_dir, save_dir) scenario_ablation(config, results_dir, save_dir) scenario_beta_sensitivity(config, results_dir, save_dir) scenario_qth_sensitivity(config, results_dir, save_dir) print(f"\nAll figures saved to: {save_dir}") def main(): """Main entry point for evaluation. | 评估主入口。""" parser = argparse.ArgumentParser(description='Co-MADDPG Evaluation') parser.add_argument('--config', type=str, default='configs/default.yaml', help='Path to config YAML') parser.add_argument('--results_dir', type=str, required=True, help='Directory with trained models and history') parser.add_argument('--save_dir', type=str, default=None, help='Directory to save figures (default: results_dir/figures)') parser.add_argument('--scenario', type=str, default='all', choices=['all', 'convergence', 'snr', 'user_load', 'lambda', 'ratio', 'ablation', 'beta', 'qth'], help='Evaluation scenario to run') parser.add_argument('--num_eval', type=int, default=10, help='Number of evaluation episodes per setting') args = parser.parse_args() config_path = os.path.join(PROJECT_ROOT, args.config) config = load_config(config_path) save_dir = args.save_dir or os.path.join(args.results_dir, 'figures') os.makedirs(save_dir, exist_ok=True) # Dispatch to specific scenario | 分派到特定场景 scenario_map = { 'all': lambda: run_all_scenarios(config, args.results_dir, save_dir), 'convergence': lambda: scenario_convergence(args.results_dir, save_dir), 'snr': lambda: scenario_snr(config, args.results_dir, save_dir, args.num_eval), 'user_load': lambda: scenario_user_load(config, args.results_dir, save_dir, args.num_eval), 'lambda': lambda: scenario_lambda_dynamics(config, args.results_dir, save_dir), 'ratio': lambda: scenario_user_ratio(config, args.results_dir, save_dir, args.num_eval), 'ablation': lambda: scenario_ablation(config, args.results_dir, save_dir, args.num_eval), 'beta': lambda: scenario_beta_sensitivity(config, args.results_dir, save_dir, args.num_eval), 'qth': lambda: scenario_qth_sensitivity(config, args.results_dir, save_dir, args.num_eval), } scenario_map[args.scenario]() print("\nEvaluation complete!") if __name__ == '__main__': main()