diff --git a/.gitignore b/.gitignore index 444aaee..6311b9c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ test.py *.pyc -*.log \ No newline at end of file +*.log +Dataset/* \ No newline at end of file diff --git a/README.md b/README.md index 14a9050..31ab06a 100644 --- a/README.md +++ b/README.md @@ -13,12 +13,18 @@ pip install requirements.txt ## Usage ### Training Model -Run(example) +Run(example presented in paper) ``` cd ./Deep-JSCC-PyTorch -python train.py --seed 2048 --epochs 200 --batch_size 256 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12] --dataset imagenet ``` +``` +python train.py --lr 10e-4 --epochs 100 --batch_size 32 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12] --dataset imagenet +``` +or +``` +python train.py --lr 10e-3 --epochs 100 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12] --dataset cifar10 +``` ### Evaluation diff --git a/eval.py b/eval.py index 4f33592..5719a3e 100644 --- a/eval.py +++ b/eval.py @@ -2,13 +2,36 @@ import torch import torch.nn as nn from PIL import Image - +from torchvision import transforms +from utils import get_psnr def config_parser(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--channel', default='AWGN', type=str, help='channel type') - parser.add_argument('--saved', default='./saved', type=str, help='saved_path') + parser.add_argument('--saved', type=str, help='saved_path') parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list') - parser.add_argument('--demo_image', default='./demo/kodim08.png', type=str, help='demo_image') + parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image') + parser.add_argument('--times', default=100, type=int, help='num_workers') return parser.parse_args() + + +def main(): + args = config_parser() + transform = transforms.Compose([transforms.ToTensor(), ]) + + test_image = Image.open(args.test_image) + test_image.load() + test_image = transform(test_image) + model = torch.load(args.saved) + psnr_all = 0.0 + for i in range(args.times): + demo_image = model(test_image) + psnr_all += get_psnr(demo_image, test_image) + demo_image = torch.cat([test_image, demo_image], dim=1) + demo_image = transforms.ToPILImage()(demo_image) + demo_image.save('./demo/demo.png') + print("psnr on {} is {}".format(args.test_image, psnr_all / args.times)) + +if __name__ == '__main__': + main() diff --git a/scripts.py b/scripts.py deleted file mode 100644 index 05a01c7..0000000 --- a/scripts.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch -import torch.nn as nn - - -def image_normalization(norm_type): - def _inner(tensor: torch.Tensor): - if norm_type == 'nomalization': - return tensor / 255.0 - elif norm_type == 'denormalization': - return (tensor * 255.0).type(torch.FloatTensor) - else: - raise Exception('Unknown type of normalization') - return _inner - -def get_psnr(image,gt,max=255): - psnr = 10 * torch.log10(max**2 / torch.mean((image - gt)**2)) - return psnr - \ No newline at end of file diff --git a/train.py b/train.py index ae5a0d9..dab7b3b 100644 --- a/train.py +++ b/train.py @@ -14,6 +14,7 @@ import torch.optim as optim from tqdm import tqdm from model import DeepJSCC, ratio2filtersize from torch.nn.parallel import DataParallel +from utils import image_normalization def config_parser(): @@ -24,12 +25,14 @@ def config_parser(): parser.add_argument('--epochs', default=100, type=int, help='number of epochs') parser.add_argument('--batch_size', default=256, type=int, help='batch size') parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay') - parser.add_argument('--channel', default='AWGN', type=str, help='channel type') + parser.add_argument('--channel', default='AWGN', type=str, + choices=['AWGN', 'Rayleigh'], help='channel type') parser.add_argument('--saved', default='./saved', type=str, help='saved_path') parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list') parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], type=list, help='ratio_list') parser.add_argument('--num_workers', default=0, type=int, help='num_workers') - parser.add_argument('--dataset', default='imagenet', type=str, help='dataset') + parser.add_argument('--dataset', default='cifar10', type=str, + choices=['cifar10', 'imagenet'], help='dataset') return parser.parse_args() @@ -55,7 +58,8 @@ def train(args: config_parser(), ratio: float, snr: float): batch_size=args.batch_size, num_workers=args.num_workers) test_dataset = datasets.CIFAR10(root='./Dataset/', train=False, download=True, transform=transform) - test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size) + test_loader = DataLoader(test_dataset, shuffle=True, + batch_size=args.batch_size, num_workers=args.num_workers) elif args.dataset == 'imagenet': transform = transforms.Compose([transforms.ToTensor(), ]) train_dataset = datasets.ImageNet(root='./Dataset/', train=True, @@ -65,7 +69,8 @@ def train(args: config_parser(), ratio: float, snr: float): batch_size=args.batch_size, num_workers=args.num_workers) test_dataset = datasets.ImageNet(root='./Dataset/', train=False, download=True, transform=transform) - test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size) + test_loader = DataLoader(test_dataset, shuffle=True, + batch_size=args.batch_size, num_workers=args.num_workers) else: raise Exception('Unknown dataset') @@ -75,7 +80,7 @@ def train(args: config_parser(), ratio: float, snr: float): c = ratio2filtersize(image_fisrt, ratio) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device) - criterion = nn.MSELoss(reduction='sum').cuda(device=device) + criterion = nn.MSELoss(reduction='mean').cuda(device=device) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) @@ -85,13 +90,23 @@ def train(args: config_parser(), ratio: float, snr: float): optimizer.zero_grad() images = images.cuda(device=device) outputs = model(images) - loss = criterion(outputs, images) / args.batch_size + loss = criterion(image_normalization('denormalization')(outputs), + image_normalization('denormalization')(images)) loss.backward() optimizer.step() run_loss += loss.item() - - epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]') - epoch_loop.set_postfix(loss=run_loss/len(train_loader)) + with torch.no_grad(): + model.eval() + test_mse = 0.0 + for images, _ in tqdm((test_loader), leave=False): + images = images.cuda(device=device) + outputs = model(images) + images = image_normalization('normalization')(images) + outputs = image_normalization('normalization')(outputs) + loss = criterion(outputs, images) + test_mse += loss.item() + model.train() + epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader)) save_model(model, args.saved + '/model{}_{:2f}_{:2f}.pth'.format(args.dataset, ratio, snr)) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..c7cdbb1 --- /dev/null +++ b/utils.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def image_normalization(norm_type): + def _inner(tensor: torch.Tensor): + if norm_type == 'nomalization': + return tensor / 255.0 + elif norm_type == 'denormalization': + return tensor * 255.0 + else: + raise Exception('Unknown type of normalization') + return _inner + + +def get_psnr(image, gt, max=255): + image = image_normalization('denormalization')(image) + gt = image_normalization('denormalization')(gt) + mse = F.mse_loss(image, gt) + + psnr = 10 * torch.log10(max**2 / mse) + return psnr + + +a = torch.randn(2, 3, 32, 32) +b = image_normalization('nomalization')(a)