From f35c7fe716af619bf14a5550bf5183830068b557 Mon Sep 17 00:00:00 2001 From: chun Date: Sat, 23 Dec 2023 13:40:56 +0800 Subject: [PATCH] v1.1 --- eval.py | 9 +++++++-- model.py | 3 +++ train.py | 30 +++++++++++++++--------------- utils.py | 3 +-- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/eval.py b/eval.py index 5719a3e..d54da77 100644 --- a/eval.py +++ b/eval.py @@ -3,14 +3,15 @@ import torch import torch.nn as nn from PIL import Image from torchvision import transforms -from utils import get_psnr +from utils import get_psnr, image_normalization + def config_parser(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--channel', default='AWGN', type=str, help='channel type') parser.add_argument('--saved', type=str, help='saved_path') - parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list') + parser.add_argument('--snr', default=20, type=int, help='snr') parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image') parser.add_argument('--times', default=100, type=int, help='num_workers') return parser.parse_args() @@ -24,14 +25,18 @@ def main(): test_image.load() test_image = transform(test_image) model = torch.load(args.saved) + model.change_channel(args.channel, args.snr) psnr_all = 0.0 for i in range(args.times): demo_image = model(test_image) + image = image_normalization('denormalization')(image) + gt = image_normalization('denormalization')(gt) psnr_all += get_psnr(demo_image, test_image) demo_image = torch.cat([test_image, demo_image], dim=1) demo_image = transforms.ToPILImage()(demo_image) demo_image.save('./demo/demo.png') print("psnr on {} is {}".format(args.test_image, psnr_all / args.times)) + if __name__ == '__main__': main() diff --git a/model.py b/model.py index 588814a..2b4c5ab 100644 --- a/model.py +++ b/model.py @@ -140,3 +140,6 @@ class DeepJSCC(nn.Module): z = self.channel(z) x_hat = self.decoder(z) return x_hat + + def change_channel(self, channel_type, snr): + self.channel = channel.channel(channel_type, snr) diff --git a/train.py b/train.py index dab7b3b..3e99b38 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn from torchvision import transforms from torchvision import datasets -from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data import DataLoader import torch.optim as optim from tqdm import tqdm from model import DeepJSCC, ratio2filtersize @@ -79,31 +79,31 @@ def train(args: config_parser(), ratio: float, snr: float): image_fisrt = train_dataset.__getitem__(0)[0] c = ratio2filtersize(image_fisrt, ratio) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device) - - criterion = nn.MSELoss(reduction='mean').cuda(device=device) - optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) - epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) + model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) + criterion=nn.MSELoss(reduction='mean').cuda(device=device) + optimizer=optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + epoch_loop=tqdm(range(args.epochs), total=args.epochs, leave=False) for epoch in epoch_loop: - run_loss = 0.0 + run_loss=0.0 for images, _ in tqdm((train_loader), leave=False): optimizer.zero_grad() - images = images.cuda(device=device) - outputs = model(images) - loss = criterion(image_normalization('denormalization')(outputs), + # images = images.cuda(device=device) + outputs=model(images) + loss=criterion(image_normalization('denormalization')(outputs), image_normalization('denormalization')(images)) loss.backward() optimizer.step() run_loss += loss.item() with torch.no_grad(): model.eval() - test_mse = 0.0 + test_mse=0.0 for images, _ in tqdm((test_loader), leave=False): - images = images.cuda(device=device) - outputs = model(images) - images = image_normalization('normalization')(images) - outputs = image_normalization('normalization')(outputs) - loss = criterion(outputs, images) + images=images.cuda(device=device) + outputs=model(images) + images=image_normalization('normalization')(images) + outputs=image_normalization('normalization')(outputs) + loss=criterion(outputs, images) test_mse += loss.item() model.train() epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader)) diff --git a/utils.py b/utils.py index c7cdbb1..cf27d61 100644 --- a/utils.py +++ b/utils.py @@ -15,8 +15,7 @@ def image_normalization(norm_type): def get_psnr(image, gt, max=255): - image = image_normalization('denormalization')(image) - gt = image_normalization('denormalization')(gt) + mse = F.mse_loss(image, gt) psnr = 10 * torch.log10(max**2 / mse)