From c81101c26951dce3cf73e9ea025ddae8958cd808 Mon Sep 17 00:00:00 2001 From: chun Date: Sat, 3 Feb 2024 20:05:40 +0800 Subject: [PATCH] formatter modified --- channel.py | 3 ++- dataset.py | 2 +- eval.py | 4 ++-- model.py | 7 +++---- train.py | 9 +++++---- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/channel.py b/channel.py index 8fb15ea..650020a 100644 --- a/channel.py +++ b/channel.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn + def channel(channel_type='AWGN', snr=20): def AWGN_channel(z_hat: torch.Tensor): if z_hat.dim() == 4: @@ -11,7 +12,7 @@ def channel(channel_type='AWGN', snr=20): # k = np.prod(z_hat.size()) k = torch.prod(torch.tensor(z_hat.size())) sig_pwr = torch.sum(torch.abs(z_hat).square())/k - noi_pwr = sig_pwr / ( 10 ** (snr / 10)) + noi_pwr = sig_pwr / (10 ** (snr / 10)) noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr) return z_hat + noise diff --git a/dataset.py b/dataset.py index 7e33f1a..7e6a85f 100644 --- a/dataset.py +++ b/dataset.py @@ -14,7 +14,7 @@ class Vanilla(Dataset): img = Image.open(img_path).convert('RGB') if self.transform is not None: img = self.transform(img) - return img, 0 # 0 is a fake label not important + return img, 0 # 0 is a fake label not important def __len__(self): return len(self.imgs) diff --git a/eval.py b/eval.py index d6fdd80..d5e4d51 100644 --- a/eval.py +++ b/eval.py @@ -34,7 +34,7 @@ def main(): model.change_channel(args.channel, args.snr) psnr_all = 0.0 - + for i in range(args.times): demo_image = model(test_image) demo_image = image_normalization('denormalization')(demo_image) @@ -43,7 +43,7 @@ def main(): demo_image = image_normalization('normalization')(demo_image) demo_image = torch.cat([test_image, demo_image], dim=1) demo_image = transforms.ToPILImage()(demo_image) - demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1],args.test_image.split('/')[-1])) + demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1], args.test_image.split('/')[-1])) print("psnr on {} is {}".format(args.test_image, psnr_all / args.times)) diff --git a/model.py b/model.py index 544f71d..da5f0ea 100644 --- a/model.py +++ b/model.py @@ -42,8 +42,7 @@ class _ConvWithPReLU(nn.Module): super(_ConvWithPReLU, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) self.prelu = nn.PReLU() - - + nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='leaky_relu') def forward(self, x): @@ -62,7 +61,7 @@ class _TransConvWithPReLU(nn.Module): nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out', nonlinearity='leaky_relu') else: nn.init.xavier_normal_(self.transconv.weight) - + def forward(self, x): x = self.transconv(x) x = self.activate(x) @@ -70,7 +69,7 @@ class _TransConvWithPReLU(nn.Module): class _Encoder(nn.Module): - def __init__(self, c=1, is_temp=False,P=1): + def __init__(self, c=1, is_temp=False, P=1): super(_Encoder, self).__init__() self.is_temp = is_temp # self.imgae_normalization = _image_normalization(norm_type='nomalization') diff --git a/train.py b/train.py index 17a6288..e74cebd 100644 --- a/train.py +++ b/train.py @@ -19,6 +19,7 @@ from fractions import Fraction from dataset import Vanilla import numpy as np + def set_seed(seed): np.random.seed(seed) torch.manual_seed(seed) @@ -55,7 +56,6 @@ def config_parser(): return parser.parse_args() - def main(): args = config_parser() args.snr_list = list(map(float, args.snr_list)) @@ -66,6 +66,7 @@ def main(): for snr in args.snr_list: train(args, ratio, snr) + def train(args: config_parser(), ratio: float, snr: float): device = torch.device(args.device if torch.cuda.is_available() else 'cpu') @@ -94,7 +95,7 @@ def train(args: config_parser(), ratio: float, snr: float): batch_size=args.batch_size, num_workers=args.num_workers) else: raise Exception('Unknown dataset') - + print(args) image_fisrt = train_dataset.__getitem__(0)[0] c = ratio2filtersize(image_fisrt, ratio) @@ -125,7 +126,7 @@ def train(args: config_parser(), ratio: float, snr: float): loss.backward() optimizer.step() run_loss += loss.item() - if args.if_scheduler: # the scheduler is wrong before + if args.if_scheduler: # the scheduler is wrong before scheduler.step() with torch.no_grad(): model.eval() @@ -142,7 +143,7 @@ def train(args: config_parser(), ratio: float, snr: float): print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format( epoch, run_loss/len(train_loader), test_mse/len(test_loader), optimizer.param_groups[0]['lr'])) save_model(model, args.saved, args.saved + - '/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, args.batch_size,c)) + '/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, args.batch_size, c)) def save_model(model, dir, path):