formatter modified
This commit is contained in:
parent
14cd4ed4e4
commit
c81101c269
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def channel(channel_type='AWGN', snr=20):
|
||||
def AWGN_channel(z_hat: torch.Tensor):
|
||||
if z_hat.dim() == 4:
|
||||
|
||||
1
model.py
1
model.py
@ -43,7 +43,6 @@ class _ConvWithPReLU(nn.Module):
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
||||
self.prelu = nn.PReLU()
|
||||
|
||||
|
||||
nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
3
train.py
3
train.py
@ -19,6 +19,7 @@ from fractions import Fraction
|
||||
from dataset import Vanilla
|
||||
import numpy as np
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
@ -55,7 +56,6 @@ def config_parser():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
args = config_parser()
|
||||
args.snr_list = list(map(float, args.snr_list))
|
||||
@ -66,6 +66,7 @@ def main():
|
||||
for snr in args.snr_list:
|
||||
train(args, ratio, snr)
|
||||
|
||||
|
||||
def train(args: config_parser(), ratio: float, snr: float):
|
||||
|
||||
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user