64 lines
3.2 KiB
Python
64 lines
3.2 KiB
Python
"""
|
|
Critic Network for Wireless Resource Allocation / 无线资源分配中的 Critic 网络
|
|
|
|
This file defines the Critic network architecture for the Co-MADDPG project.
|
|
The Critic estimates the joint Q-value based on the global observations and actions of all agents.
|
|
本文档定义了 Co-MADDPG 项目中的 Critic 网络架构。
|
|
Critic 网络基于所有智能体的全局观测和动作来估算联合 Q 值。
|
|
|
|
Network Architecture / 网络架构:
|
|
FC(obs_dim_total + act_dim_total \u2192 512 \u2192 512 \u2192 256 \u2192 1)
|
|
Input / 输入: Concatenated observations and actions / 拼接后的观测与动作
|
|
Reference / 参考文献: Section 3.2.1 Actor-Critic Structure in the project paper.
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class Critic(nn.Module):
|
|
"""
|
|
Critic network for assessing the value of joint actions given joint observations.
|
|
Critic 网络,用于在给定联合观测的情况下评估联合动作的价值。
|
|
|
|
Architecture / 架构: FC(obs_dim_total + act_dim_total \u2192 512 \u2192 512 \u2192 256 \u2192 1)
|
|
Paper Ref / 论文参考: Section 3.2.1 - Centralized Critic implementation.
|
|
|
|
Args / 参数:
|
|
obs_dim_total (int): Total dimension of concatenated observations. / 所有智能体拼接后的总观测维度。
|
|
act_dim_total (int): Total dimension of concatenated actions. / 所有智能体拼接后的总动作维度。
|
|
hidden_sizes (list): Sizes of the three hidden layers (default: [512, 512, 256]). / 三个隐藏层的维度(默认:[512, 512, 256])。
|
|
"""
|
|
def __init__(self, obs_dim_total, act_dim_total, hidden_sizes=[512, 512, 256]):
|
|
super(Critic, self).__init__()
|
|
|
|
# Ensure exactly 3 hidden layers as per model design / 确保按照模型设计包含恰好 3 个隐藏层
|
|
assert len(hidden_sizes) == 3, "Critic requires exactly 3 hidden layer sizes"
|
|
|
|
# Define the feedforward neural network / 定义前馈神经网络
|
|
# FC(obs_dim_total + act_dim_total \u2192 512 \u2192 512 \u2192 256 \u2192 1)
|
|
self.net = nn.Sequential(
|
|
nn.Linear(obs_dim_total + act_dim_total, hidden_sizes[0]),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_sizes[2], 1)
|
|
)
|
|
|
|
def forward(self, obs_all, act_all):
|
|
"""
|
|
Forward pass for the Critic network. / Critic 网络的前向传播。
|
|
|
|
Args / 参数:
|
|
obs_all (torch.Tensor): The concatenated joint observation tensor. / 拼接后的联合观测张量。
|
|
act_all (torch.Tensor): The concatenated joint action tensor. / 拼接后的联合动作张量。
|
|
|
|
Returns / 返回:
|
|
torch.Tensor: Scalar Q-value evaluation. / 标量 Q 值评估结果。
|
|
"""
|
|
# Formula / 公式: x = [obs_total, act_total]
|
|
# Concatenate joint states and actions together for input / 将联合状态和动作拼接作为输入
|
|
x = torch.cat([obs_all, act_all], dim=1)
|
|
# Pass the concatenated input through the network / 将拼接后的输入传入网络
|
|
return self.net(x)
|