SemanticCommunication/code/utils/visualization.py

314 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Co-MADDPG Visualization Module | Co-MADDPG 可视化模块
This module handles the generation of IEEE-standard figures for the resource
allocation performance evaluation. It maps specifically to Section VII (Experimental
Results) of the associated research paper.
本模块负责生成符合 IEEE 标准的资源分配性能评估图表。它专门对应于相关研究论文的
第七节(实验结果)。
Reference Figures (from Section VII):
- Fig 2: Training convergence curves | 训练收敛曲线
- Fig 3: System QoE vs. SNR | 系统 QoE 随 SNR 的变化
- Fig 4: Jain's Fairness Index vs. SNR | Jain 公平性指数随 SNR 的变化
- Fig 5: System QoE vs. Total Users (K) | 系统 QoE 随总用户数 (K) 的变化
- Fig 6: Rate Satisfaction Ratio vs. Total Users (K) | 速率满足率随总用户数 (K) 的变化
- Fig 7: Trajectory of λ(t) over time | λ(t) 随时间的变化轨迹
- Fig 8: Correlation scatter of λ and System QoE | λ 与系统 QoE 的相关性散点图
- Fig 9: System QoE vs. Semantic User Ratio | 系统 QoE 随语义用户比例的变化
- Fig 10: Ablation study results | 消融实验结果
- Fig 11: Sensitivity analysis of β | β 参数的敏感性分析
- Fig 12: Sensitivity analysis of Q_th | Q_th 阈值的敏感性分析
"""
import os
import numpy as np
import matplotlib
# Use non-interactive backend to avoid requiring an X server or display
# 使用非交互式后端以避免需要 X 服务器或显示器
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from utils.metrics import moving_average
# IEEE-quality plotting defaults | IEEE 质量绘图默认设置
plt.rcParams.update({
'font.family': 'serif',
'font.serif': ['Times New Roman', 'DejaVu Serif'],
'font.size': 12,
'axes.grid': True,
'figure.figsize': (8, 6),
'figure.autolayout': True, # Equivalent to tight_layout | 等同于 tight_layout
'savefig.dpi': 300,
'savefig.bbox': 'tight'
})
# Consistent algorithm styles | 一致的算法绘图风格
# Color and marker choices distinguish between proposed, baselines, and ablation variants
# 颜色和标记的选择用于区分建议算法、基准算法和消融变体
ALGO_STYLES = {
'Co-MADDPG': {'color': '#E24A33', 'marker': 'o', 'linestyle': '-'}, # Proposed (Red) | 建议算法(红色)
'Pure Cooperative': {'color': '#348ABD', 'marker': 's', 'linestyle': '--'}, # Baseline (Blue) | 基准(蓝色)
'Pure Competitive': {'color': '#988ED5', 'marker': '^', 'linestyle': '--'}, # Baseline (Purple) | 基准(紫色)
'Single-Agent DQN': {'color': '#777777', 'marker': 'D', 'linestyle': '-.'}, # Baseline (Gray) | 基准(灰色)
'IDDPG': {'color': '#FBC15E', 'marker': 'v', 'linestyle': '-.'}, # Baseline (Yellow) | 基准(黄色)
'Fixed λ=0.5': {'color': '#8EBA42', 'marker': 'p', 'linestyle': ':'}, # Ablation (Green) | 消融(绿色)
'Equal Allocation': {'color': '#FFB5B8', 'marker': '*', 'linestyle': ':'}, # Baseline (Pink) | 基准(粉色)
'Semantic-Only': {'color': '#6d904f', 'marker': 'h', 'linestyle': ':'}, # Baseline (Olive) | 基准(橄榄色)
}
class Plotter:
"""
IEEE-quality plotting module for all paper figures. | 用于所有论文图表的 IEEE 质量绘图模块。
"""
def __init__(self):
pass
def _get_style(self, algo_name):
"""
Helper to get plotting style for an algorithm or a sensible default.
获取算法的绘图风格或合理的默认值。
"""
return ALGO_STYLES.get(algo_name, {'color': 'k', 'marker': '', 'linestyle': '-'})
def _save_plot(self, save_path):
"""
Helper to save plot in both PDF and PNG formats at 300 DPI.
以 300 DPI 的分辨率将图表保存为 PDF 和 PNG 格式。
"""
os.makedirs(os.path.dirname(os.path.abspath(save_path)), exist_ok=True)
# Strip extension if the user provided one, to consistently save both .pdf and .png
# 如果用户提供了扩展名,则将其去除,以便一致地保存 .pdf 和 .png
base_path = os.path.splitext(save_path)[0]
plt.savefig(f"{base_path}.pdf", format='pdf')
plt.savefig(f"{base_path}.png", format='png', dpi=300)
plt.close()
def plot_convergence(self, data_dict, save_path):
"""
Fig 2: Episode QoE_sys curves. | 图 2每回合系统 QoE 曲线。
Shows how the algorithm improves over training episodes.
展示算法在训练回合中如何改进。
data_dict: {algo_name: [episode_qoe_values]}
"""
plt.figure()
for algo, qoe_vals in data_dict.items():
style = self._get_style(algo)
# Remove markers for dense convergence plots to maintain clean look
# 为密集的收敛图移除标记以保持画面整洁
plot_style = style.copy()
if 'marker' in plot_style:
plot_style.pop('marker')
smoothed_qoe = moving_average(qoe_vals, window=50)
x_vals = np.arange(len(smoothed_qoe))
plt.plot(x_vals, smoothed_qoe, label=algo, **plot_style)
plt.xlabel('Episode')
plt.ylabel('System QoE')
plt.title('Training Convergence')
plt.legend()
self._save_plot(save_path)
def plot_qoe_vs_snr(self, data_dict, save_path):
"""
Fig 3: QoE vs SNR. | 图 3QoE 随 SNR 的变化。
Evaluates system robustness under different noise levels.
评估不同噪声水平下的系统鲁棒性。
data_dict: {algo_name: [qoe_per_snr_point]}
"""
plt.figure()
snr_vals = [0, 5, 10, 15, 20, 25, 30]
for algo, qoe_vals in data_dict.items():
style = self._get_style(algo)
plt.plot(snr_vals, qoe_vals, label=algo, **style)
plt.xlabel('SNR (dB)')
plt.ylabel('System QoE')
plt.title('System QoE vs. SNR')
plt.legend()
self._save_plot(save_path)
def plot_fairness_vs_snr(self, data_dict, save_path):
"""
Fig 4: Jain Fairness Index vs SNR. | 图 4Jain 公平性指数随 SNR 的变化。
Measures the balance of resource allocation across users.
衡量不同用户之间资源分配的平衡性。
data_dict: {algo_name: [fairness_per_snr_point]}
"""
plt.figure()
snr_vals = [0, 5, 10, 15, 20, 25, 30]
for algo, fairness_vals in data_dict.items():
style = self._get_style(algo)
plt.plot(snr_vals, fairness_vals, label=algo, **style)
plt.xlabel('SNR (dB)')
plt.ylabel('Jain Fairness Index')
plt.title('Fairness vs. SNR')
plt.legend()
self._save_plot(save_path)
def plot_qoe_vs_users(self, data_dict, save_path):
"""
Fig 5: QoE vs Total Users K. | 图 5QoE 随总用户数 K 的变化。
Tests system scalability as user density increases.
测试随着用户密度增加系统的可扩展性。
data_dict: {algo_name: [qoe_per_k_point]}
"""
plt.figure()
users_vals = [4, 6, 8, 10, 12]
for algo, qoe_vals in data_dict.items():
style = self._get_style(algo)
plt.plot(users_vals, qoe_vals, label=algo, **style)
plt.xlabel('Total Users (K)')
plt.ylabel('System QoE')
plt.title('System QoE vs. Total Users')
plt.legend()
self._save_plot(save_path)
def plot_rate_satisfaction_vs_users(self, data_dict, save_path):
"""
Fig 6: Rate Satisfaction Ratio vs Total Users K. | 图 6速率满足率随总用户数 K 的变化。
Evaluates the ability to meet minimum QoS requirements.
评估满足最小 QoS 要求的能力。
data_dict: {algo_name: [rate_satisfaction_per_k_point]}
"""
plt.figure()
users_vals = [4, 6, 8, 10, 12]
for algo, sat_vals in data_dict.items():
style = self._get_style(algo)
plt.plot(users_vals, sat_vals, label=algo, **style)
plt.xlabel('Total Users (K)')
plt.ylabel('Rate Satisfaction Ratio')
plt.title('Rate Satisfaction vs. Total Users')
plt.legend()
self._save_plot(save_path)
def plot_lambda_trajectory(self, lambda_values, save_path):
"""
Fig 7: Lambda Trajectory. | 图 7Lambda 轨迹。
Visualizes the dynamic switching between cooperation and competition.
可视化协作与竞争之间的动态切换。
lambda_values: list of lambda(t) values.
"""
plt.figure()
time_steps = np.arange(len(lambda_values))
plt.plot(time_steps, lambda_values, label=r'$\lambda(t)$', color='#348ABD', linestyle='-')
# Reference line for fixed weighting | 固定权重的参考线
plt.axhline(y=0.5, color='#E24A33', linestyle='--', label=r'Reference ($\lambda=0.5$)')
plt.xlabel('Time Step')
plt.ylabel(r'$\lambda(t)$')
plt.title(r'Trajectory of Allocation Parameter $\lambda$')
plt.legend()
self._save_plot(save_path)
def plot_lambda_qoe_scatter(self, lambdas, qoes, save_path):
"""
Fig 8: Scatter of (lambda, QoE_sys). | 图 8(lambda, QoE_sys) 散点图。
Shows the correlation between the dynamic parameter and system performance.
展示动态参数与系统性能之间的相关性。
"""
plt.figure()
time_steps = np.arange(len(lambdas))
# Color points by time to show evolution | 按时间为点着色以显示演化过程
sc = plt.scatter(lambdas, qoes, c=time_steps, cmap='viridis', alpha=0.7)
cbar = plt.colorbar(sc)
cbar.set_label('Time Step')
plt.xlabel(r'$\lambda$')
plt.ylabel('System QoE')
plt.title(r'Correlation between $\lambda$ and System QoE')
self._save_plot(save_path)
def plot_qoe_vs_ratio(self, data_dict, ratios, save_path):
"""
Fig 9: QoE vs Semantic User Ratio. | 图 9QoE 随语义用户比例的变化。
Studies the impact of increasing semantic communication prevalence.
研究语义通信普及率增加的影响。
data_dict: {algo_name: [qoe_values]}
"""
plt.figure()
for algo, qoe_vals in data_dict.items():
style = self._get_style(algo)
plt.plot(ratios, qoe_vals, label=algo, **style)
plt.xlabel('Semantic User Ratio')
plt.ylabel('System QoE')
plt.title('System QoE vs. Semantic User Ratio')
plt.legend()
self._save_plot(save_path)
def plot_ablation(self, data, save_path):
"""
Fig 10: Horizontal bar chart for ablation study. | 图 10消融研究的水平条形图。
Compares the full Co-MADDPG against its stripped-down variants.
将完整的 Co-MADDPG 与其简化变体进行比较。
data: {variant_label: qoe_value}
"""
plt.figure()
labels = list(data.keys())
values = list(data.values())
y_pos = np.arange(len(labels))
# Highlight Co-MADDPG (Full) in red if present | 如果存在,用红色高亮 Co-MADDPG (Full)
colors = ['#E24A33' if 'Co-MADDPG' in label and 'Full' in label else '#348ABD' for label in labels]
plt.barh(y_pos, values, align='center', color=colors)
plt.yticks(y_pos, labels)
plt.xlabel('System QoE')
plt.title('Ablation Study')
self._save_plot(save_path)
def plot_beta_sensitivity(self, data_dict, betas, save_path):
"""
Fig 11: QoE vs Beta values. | 图 11QoE 随 Beta 值的变化。
Analyzes sensitivity to the sigmoid steepness parameter.
分析对 Sigmoid 陡峭度参数的敏感性。
data_dict: {label: qoe_value_list}
"""
plt.figure()
for algo, qoe_vals in data_dict.items():
style = self._get_style(algo)
plt.plot(betas, qoe_vals, label=algo, **style)
plt.xlabel(r'$\beta$ Parameter')
plt.ylabel('System QoE')
plt.title(r'Sensitivity Analysis of $\beta$')
plt.legend()
self._save_plot(save_path)
def plot_qth_sensitivity(self, data_dict, qths, save_path):
"""
Fig 12: QoE vs Q_th values. | 图 12QoE 随 Q_th 值的变化。
Analyzes sensitivity to the cooperation threshold.
分析对协作阈值的敏感性。
data_dict: {label: qoe_value_list}
"""
plt.figure()
for algo, qoe_vals in data_dict.items():
style = self._get_style(algo)
plt.plot(qths, qoe_vals, label=algo, **style)
plt.xlabel(r'$Q_{th}$ Threshold')
plt.ylabel('System QoE')
plt.title(r'Sensitivity Analysis of $Q_{th}$')
plt.legend()
self._save_plot(save_path)