578 lines
22 KiB
Python
578 lines
22 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Co-MADDPG Evaluation & Figure Generation | Co-MADDPG 评估与图表生成
|
||
|
||
This script evaluates trained models across various network scenarios and
|
||
generates the 12 primary figures for the research paper. It covers robustness
|
||
tests (SNR), scalability (User Load), and internal dynamics (Lambda).
|
||
|
||
本脚本在各种网络场景下评估已训练的模型,并为研究论文生成 12 张主要图表。
|
||
它涵盖了鲁棒性测试 (SNR)、可扩展性 (用户负载) 和内部动态 (Lambda)。
|
||
|
||
Scenarios Documented:
|
||
1. Convergence / 收敛性 (Fig 2)
|
||
2. SNR Sensitivity / SNR 敏感性 (Fig 3, 4)
|
||
3. User Load Scalability / 用户负载可扩展性 (Fig 5, 6)
|
||
4. Dynamic Lambda Trajectory / 动态 Lambda 轨迹 (Fig 7, 8)
|
||
5. Semantic-Traditional Ratio / 语义-传统比例 (Fig 9)
|
||
6. Component Ablation / 组件消融实验 (Fig 10)
|
||
7. Beta Parameter Sensitivity / Beta 参数敏感性 (Fig 11)
|
||
8. Q_th Threshold Sensitivity / Q_th 阈值敏感性 (Fig 12)
|
||
|
||
Reference:
|
||
- Section VII: Experimental Results
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import json
|
||
import yaml
|
||
import numpy as np
|
||
import torch
|
||
from pathlib import Path
|
||
from copy import deepcopy
|
||
|
||
PROJECT_ROOT = Path(__file__).parent
|
||
sys.path.insert(0, str(PROJECT_ROOT))
|
||
|
||
from envs.wireless_env import WirelessEnv
|
||
from agents.co_maddpg import CoMADDPG
|
||
from baselines.pure_coop import PureCooperative
|
||
from baselines.pure_comp import PureCompetitive
|
||
from baselines.single_dqn import SingleAgentDQN
|
||
from baselines.iddpg import IndependentDDPG
|
||
from baselines.fixed_lambda import FixedLambda
|
||
from baselines.equal_alloc import EqualAllocation
|
||
from baselines.semantic_only import SemanticOnly
|
||
from utils.metrics import jain_fairness, rate_satisfaction, compute_system_qoe, moving_average
|
||
from utils.visualization import Plotter
|
||
|
||
|
||
# Mapping internal keys to display names and classes
|
||
# 将内部键映射到显示名称和类
|
||
ALGO_MAP = {
|
||
'co_maddpg': ('Co-MADDPG', CoMADDPG),
|
||
'pure_coop': ('Pure Cooperative', PureCooperative),
|
||
'pure_comp': ('Pure Competitive', PureCompetitive),
|
||
'single_dqn': ('Single-Agent DQN', SingleAgentDQN),
|
||
'iddpg': ('IDDPG', IndependentDDPG),
|
||
'fixed_lambda': ('Fixed λ=0.5', FixedLambda),
|
||
'equal_alloc': ('Equal Allocation', EqualAllocation),
|
||
'semantic_only': ('Semantic-Only', SemanticOnly),
|
||
}
|
||
|
||
|
||
def load_config(config_path: str) -> dict:
|
||
"""Load YAML configuration file. | 加载 YAML 配置文件。"""
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
return yaml.safe_load(f)
|
||
|
||
|
||
def evaluate_episode(env, agent, config, num_episodes=10):
|
||
"""
|
||
Run evaluation episodes and return average metrics.
|
||
执行评估回合并返回平均指标。
|
||
|
||
Parameters
|
||
----------
|
||
env : WirelessEnv
|
||
The wireless environment instance. | 无线环境实例。
|
||
agent : BaseAgent
|
||
The trained agent model. | 已训练的智能体模型。
|
||
config : dict
|
||
Configuration parameters. | 配置参数。
|
||
num_episodes : int
|
||
Number of episodes to average over. | 用于计算平均值的回合数。
|
||
"""
|
||
max_steps = config['training']['max_steps']
|
||
|
||
all_qoe_sys = []
|
||
all_qoe_s = []
|
||
all_qoe_b = []
|
||
all_fairness = []
|
||
all_rate_sat = []
|
||
all_lambda = []
|
||
all_rates = []
|
||
|
||
for _ in range(num_episodes):
|
||
obs_s, obs_b = env.reset()
|
||
ep_qoe_sys = []
|
||
ep_lambda = []
|
||
|
||
for step in range(max_steps):
|
||
# Deterministic action selection (no exploration noise)
|
||
# 确定性动作选择(无探索噪声)
|
||
act_s, act_b = agent.select_action(obs_s, obs_b, explore=False)
|
||
next_obs_s, next_obs_b, qoe_s, qoe_b, done, info = env.step(act_s, act_b)
|
||
|
||
qoe_sys = info['qoe_sys']
|
||
# Get lambda if applicable | 获取 lambda(如果适用)
|
||
if hasattr(agent, 'compute_lambda'):
|
||
lambda_val = agent.compute_lambda(qoe_sys)
|
||
else:
|
||
lambda_val = 0.5
|
||
|
||
ep_qoe_sys.append(qoe_sys)
|
||
ep_lambda.append(lambda_val)
|
||
|
||
obs_s = next_obs_s
|
||
obs_b = next_obs_b
|
||
if done:
|
||
break
|
||
|
||
# Calculate episode means | 计算回合平均值
|
||
all_qoe_sys.append(np.mean(ep_qoe_sys))
|
||
all_qoe_s.append(info['qoe_semantic'])
|
||
all_qoe_b.append(info['qoe_traditional'])
|
||
all_fairness.append(jain_fairness(info['qoe_list']))
|
||
all_rate_sat.append(info['rate_satisfaction'])
|
||
all_lambda.append(np.mean(ep_lambda))
|
||
all_rates.extend(info['rates'])
|
||
|
||
return {
|
||
'qoe_sys': np.mean(all_qoe_sys),
|
||
'qoe_sys_std': np.std(all_qoe_sys),
|
||
'qoe_semantic': np.mean(all_qoe_s),
|
||
'qoe_traditional': np.mean(all_qoe_b),
|
||
'fairness': np.mean(all_fairness),
|
||
'rate_satisfaction': np.mean(all_rate_sat),
|
||
'avg_lambda': np.mean(all_lambda),
|
||
'lambda_trajectory': all_lambda,
|
||
}
|
||
|
||
|
||
# ============================================================
|
||
# Scenario 1: Convergence (Fig 2)
|
||
# ============================================================
|
||
def scenario_convergence(results_dir: str, save_dir: str):
|
||
"""
|
||
Generate convergence curves from training history.
|
||
根据训练历史生成收敛曲线。
|
||
|
||
Loads JSON history files for each algorithm and plots system QoE.
|
||
加载每个算法的 JSON 历史文件并绘制系统 QoE。
|
||
"""
|
||
print("\n[Scenario 1] Convergence curves (Fig 2)")
|
||
plotter = Plotter()
|
||
data_dict = {}
|
||
|
||
for algo_key, (display_name, _) in ALGO_MAP.items():
|
||
history_path = os.path.join(results_dir, f'{algo_key}_history.json')
|
||
if os.path.exists(history_path):
|
||
with open(history_path, 'r') as f:
|
||
history = json.load(f)
|
||
if 'episode_qoe_sys' in history:
|
||
data_dict[display_name] = history['episode_qoe_sys']
|
||
|
||
if data_dict:
|
||
plotter.plot_convergence(data_dict, os.path.join(save_dir, 'fig2_convergence'))
|
||
print(f" Saved fig2_convergence")
|
||
else:
|
||
print(" No training history found. Run training first.")
|
||
|
||
|
||
# ============================================================
|
||
# Scenario 2: QoE vs SNR (Fig 3, 4)
|
||
# ============================================================
|
||
def scenario_snr(config: dict, results_dir: str, save_dir: str, num_eval=5):
|
||
"""
|
||
Evaluate performance across different SNR levels.
|
||
在不同 SNR 水平下评估性能。
|
||
|
||
Simulation Method: Adjusts noise PSD to achieve target SNR (0 to 30 dB).
|
||
仿真方法:调整噪声功率谱密度 (PSD) 以达到目标 SNR(0 到 30 dB)。
|
||
"""
|
||
print("\n[Scenario 2] QoE vs SNR (Fig 3, 4)")
|
||
plotter = Plotter()
|
||
snr_levels_db = np.arange(0, 31, 5) # 0, 5, 10, 15, 20, 25, 30
|
||
|
||
qoe_data = {}
|
||
fairness_data = {}
|
||
|
||
for algo_key, (display_name, AlgoClass) in ALGO_MAP.items():
|
||
qoe_vals = []
|
||
fair_vals = []
|
||
|
||
for snr_db in snr_levels_db:
|
||
# Modify noise PSD to achieve target SNR | 修改噪声 PSD 以达到目标 SNR
|
||
test_config = deepcopy(config)
|
||
# SNR = Signal_Power - Noise_Power. Adjusting noise_psd shifts SNR.
|
||
# SNR = 信号功率 - 噪声功率。调整 noise_psd 会改变 SNR。
|
||
snr_offset = snr_db - 15 # 15 dB is roughly the baseline SNR | 15 dB 大约是基准 SNR
|
||
test_config['env']['noise_psd'] = -174 - snr_offset
|
||
|
||
env = WirelessEnv(test_config)
|
||
agent = AlgoClass(test_config)
|
||
|
||
# Load trained model weights | 加载已训练的模型权重
|
||
model_path = os.path.join(results_dir, f'{algo_key}_best.pt')
|
||
if os.path.exists(model_path) and hasattr(agent, 'load'):
|
||
try:
|
||
agent.load(model_path)
|
||
except Exception:
|
||
pass
|
||
|
||
result = evaluate_episode(env, agent, test_config, num_episodes=num_eval)
|
||
qoe_vals.append(result['qoe_sys'])
|
||
fair_vals.append(result['fairness'])
|
||
|
||
qoe_data[display_name] = qoe_vals
|
||
fairness_data[display_name] = fair_vals
|
||
print(f" {display_name}: QoE range [{min(qoe_vals):.3f}, {max(qoe_vals):.3f}]")
|
||
|
||
plotter.plot_qoe_vs_snr(qoe_data, os.path.join(save_dir, 'fig3_qoe_vs_snr'))
|
||
plotter.plot_fairness_vs_snr(fairness_data, os.path.join(save_dir, 'fig4_fairness_vs_snr'))
|
||
print(f" Saved fig3, fig4")
|
||
|
||
return {'snr_levels': snr_levels_db.tolist(), 'qoe': qoe_data, 'fairness': fairness_data}
|
||
|
||
|
||
# ============================================================
|
||
# Scenario 3: QoE vs User Load (Fig 5, 6)
|
||
# ============================================================
|
||
def scenario_user_load(config: dict, results_dir: str, save_dir: str, num_eval=5):
|
||
"""
|
||
Evaluate performance with different user counts.
|
||
评估不同用户数量下的性能。
|
||
|
||
Simulation Method: Varies total user count K from 4 to 12, split between S and B.
|
||
仿真方法:将总用户数 K 在 4 到 12 之间变化,在语义 (S) 和传统 (B) 用户之间分配。
|
||
"""
|
||
print("\n[Scenario 3] QoE vs User Load (Fig 5, 6)")
|
||
plotter = Plotter()
|
||
user_counts = [4, 6, 8, 10, 12] # Total K | 总用户数 K
|
||
|
||
qoe_data = {}
|
||
rate_sat_data = {}
|
||
|
||
for algo_key, (display_name, AlgoClass) in ALGO_MAP.items():
|
||
qoe_vals = []
|
||
rate_vals = []
|
||
|
||
for k_total in user_counts:
|
||
test_config = deepcopy(config)
|
||
# Distribute users equally between types | 在不同类型之间平均分配用户
|
||
k_s = k_total // 2
|
||
k_b = k_total - k_s
|
||
test_config['env']['num_semantic_users'] = k_s
|
||
test_config['env']['num_traditional_users'] = k_b
|
||
|
||
env = WirelessEnv(test_config)
|
||
agent = AlgoClass(test_config)
|
||
|
||
model_path = os.path.join(results_dir, f'{algo_key}_best.pt')
|
||
if os.path.exists(model_path) and hasattr(agent, 'load'):
|
||
try:
|
||
agent.load(model_path)
|
||
except Exception:
|
||
pass
|
||
|
||
result = evaluate_episode(env, agent, test_config, num_episodes=num_eval)
|
||
qoe_vals.append(result['qoe_sys'])
|
||
rate_vals.append(result['rate_satisfaction'])
|
||
|
||
qoe_data[display_name] = qoe_vals
|
||
rate_sat_data[display_name] = rate_vals
|
||
print(f" {display_name}: QoE range [{min(qoe_vals):.3f}, {max(qoe_vals):.3f}]")
|
||
|
||
plotter.plot_qoe_vs_users(qoe_data, os.path.join(save_dir, 'fig5_qoe_vs_users'))
|
||
plotter.plot_rate_satisfaction_vs_users(rate_sat_data, os.path.join(save_dir, 'fig6_rate_sat_vs_users'))
|
||
print(f" Saved fig5, fig6")
|
||
|
||
|
||
# ============================================================
|
||
# Scenario 4: Lambda Dynamics (Fig 7, 8)
|
||
# ============================================================
|
||
def scenario_lambda_dynamics(config: dict, results_dir: str, save_dir: str):
|
||
"""
|
||
Analyze dynamic λ switching behavior of Co-MADDPG.
|
||
分析 Co-MADDPG 的动态 λ 切换行为。
|
||
"""
|
||
print("\n[Scenario 4] Lambda Dynamics (Fig 7, 8)")
|
||
plotter = Plotter()
|
||
|
||
env = WirelessEnv(config)
|
||
agent = CoMADDPG(config)
|
||
|
||
model_path = os.path.join(results_dir, 'co_maddpg_best.pt')
|
||
if os.path.exists(model_path):
|
||
try:
|
||
agent.load(model_path)
|
||
except Exception:
|
||
pass
|
||
|
||
# Run one episode and collect λ trajectory | 执行一个回合并收集 λ 轨迹
|
||
obs_s, obs_b = env.reset()
|
||
lambda_vals = []
|
||
qoe_vals = []
|
||
|
||
for step in range(config['training']['max_steps']):
|
||
act_s, act_b = agent.select_action(obs_s, obs_b, explore=False)
|
||
next_obs_s, next_obs_b, qoe_s, qoe_b, done, info = env.step(act_s, act_b)
|
||
|
||
qoe_sys = info['qoe_sys']
|
||
lambda_val = agent.compute_lambda(qoe_sys)
|
||
lambda_vals.append(float(lambda_val))
|
||
qoe_vals.append(float(qoe_sys))
|
||
|
||
obs_s, obs_b = next_obs_s, next_obs_b
|
||
if done:
|
||
break
|
||
|
||
plotter.plot_lambda_trajectory(lambda_vals, os.path.join(save_dir, 'fig7_lambda_trajectory'))
|
||
plotter.plot_lambda_qoe_scatter(lambda_vals, qoe_vals, os.path.join(save_dir, 'fig8_lambda_qoe_scatter'))
|
||
print(f" Saved fig7, fig8")
|
||
|
||
|
||
# ============================================================
|
||
# Scenario 5: Semantic/Traditional Ratio (Fig 9)
|
||
# ============================================================
|
||
def scenario_user_ratio(config: dict, results_dir: str, save_dir: str, num_eval=5):
|
||
"""
|
||
Evaluate with different semantic/traditional user ratios.
|
||
评估不同语义/传统用户比例下的性能。
|
||
|
||
Studies the impact as semantic communication becomes more prevalent.
|
||
研究语义通信变得更加普遍时的影响。
|
||
"""
|
||
print("\n[Scenario 5] User Ratio Analysis (Fig 9)")
|
||
plotter = Plotter()
|
||
|
||
total_users = 6
|
||
ratios = [0.0, 0.17, 0.33, 0.5, 0.67, 0.83, 1.0] # semantic fraction | 语义用户占比
|
||
|
||
qoe_data = {}
|
||
|
||
for algo_key, (display_name, AlgoClass) in ALGO_MAP.items():
|
||
qoe_vals = []
|
||
|
||
for ratio in ratios:
|
||
# Map ratio to discrete integer counts | 将比例映射为离散整数计数
|
||
k_s = max(0, min(total_users, int(round(ratio * total_users))))
|
||
k_b = total_users - k_s
|
||
# Ensure at least one of each for hybrid env constraints if necessary
|
||
# 如有必要,确保混合环境约束下每种类型至少有一个
|
||
if k_s == 0: k_s = 1; k_b = total_users - 1
|
||
if k_b == 0: k_b = 1; k_s = total_users - 1
|
||
|
||
test_config = deepcopy(config)
|
||
test_config['env']['num_semantic_users'] = k_s
|
||
test_config['env']['num_traditional_users'] = k_b
|
||
|
||
env = WirelessEnv(test_config)
|
||
agent = AlgoClass(test_config)
|
||
|
||
model_path = os.path.join(results_dir, f'{algo_key}_best.pt')
|
||
if os.path.exists(model_path) and hasattr(agent, 'load'):
|
||
try:
|
||
agent.load(model_path)
|
||
except Exception:
|
||
pass
|
||
|
||
result = evaluate_episode(env, agent, test_config, num_episodes=num_eval)
|
||
qoe_vals.append(result['qoe_sys'])
|
||
|
||
qoe_data[display_name] = qoe_vals
|
||
|
||
plotter.plot_qoe_vs_ratio(qoe_data, ratios, os.path.join(save_dir, 'fig9_qoe_vs_ratio'))
|
||
print(f" Saved fig9")
|
||
|
||
|
||
# ============================================================
|
||
# Scenario 6: Ablation Study (Fig 10)
|
||
# ============================================================
|
||
def scenario_ablation(config: dict, results_dir: str, save_dir: str, num_eval=5):
|
||
"""
|
||
Run ablation study comparing core components.
|
||
运行消融实验比较核心组件。
|
||
|
||
Ablation Mapping:
|
||
- w/o Stackelberg: Pure Cooperative (simultaneous update) | 无 Stackelberg:纯协作(同步更新)
|
||
- w/o Dynamic λ: Fixed Lambda (λ=0.5) | 无动态 λ:固定 Lambda (λ=0.5)
|
||
- w/o Cooperation: Pure Competitive (λ=0) | 无协作:纯竞争 (λ=0)
|
||
- w/o CTDE: IDDPG (Independent Critics) | 无 CTDE:IDDPG(独立评论家)
|
||
"""
|
||
print("\n[Scenario 6] Ablation Study (Fig 10)")
|
||
plotter = Plotter()
|
||
|
||
ablation_keys = {
|
||
'Co-MADDPG (Full)': 'co_maddpg',
|
||
'w/o Stackelberg': 'pure_coop',
|
||
'w/o Dynamic λ': 'fixed_lambda',
|
||
'w/o Cooperation': 'pure_comp',
|
||
'w/o CTDE': 'iddpg',
|
||
}
|
||
|
||
ablation_data = {}
|
||
for label, algo_key in ablation_keys.items():
|
||
history_path = os.path.join(results_dir, f'{algo_key}_history.json')
|
||
if os.path.exists(history_path):
|
||
with open(history_path, 'r') as f:
|
||
history = json.load(f)
|
||
# Average of last 500 episodes for stability | 为保证稳定性取最后 500 回合的平均值
|
||
qoe_series = history.get('episode_qoe_sys', [])
|
||
if len(qoe_series) >= 500:
|
||
ablation_data[label] = np.mean(qoe_series[-500:])
|
||
elif len(qoe_series) > 0:
|
||
ablation_data[label] = np.mean(qoe_series[-len(qoe_series)//5:])
|
||
else:
|
||
ablation_data[label] = 0.0
|
||
else:
|
||
# Fallback to direct evaluation if history missing | 如果历史记录缺失,则回退到直接评估
|
||
env = WirelessEnv(config)
|
||
AlgoClass = ALGO_MAP[algo_key][1]
|
||
agent = AlgoClass(config)
|
||
model_path = os.path.join(results_dir, f'{algo_key}_best.pt')
|
||
if os.path.exists(model_path) and hasattr(agent, 'load'):
|
||
try:
|
||
agent.load(model_path)
|
||
except Exception:
|
||
pass
|
||
result = evaluate_episode(env, agent, config, num_episodes=num_eval)
|
||
ablation_data[label] = result['qoe_sys']
|
||
|
||
plotter.plot_ablation(ablation_data, os.path.join(save_dir, 'fig10_ablation'))
|
||
print(f" Saved fig10")
|
||
|
||
|
||
# ============================================================
|
||
# Scenario 7: β Sensitivity (Fig 11)
|
||
# ============================================================
|
||
def scenario_beta_sensitivity(config: dict, results_dir: str, save_dir: str, num_eval=5):
|
||
"""
|
||
Evaluate sensitivity to the β parameter in the sigmoid function.
|
||
评估 Sigmoid 函数中 β 参数的敏感性。
|
||
|
||
β controls the steepness of switching between competition and cooperation.
|
||
β 控制竞争与协作之间切换的陡峭程度。
|
||
"""
|
||
print("\n[Scenario 7] β Sensitivity (Fig 11)")
|
||
plotter = Plotter()
|
||
|
||
betas = [1, 3, 5, 7, 10]
|
||
qoe_data = {}
|
||
|
||
for beta in betas:
|
||
test_config = deepcopy(config)
|
||
test_config['training']['beta'] = float(beta)
|
||
|
||
env = WirelessEnv(test_config)
|
||
agent = CoMADDPG(test_config)
|
||
|
||
model_path = os.path.join(results_dir, 'co_maddpg_best.pt')
|
||
if os.path.exists(model_path):
|
||
try:
|
||
agent.load(model_path)
|
||
except Exception:
|
||
pass
|
||
|
||
result = evaluate_episode(env, agent, test_config, num_episodes=num_eval)
|
||
qoe_data[f'β={beta}'] = result['qoe_sys']
|
||
print(f" β={beta}: QoE_sys={result['qoe_sys']:.4f}")
|
||
|
||
plotter.plot_beta_sensitivity(qoe_data, betas, os.path.join(save_dir, 'fig11_beta_sensitivity'))
|
||
print(f" Saved fig11")
|
||
|
||
|
||
# ============================================================
|
||
# Scenario 8: Q_th Sensitivity (Fig 12)
|
||
# ============================================================
|
||
def scenario_qth_sensitivity(config: dict, results_dir: str, save_dir: str, num_eval=5):
|
||
"""
|
||
Evaluate sensitivity to the Q_th threshold parameter.
|
||
评估 Q_th 阈值参数的敏感性。
|
||
|
||
Q_th is the target QoE level below which cooperation is triggered.
|
||
Q_th 是触发协作的目标 QoE 水平。
|
||
"""
|
||
print("\n[Scenario 8] Q_th Sensitivity (Fig 12)")
|
||
plotter = Plotter()
|
||
|
||
qths = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
|
||
qoe_data = {}
|
||
|
||
for qth in qths:
|
||
test_config = deepcopy(config)
|
||
test_config['training']['q_threshold'] = float(qth)
|
||
|
||
env = WirelessEnv(test_config)
|
||
agent = CoMADDPG(test_config)
|
||
|
||
model_path = os.path.join(results_dir, 'co_maddpg_best.pt')
|
||
if os.path.exists(model_path):
|
||
try:
|
||
agent.load(model_path)
|
||
except Exception:
|
||
pass
|
||
|
||
result = evaluate_episode(env, agent, test_config, num_episodes=num_eval)
|
||
qoe_data[f'Q_th={qth}'] = result['qoe_sys']
|
||
print(f" Q_th={qth}: QoE_sys={result['qoe_sys']:.4f}")
|
||
|
||
plotter.plot_qth_sensitivity(qoe_data, qths, os.path.join(save_dir, 'fig12_qth_sensitivity'))
|
||
print(f" Saved fig12")
|
||
|
||
|
||
# ============================================================
|
||
# Run all scenarios | 执行所有场景
|
||
# ============================================================
|
||
def run_all_scenarios(config: dict, results_dir: str, save_dir: str):
|
||
"""Run all evaluation scenarios and generate all figures. | 执行所有评估场景并生成所有图表。"""
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
scenario_convergence(results_dir, save_dir)
|
||
scenario_snr(config, results_dir, save_dir)
|
||
scenario_user_load(config, results_dir, save_dir)
|
||
scenario_lambda_dynamics(config, results_dir, save_dir)
|
||
scenario_user_ratio(config, results_dir, save_dir)
|
||
scenario_ablation(config, results_dir, save_dir)
|
||
scenario_beta_sensitivity(config, results_dir, save_dir)
|
||
scenario_qth_sensitivity(config, results_dir, save_dir)
|
||
|
||
print(f"\nAll figures saved to: {save_dir}")
|
||
|
||
|
||
def main():
|
||
"""Main entry point for evaluation. | 评估主入口。"""
|
||
parser = argparse.ArgumentParser(description='Co-MADDPG Evaluation')
|
||
parser.add_argument('--config', type=str, default='configs/default.yaml',
|
||
help='Path to config YAML')
|
||
parser.add_argument('--results_dir', type=str, required=True,
|
||
help='Directory with trained models and history')
|
||
parser.add_argument('--save_dir', type=str, default=None,
|
||
help='Directory to save figures (default: results_dir/figures)')
|
||
parser.add_argument('--scenario', type=str, default='all',
|
||
choices=['all', 'convergence', 'snr', 'user_load',
|
||
'lambda', 'ratio', 'ablation', 'beta', 'qth'],
|
||
help='Evaluation scenario to run')
|
||
parser.add_argument('--num_eval', type=int, default=10,
|
||
help='Number of evaluation episodes per setting')
|
||
args = parser.parse_args()
|
||
|
||
config_path = os.path.join(PROJECT_ROOT, args.config)
|
||
config = load_config(config_path)
|
||
|
||
save_dir = args.save_dir or os.path.join(args.results_dir, 'figures')
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
# Dispatch to specific scenario | 分派到特定场景
|
||
scenario_map = {
|
||
'all': lambda: run_all_scenarios(config, args.results_dir, save_dir),
|
||
'convergence': lambda: scenario_convergence(args.results_dir, save_dir),
|
||
'snr': lambda: scenario_snr(config, args.results_dir, save_dir, args.num_eval),
|
||
'user_load': lambda: scenario_user_load(config, args.results_dir, save_dir, args.num_eval),
|
||
'lambda': lambda: scenario_lambda_dynamics(config, args.results_dir, save_dir),
|
||
'ratio': lambda: scenario_user_ratio(config, args.results_dir, save_dir, args.num_eval),
|
||
'ablation': lambda: scenario_ablation(config, args.results_dir, save_dir, args.num_eval),
|
||
'beta': lambda: scenario_beta_sensitivity(config, args.results_dir, save_dir, args.num_eval),
|
||
'qth': lambda: scenario_qth_sensitivity(config, args.results_dir, save_dir, args.num_eval),
|
||
}
|
||
|
||
scenario_map[args.scenario]()
|
||
print("\nEvaluation complete!")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|