600 lines
16 KiB
Markdown
600 lines
16 KiB
Markdown
# API 接口文档 / API Reference
|
||
|
||
本文档详细描述了 Co-MADDPG 项目中所有公开类和函数的接口。
|
||
|
||
---
|
||
|
||
## 目录 / Table of Contents
|
||
|
||
1. [环境模块 envs/](#1-环境模块-envs)
|
||
- [ChannelModel](#channelmodel)
|
||
- [SemanticModule](#semanticmodule)
|
||
- [WirelessEnv](#wirelessenv)
|
||
2. [算法模块 agents/](#2-算法模块-agents)
|
||
- [Actor](#actor)
|
||
- [Critic](#critic)
|
||
- [OUNoise](#ounoise)
|
||
- [ReplayBuffer](#replaybuffer)
|
||
- [CoMADDPG](#comaddpg)
|
||
3. [基线模块 baselines/](#3-基线模块-baselines)
|
||
- [通用接口](#通用接口--common-interface)
|
||
- [各基线差异](#各基线差异--baseline-differences)
|
||
4. [工具模块 utils/](#4-工具模块-utils)
|
||
- [metrics.py](#metricspy)
|
||
- [visualization.py](#visualizationpy)
|
||
5. [入口脚本](#5-入口脚本--entry-scripts)
|
||
- [train.py](#trainpy)
|
||
- [evaluate.py](#evaluatepy)
|
||
|
||
---
|
||
|
||
## 1. 环境模块 envs/
|
||
|
||
### ChannelModel
|
||
|
||
**文件**: `envs/channel_model.py`
|
||
|
||
3GPP Urban Micro NLOS 信道模型,负责路径损耗计算、复信道增益生成和 SNR 计算。
|
||
|
||
```python
|
||
class ChannelModel:
|
||
def __init__(self, config: dict) -> None
|
||
```
|
||
|
||
| 参数 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `config` | dict | 完整配置字典,需包含 `config["env"]["carrier_freq"]`, `config["env"]["noise_psd"]`, `config["env"]["subcarrier_spacing"]` |
|
||
|
||
#### 方法
|
||
|
||
**`path_loss(distance) -> float`**
|
||
|
||
计算 3GPP UMi NLOS 路径损耗。
|
||
|
||
| 参数 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `distance` | float / np.ndarray | 收发机距离 (米) |
|
||
| **返回** | float / np.ndarray | 路径损耗 (dB) |
|
||
|
||
公式: `PL(d) = 36.7·log₁₀(d) + 22.7 + 26·log₁₀(fc)`
|
||
|
||
---
|
||
|
||
**`generate_channel(distances, num_subcarriers) -> np.ndarray`**
|
||
|
||
生成复信道增益矩阵。
|
||
|
||
| 参数 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `distances` | np.ndarray (K,) | 每个用户的距离 |
|
||
| `num_subcarriers` | int | 子载波数 N |
|
||
| **返回** | np.ndarray (K, N) | 复信道增益 `h_{k,n} ~ CN(0, 10^{-PL/10})` |
|
||
|
||
---
|
||
|
||
**`compute_snr(channel_gains, power_alloc, noise_power) -> np.ndarray`**
|
||
|
||
计算每用户每子载波的 SNR。
|
||
|
||
| 参数 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `channel_gains` | np.ndarray (K, N) | 复信道增益 |
|
||
| `power_alloc` | np.ndarray (K, N) | 功率分配矩阵 (W) |
|
||
| `noise_power` | float | 每子载波噪声功率 σ² (W) |
|
||
| **返回** | np.ndarray (K, N) | SNR (线性尺度) |
|
||
|
||
公式: `γ_{k,n} = p_{k,n} · |h_{k,n}|² / σ²`
|
||
|
||
---
|
||
|
||
**`noise_power` (property) -> float**
|
||
|
||
每子载波热噪声功率 (W)。
|
||
|
||
公式: `σ² = 10^{(N₀_dBm - 30)/10} · Δf`
|
||
|
||
---
|
||
|
||
### SemanticModule
|
||
|
||
**文件**: `envs/semantic_module.py`
|
||
|
||
语义通信质量模块,计算 SSim 和语义 QoE。
|
||
|
||
```python
|
||
class SemanticModule:
|
||
def __init__(self, config: dict) -> None
|
||
```
|
||
|
||
| 参数 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `config` | dict | 需包含 `config["env"]["rho_max"]`, `rho_min`, `w1`, `w2` |
|
||
|
||
#### 方法
|
||
|
||
**`compute_ssim(avg_snr, rho) -> float`**
|
||
|
||
计算语义相似度指数。
|
||
|
||
| 参数 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `avg_snr` | float / np.ndarray | 平均 SNR (线性尺度) |
|
||
| `rho` | float | 压缩率 ρ ∈ [ρ_min, ρ_max] |
|
||
| **返回** | float / np.ndarray | SSim ∈ [0, 1] |
|
||
|
||
公式: `φ(γ̄, ρ) = 1 - exp(-a(ρ)·γ̄^{b(ρ)})`,其中 `a(ρ) = 0.8/(ρ+0.1)`, `b(ρ) = 0.6+0.2·ρ`
|
||
|
||
---
|
||
|
||
**`compute_avg_snr(snr_per_subcarrier, allocation_mask) -> float`**
|
||
|
||
计算已分配子载波的平均 SNR。
|
||
|
||
| 参数 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `snr_per_subcarrier` | np.ndarray | 所有子载波的 SNR |
|
||
| `allocation_mask` | np.ndarray | 二进制掩码 (1=已分配) |
|
||
| **返回** | float | 平均 SNR (无分配时返回 0.0) |
|
||
|
||
---
|
||
|
||
**`compute_semantic_qoe(ssim, rho, w1=None, w2=None, rho_max=None) -> float`**
|
||
|
||
计算语义用户 QoE。
|
||
|
||
| 参数 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `ssim` | float | 语义相似度 ∈ [0, 1] |
|
||
| `rho` | float | 压缩率 |
|
||
| `w1`, `w2` | float, optional | 权重 (默认使用配置值) |
|
||
| `rho_max` | float, optional | 最大压缩率 (默认使用配置值) |
|
||
| **返回** | float | QoE ∈ [0, 1] |
|
||
|
||
公式: `QoE_s = w1·SSim + w2·(1 - ρ/ρ_max)`
|
||
|
||
---
|
||
|
||
### WirelessEnv
|
||
|
||
**文件**: `envs/wireless_env.py`
|
||
|
||
Gym 风格的无线资源分配环境,管理信道状态、执行动作、计算 QoE。
|
||
|
||
```python
|
||
class WirelessEnv:
|
||
def __init__(self, config: dict)
|
||
```
|
||
|
||
| 属性 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `obs_dim` | int (property) | 观察维度 = N + 4 |
|
||
| `act_dim` | int (property) | 动作维度 = 3 |
|
||
| `N` | int | 子载波数量 |
|
||
| `K_s`, `K_b`, `K` | int | 语义/传统/总用户数 |
|
||
|
||
#### 方法
|
||
|
||
**`reset() -> (obs_s, obs_b)`**
|
||
|
||
重置环境。随机化用户距离、信道、辅助参数。
|
||
|
||
| 返回 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `obs_s` | np.ndarray (obs_dim,) | 语义 agent 观察 (float32) |
|
||
| `obs_b` | np.ndarray (obs_dim,) | 传统 agent 观察 (float32) |
|
||
|
||
---
|
||
|
||
**`step(action_s, action_b) -> (obs_s, obs_b, reward_s, reward_b, done, info)`**
|
||
|
||
执行一步。
|
||
|
||
| 参数 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `action_s` | np.ndarray (3,) | 语义 agent 动作 [sub_frac, power_frac, rho] |
|
||
| `action_b` | np.ndarray (3,) | 传统 agent 动作 [sub_frac, power_frac, _] |
|
||
|
||
| 返回 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `obs_s`, `obs_b` | np.ndarray | 新观察 |
|
||
| `reward_s`, `reward_b` | float | 各自平均 QoE(作为基础奖励) |
|
||
| `done` | bool | 是否达到 max_steps |
|
||
| `info` | dict | 详细信息(见下表) |
|
||
|
||
**info 字典内容:**
|
||
|
||
| Key | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `qoe_semantic` | float | 语义组平均 QoE |
|
||
| `qoe_traditional` | float | 传统组平均 QoE |
|
||
| `qoe_sys` | float | 系统平均 QoE |
|
||
| `qoe_list` | list[float] | 每个用户的 QoE |
|
||
| `rates` | list[float] | 传统用户速率 (bps) |
|
||
| `ssim_values` | list[float] | 语义用户 SSim 值 |
|
||
| `rate_satisfaction` | float | 速率满足比例 ∈ [0, 1] |
|
||
| `rho` | float | 实际使用的压缩率 |
|
||
| `n_sub_s`, `n_sub_b` | int | 分配的子载波数量 |
|
||
|
||
---
|
||
|
||
## 2. 算法模块 agents/
|
||
|
||
### Actor
|
||
|
||
**文件**: `agents/actor.py`
|
||
|
||
确定性策略网络,输出 [0, 1] 范围的连续动作。
|
||
|
||
```python
|
||
class Actor(nn.Module):
|
||
def __init__(self, obs_dim: int, act_dim: int, hidden_sizes: list = [256, 256, 128])
|
||
```
|
||
|
||
**`forward(obs) -> torch.Tensor`**
|
||
|
||
| 参数 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `obs` | Tensor (batch, obs_dim) | 观察 |
|
||
| **返回** | Tensor (batch, act_dim) | 动作 ∈ [0, 1],通过 `(tanh(x) + 1) / 2` |
|
||
|
||
---
|
||
|
||
### Critic
|
||
|
||
**文件**: `agents/critic.py`
|
||
|
||
联合 Q 值网络 (CTDE),输入所有 agent 的观察和动作。
|
||
|
||
```python
|
||
class Critic(nn.Module):
|
||
def __init__(self, obs_dim_total: int, act_dim_total: int, hidden_sizes: list = [512, 512, 256])
|
||
```
|
||
|
||
- `obs_dim_total` = obs_dim × 2 = 136
|
||
- `act_dim_total` = act_dim × 2 = 6
|
||
- 总输入维度 = 142
|
||
|
||
**`forward(obs, act) -> torch.Tensor`**
|
||
|
||
| 参数 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `obs` | Tensor (batch, obs_dim_total) | 联合观察 concat(obs_s, obs_b) |
|
||
| `act` | Tensor (batch, act_dim_total) | 联合动作 concat(act_s, act_b) |
|
||
| **返回** | Tensor (batch, 1) | Q 值 |
|
||
|
||
---
|
||
|
||
### OUNoise
|
||
|
||
**文件**: `agents/noise.py`
|
||
|
||
Ornstein-Uhlenbeck 探索噪声,带线性 sigma 衰减。
|
||
|
||
```python
|
||
class OUNoise:
|
||
def __init__(self, size: int, mu: float = 0.0, theta: float = 0.15,
|
||
sigma_init: float = 0.2, sigma_min: float = 0.01, decay_period: int = 5000)
|
||
```
|
||
|
||
| 参数 | 说明 |
|
||
|---|---|
|
||
| `size` | 噪声维度 (= act_dim = 3) |
|
||
| `theta` | 回归速率 (默认 0.15) |
|
||
| `sigma_init` | 初始标准差 (默认 0.2) |
|
||
| `sigma_min` | 最小标准差 (默认 0.01) |
|
||
| `decay_period` | 线性衰减周期 (默认 5000 episodes) |
|
||
|
||
#### 方法
|
||
|
||
| 方法 | 说明 |
|
||
|---|---|
|
||
| `reset()` | 重置噪声状态到 μ |
|
||
| `sample() -> np.ndarray` | 采样一步 OU 噪声 |
|
||
| `decay_sigma(episode)` | 线性衰减 sigma: `σ = max(σ_min, σ_init - (σ_init - σ_min) · episode / decay_period)` |
|
||
|
||
---
|
||
|
||
### ReplayBuffer
|
||
|
||
**文件**: `agents/replay_buffer.py`
|
||
|
||
9-field 经验回放缓冲区。
|
||
|
||
```python
|
||
class ReplayBuffer:
|
||
def __init__(self, capacity: int = 100000)
|
||
```
|
||
|
||
#### 方法
|
||
|
||
**`push(obs_s, obs_b, act_s, act_b, rew_s, rew_b, next_obs_s, next_obs_b, done)`**
|
||
|
||
存储一个 transition。所有参数为 numpy array 或 float。
|
||
|
||
**`sample(batch_size) -> dict`**
|
||
|
||
随机采样一批 transitions。
|
||
|
||
| 返回 key | 类型 | Shape |
|
||
|---|---|---|
|
||
| `obs_s` | np.ndarray | (batch, obs_dim) |
|
||
| `obs_b` | np.ndarray | (batch, obs_dim) |
|
||
| `act_s` | np.ndarray | (batch, act_dim) |
|
||
| `act_b` | np.ndarray | (batch, act_dim) |
|
||
| `rew_s` | np.ndarray | (batch, 1) |
|
||
| `rew_b` | np.ndarray | (batch, 1) |
|
||
| `next_obs_s` | np.ndarray | (batch, obs_dim) |
|
||
| `next_obs_b` | np.ndarray | (batch, obs_dim) |
|
||
| `done` | np.ndarray | (batch, 1) |
|
||
|
||
**`__len__() -> int`**: 当前存储的 transition 数量。
|
||
|
||
---
|
||
|
||
### CoMADDPG
|
||
|
||
**文件**: `agents/co_maddpg.py`
|
||
|
||
Co-MADDPG 主算法,实现 Stackelberg Leader-Follower 更新。
|
||
|
||
```python
|
||
class CoMADDPG:
|
||
def __init__(self, config: dict)
|
||
```
|
||
|
||
| 关键属性 | 类型 | 说明 |
|
||
|---|---|---|
|
||
| `actor_s`, `actor_b` | Actor | 语义/传统 Actor 网络 |
|
||
| `critic_s`, `critic_b` | Critic | 语义/传统 Critic 网络 |
|
||
| `actor_s_target`, ... | Actor/Critic | Target 网络 (4个) |
|
||
| `noise_s`, `noise_b` | OUNoise | 探索噪声 |
|
||
| `buffer` | ReplayBuffer | 经验回放 |
|
||
| `current_lambda` | float | 当前 λ(t) 值 |
|
||
| `device` | torch.device | 计算设备 |
|
||
|
||
#### 方法
|
||
|
||
**`select_action(obs_s, obs_b, explore=True) -> (act_s, act_b)`**
|
||
|
||
| 参数 | 说明 |
|
||
|---|---|
|
||
| `obs_s`, `obs_b` | np.ndarray (obs_dim,) — 各 agent 观察 |
|
||
| `explore` | bool — 是否添加 OU 噪声 |
|
||
| **返回** | tuple(np.ndarray, np.ndarray) — 动作 ∈ [0, 1]³ |
|
||
|
||
---
|
||
|
||
**`compute_rewards(info) -> (rew_s, rew_b)`**
|
||
|
||
根据 info 字典计算混合奖励。内部更新 `self.current_lambda`。
|
||
|
||
| 参数 | 说明 |
|
||
|---|---|
|
||
| `info` | dict — 来自 env.step() |
|
||
| **返回** | tuple(float, float) — 混合奖励 |
|
||
|
||
奖励公式:
|
||
```
|
||
r_coop_i = coop_self·qoe_i + coop_other·qoe_j + coop_sys·qoe_sys
|
||
r_comp_i = comp_self·qoe_i + comp_sys·qoe_sys
|
||
r_i = λ·r_coop_i + (1-λ)·r_comp_i
|
||
```
|
||
|
||
---
|
||
|
||
**`update() -> dict`**
|
||
|
||
执行 Stackelberg 更新。返回 loss 字典。
|
||
|
||
| 返回 key | 说明 |
|
||
|---|---|
|
||
| `critic_loss_b` | Follower Critic 损失 |
|
||
| `actor_loss_b` | Follower Actor 损失 |
|
||
| `critic_loss_s` | Leader Critic 损失 |
|
||
| `actor_loss_s` | Leader Actor 损失 |
|
||
| `lambda` | 当前 λ(t) |
|
||
|
||
---
|
||
|
||
**`save(path)` / `load(path)`**
|
||
|
||
保存/加载所有网络参数到指定目录。
|
||
|
||
| 文件 | 内容 |
|
||
|---|---|
|
||
| `model_s.pth` | Actor S + Critic S + 对应 Target 网络 |
|
||
| `model_b.pth` | Actor B + Critic B + 对应 Target 网络 |
|
||
|
||
---
|
||
|
||
## 3. 基线模块 baselines/
|
||
|
||
### 通用接口 / Common Interface
|
||
|
||
所有 7 个基线实现与 CoMADDPG 相同的接口:
|
||
|
||
```python
|
||
def __init__(self, config: dict)
|
||
def select_action(obs_s, obs_b, explore=True) -> (act_s, act_b)
|
||
def compute_rewards(info) -> (rew_s, rew_b)
|
||
def update() -> dict or None
|
||
def save(path)
|
||
def load(path)
|
||
|
||
# 属性
|
||
self.buffer: ReplayBuffer # 或等效
|
||
self.noise_s: OUNoise # 部分基线有 (用于 train.py 的 hasattr 检查)
|
||
self.noise_b: OUNoise
|
||
```
|
||
|
||
### 各基线差异 / Baseline Differences
|
||
|
||
| 基线类 | 文件 | λ | 更新方式 | Critic | 特殊类 |
|
||
|---|---|---|---|---|---|
|
||
| `PureCooperative` | `pure_coop.py` | 1.0 固定 | Simultaneous | Joint | — |
|
||
| `PureCompetitive` | `pure_comp.py` | 0.0 固定 | Simultaneous | Joint | — |
|
||
| `FixedLambda` | `fixed_lambda.py` | 0.5 固定 | Stackelberg | Joint | — |
|
||
| `IndependentDDPG` | `iddpg.py` | 0.0 | Simultaneous | Independent | `IndependentCritic` |
|
||
| `SingleAgentDQN` | `single_dqn.py` | 0.5 | N/A (集中) | Centralized | `DQNNet`, `DQNReplayBuffer`, `EpsilonAdapter` |
|
||
| `EqualAllocation` | `equal_alloc.py` | 0.5 | N/A (无学习) | None | `DummyBuffer` |
|
||
| `SemanticOnly` | `semantic_only.py` | 1.0 | N/A (单策略) | Single | `SemanticCritic`, `SemanticBuffer` |
|
||
|
||
#### 特殊说明
|
||
|
||
**SingleAgentDQN**: 48 个离散动作 = 4 (sub_levels) × 4 (power_levels) × 3 (rho_levels)。使用 `EpsilonAdapter` 适配 `noise_s.decay_sigma()` 接口。
|
||
|
||
**EqualAllocation**: 无学习,永远输出 `[0.5, 0.5, 0.5]`。`DummyBuffer` 有 `push()` 和 `__len__()` 但不存储数据。
|
||
|
||
**IndependentDDPG**: `IndependentCritic` 输入为单个 agent 的 `(obs, act)` 而非联合输入,消融 CTDE 的效果。
|
||
|
||
---
|
||
|
||
## 4. 工具模块 utils/
|
||
|
||
### metrics.py
|
||
|
||
**文件**: `utils/metrics.py`
|
||
|
||
#### 函数
|
||
|
||
**`jain_fairness(values) -> float`**
|
||
|
||
Jain 公平性指数。`J = (Σx_i)² / (n·Σx_i²)`, 范围 [1/n, 1]。
|
||
|
||
---
|
||
|
||
**`rate_satisfaction(rates, min_rate) -> float`**
|
||
|
||
速率满足比例。满足 `R_k ≥ R_req` 的用户占比。
|
||
|
||
---
|
||
|
||
**`compute_system_qoe(qoe_list) -> float`**
|
||
|
||
系统级 QoE = 所有用户 QoE 的均值。
|
||
|
||
---
|
||
|
||
**`compute_lambda(qoe_sys, beta=5.0, q_threshold=0.6) -> float`**
|
||
|
||
动态协作权重。`λ = 1 / (1 + exp(-β·(QoE_sys - Q_th)))`
|
||
|
||
---
|
||
|
||
**`compute_mixed_reward(qoe_s, qoe_b, qoe_sys, lam, reward_config) -> (float, float)`**
|
||
|
||
计算混合奖励。`r_i = λ·r_coop_i + (1-λ)·r_comp_i`
|
||
|
||
---
|
||
|
||
**`moving_average(data, window) -> np.ndarray`**
|
||
|
||
滑动平均平滑。
|
||
|
||
---
|
||
|
||
### visualization.py
|
||
|
||
**文件**: `utils/visualization.py`
|
||
|
||
IEEE 风格绘图工具,对应论文 Section VII 的 12 张图。
|
||
|
||
```python
|
||
class Plotter:
|
||
def __init__(self, save_dir: str = "results/figures")
|
||
```
|
||
|
||
#### ALGO_STYLES
|
||
|
||
内置样式字典,为 8 个算法分配颜色、标记、线型:
|
||
|
||
```python
|
||
ALGO_STYLES = {
|
||
"Co-MADDPG": {"color": "#E74C3C", "marker": "o", "linestyle": "-"},
|
||
"PureCooperative": {"color": "#3498DB", "marker": "s", "linestyle": "--"},
|
||
"PureCompetitive": {"color": "#2ECC71", "marker": "^", "linestyle": "--"},
|
||
...
|
||
}
|
||
```
|
||
|
||
#### 绘图方法
|
||
|
||
| 方法 | 对应图表 | 参数 |
|
||
|---|---|---|
|
||
| `plot_convergence(data)` | Fig.2 | `{algo: [episode_rewards]}` |
|
||
| `plot_qoe_vs_snr(data)` | Fig.3 | `{algo: {snr: qoe}}` |
|
||
| `plot_fairness_vs_snr(data)` | Fig.4 | `{algo: {snr: fairness}}` |
|
||
| `plot_qoe_vs_users(data)` | Fig.5 | `{algo: {n_users: qoe}}` |
|
||
| `plot_rate_satisfaction(data)` | Fig.6 | `{algo: {n_users: rate_sat}}` |
|
||
| `plot_lambda_trajectory(lambdas)` | Fig.7 | `[λ_1, λ_2, ...]` |
|
||
| `plot_lambda_qoe_scatter(lambdas, qoes)` | Fig.8 | 两个等长列表 |
|
||
| `plot_qoe_vs_semantic_ratio(data)` | Fig.9 | `{algo: {ratio: qoe}}` |
|
||
| `plot_ablation(data)` | Fig.10 | `{algo: qoe_mean}` |
|
||
| `plot_beta_sensitivity(data)` | Fig.11 | `{beta: qoe}` |
|
||
| `plot_qth_sensitivity(data)` | Fig.12 | `{qth: qoe}` |
|
||
|
||
所有方法自动保存 PNG (300 DPI) 到 `save_dir`。
|
||
|
||
---
|
||
|
||
## 5. 入口脚本 / Entry Scripts
|
||
|
||
### train.py
|
||
|
||
训练入口,支持 CLI 参数。
|
||
|
||
```bash
|
||
python train.py [--algo ALGO] [--config PATH] [--episodes N] [--steps N] [--seed N]
|
||
```
|
||
|
||
| 参数 | 默认值 | 说明 |
|
||
|---|---|---|
|
||
| `--algo` | `co_maddpg` | 算法名 (`co_maddpg`, `pure_coop`, `all`, 等) |
|
||
| `--config` | `configs/default.yaml` | 配置文件路径 |
|
||
| `--episodes` | 从配置读取 (5000) | 训练轮数 |
|
||
| `--steps` | 从配置读取 (200) | 每轮步数 |
|
||
| `--seed` | 从配置读取 (42) | 随机种子 |
|
||
|
||
**关键函数:**
|
||
|
||
- `load_config(path)` — 加载 YAML
|
||
- `get_algorithm(name, config)` — 工厂函数,返回算法实例
|
||
- `train_single(algo_name, config)` — 训练单个算法
|
||
- `train_all(config)` — 训练全部 8 个算法
|
||
|
||
---
|
||
|
||
### evaluate.py
|
||
|
||
评估入口,运行 8 个场景生成 12+ 张图。
|
||
|
||
```bash
|
||
python evaluate.py [--results_dir PATH] [--config PATH]
|
||
```
|
||
|
||
**8 个评估场景:**
|
||
|
||
| # | 函数 | 说明 |
|
||
|---|---|---|
|
||
| 1 | `scenario_convergence()` | 绘制训练收敛曲线 |
|
||
| 2 | `scenario_qoe_vs_snr()` | 扫描 SNR (通过调节 noise_psd) |
|
||
| 3 | `scenario_fairness_vs_snr()` | 不同 SNR 下的公平性 |
|
||
| 4 | `scenario_qoe_vs_users()` | 扫描用户数量 |
|
||
| 5 | `scenario_rate_satisfaction()` | 不同用户数下的速率满足度 |
|
||
| 6 | `scenario_lambda_dynamics()` | λ(t) 时间演化 |
|
||
| 7 | `scenario_ablation()` | 消融实验对比 |
|
||
| 8 | `scenario_sensitivity()` | β 和 Q_th 敏感性 |
|
||
|
||
---
|
||
|
||
## 类型约定 / Type Conventions
|
||
|
||
| 约定 | 说明 |
|
||
|---|---|
|
||
| 所有观察/动作 | numpy float32 |
|
||
| 神经网络输入 | torch.FloatTensor (自动转换) |
|
||
| 配置参数 | 从 YAML 加载,保持原始类型 |
|
||
| 奖励/QoE | Python float |
|
||
| 信道增益 | numpy complex128 |
|
||
| 布尔 done | Python bool |
|