formatter modified

This commit is contained in:
chun 2024-02-03 20:05:40 +08:00
parent 14cd4ed4e4
commit c81101c269
5 changed files with 13 additions and 12 deletions

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
def channel(channel_type='AWGN', snr=20):
def AWGN_channel(z_hat: torch.Tensor):
if z_hat.dim() == 4:

View File

@ -43,7 +43,6 @@ class _ConvWithPReLU(nn.Module):
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.prelu = nn.PReLU()
nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='leaky_relu')
def forward(self, x):

View File

@ -19,6 +19,7 @@ from fractions import Fraction
from dataset import Vanilla
import numpy as np
def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
@ -55,7 +56,6 @@ def config_parser():
return parser.parse_args()
def main():
args = config_parser()
args.snr_list = list(map(float, args.snr_list))
@ -66,6 +66,7 @@ def main():
for snr in args.snr_list:
train(args, ratio, snr)
def train(args: config_parser(), ratio: float, snr: float):
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')