From 2c1a6ca92e7aa449e6761699ca5534c1a0ef8c61 Mon Sep 17 00:00:00 2001 From: chun Date: Sat, 23 Dec 2023 10:48:54 +0800 Subject: [PATCH] requirements modified --- README.md | 2 +- eval.py | 13 +++++++++++++ requirements.txt | 2 +- scripts.py | 5 +++++ train.py | 45 +++++++++++++++++++++++++++++---------------- 5 files changed, 49 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index c0cfeac..14a9050 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ pip install requirements.txt Run(example) ``` cd ./Deep-JSCC-PyTorch -python train.py --seed 2048 --epochs 200 --batch_size 256 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12] +python train.py --seed 2048 --epochs 200 --batch_size 256 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12] --dataset imagenet ``` ### Evaluation diff --git a/eval.py b/eval.py index 2624b4c..4f33592 100644 --- a/eval.py +++ b/eval.py @@ -1 +1,14 @@ # to be implemented +import torch +import torch.nn as nn +from PIL import Image + + +def config_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--channel', default='AWGN', type=str, help='channel type') + parser.add_argument('--saved', default='./saved', type=str, help='saved_path') + parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list') + parser.add_argument('--demo_image', default='./demo/kodim08.png', type=str, help='demo_image') + return parser.parse_args() diff --git a/requirements.txt b/requirements.txt index 617f283..e5e287d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ torchvison matplotlib tqdm numpy -fraction \ No newline at end of file +pillow \ No newline at end of file diff --git a/scripts.py b/scripts.py index 310c2db..05a01c7 100644 --- a/scripts.py +++ b/scripts.py @@ -11,3 +11,8 @@ def image_normalization(norm_type): else: raise Exception('Unknown type of normalization') return _inner + +def get_psnr(image,gt,max=255): + psnr = 10 * torch.log10(max**2 / torch.mean((image - gt)**2)) + return psnr + \ No newline at end of file diff --git a/train.py b/train.py index 9257d74..ae5a0d9 100644 --- a/train.py +++ b/train.py @@ -22,13 +22,14 @@ def config_parser(): parser.add_argument('--seed', default=2048, type=int, help='Random seed') parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') parser.add_argument('--epochs', default=100, type=int, help='number of epochs') - parser.add_argument('--batch_size', default=64, type=int, help='batch size') + parser.add_argument('--batch_size', default=256, type=int, help='batch size') parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay') parser.add_argument('--channel', default='AWGN', type=str, help='channel type') parser.add_argument('--saved', default='./saved', type=str, help='saved_path') parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list') - parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list') + parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], type=list, help='ratio_list') parser.add_argument('--num_workers', default=0, type=int, help='num_workers') + parser.add_argument('--dataset', default='imagenet', type=str, help='dataset') return parser.parse_args() @@ -36,25 +37,37 @@ def main(): args = config_parser() print("Training Start") - # for ratio in args.ratio_list: - # for snr in args.snr_list: - # train(args, ratio, snr) - train(args, 1/6, 20) + for ratio in args.ratio_list: + for snr in args.snr_list: + train(args, ratio, snr) def train(args: config_parser(), ratio: float, snr: float): device = torch.device('cuda') # load data - transform = transforms.Compose([transforms.ToTensor(), ]) - train_dataset = datasets.CIFAR10(root='./Dataset/', train=True, - download=True, transform=transform) + if args.dataset == 'cifar10': + transform = transforms.Compose([transforms.ToTensor(), ]) + train_dataset = datasets.CIFAR10(root='./Dataset/', train=True, + download=True, transform=transform) - train_loader = DataLoader(train_dataset, shuffle=True, - batch_size=args.batch_size, num_workers=args.num_workers) - test_dataset = datasets.CIFAR10(root='./Dataset/', train=False, - download=True, transform=transform) - test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size) + train_loader = DataLoader(train_dataset, shuffle=True, + batch_size=args.batch_size, num_workers=args.num_workers) + test_dataset = datasets.CIFAR10(root='./Dataset/', train=False, + download=True, transform=transform) + test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size) + elif args.dataset == 'imagenet': + transform = transforms.Compose([transforms.ToTensor(), ]) + train_dataset = datasets.ImageNet(root='./Dataset/', train=True, + download=True, transform=transform) + + train_loader = DataLoader(train_dataset, shuffle=True, + batch_size=args.batch_size, num_workers=args.num_workers) + test_dataset = datasets.ImageNet(root='./Dataset/', train=False, + download=True, transform=transform) + test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size) + else: + raise Exception('Unknown dataset') print("training with ratio: {:2f}, snr_db: {}, channel: {}".format(ratio, snr, args.channel)) @@ -79,12 +92,12 @@ def train(args: config_parser(), ratio: float, snr: float): epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]') epoch_loop.set_postfix(loss=run_loss/len(train_loader)) - save_model(model, args.saved + '/model_{:2f}_{:2f}.pth'.format(ratio, snr)) + save_model(model, args.saved + '/model{}_{:2f}_{:2f}.pth'.format(args.dataset, ratio, snr)) def save_model(model, path): os.makedirs(path, exist_ok=True) - torch.save(model.state_dict(), path) + torch.save(model, path) print("Model saved in {}".format(path))