utils.py added

This commit is contained in:
chun 2023-12-23 13:25:39 +08:00
parent 2c1a6ca92e
commit acd646a9ea
6 changed files with 87 additions and 33 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
test.py test.py
*.pyc *.pyc
*.log *.log
Dataset/*

View File

@ -13,12 +13,18 @@ pip install requirements.txt
## Usage ## Usage
### Training Model ### Training Model
Run(example) Run(example presented in paper)
``` ```
cd ./Deep-JSCC-PyTorch cd ./Deep-JSCC-PyTorch
python train.py --seed 2048 --epochs 200 --batch_size 256 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12] --dataset imagenet
``` ```
```
python train.py --lr 10e-4 --epochs 100 --batch_size 32 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12] --dataset imagenet
```
or
```
python train.py --lr 10e-3 --epochs 100 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12] --dataset cifar10
```
### Evaluation ### Evaluation

29
eval.py
View File

@ -2,13 +2,36 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from torchvision import transforms
from utils import get_psnr
def config_parser(): def config_parser():
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--channel', default='AWGN', type=str, help='channel type') parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path') parser.add_argument('--saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list') parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
parser.add_argument('--demo_image', default='./demo/kodim08.png', type=str, help='demo_image') parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
parser.add_argument('--times', default=100, type=int, help='num_workers')
return parser.parse_args() return parser.parse_args()
def main():
args = config_parser()
transform = transforms.Compose([transforms.ToTensor(), ])
test_image = Image.open(args.test_image)
test_image.load()
test_image = transform(test_image)
model = torch.load(args.saved)
psnr_all = 0.0
for i in range(args.times):
demo_image = model(test_image)
psnr_all += get_psnr(demo_image, test_image)
demo_image = torch.cat([test_image, demo_image], dim=1)
demo_image = transforms.ToPILImage()(demo_image)
demo_image.save('./demo/demo.png')
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
if __name__ == '__main__':
main()

View File

@ -1,18 +0,0 @@
import torch
import torch.nn as nn
def image_normalization(norm_type):
def _inner(tensor: torch.Tensor):
if norm_type == 'nomalization':
return tensor / 255.0
elif norm_type == 'denormalization':
return (tensor * 255.0).type(torch.FloatTensor)
else:
raise Exception('Unknown type of normalization')
return _inner
def get_psnr(image,gt,max=255):
psnr = 10 * torch.log10(max**2 / torch.mean((image - gt)**2))
return psnr

View File

@ -14,6 +14,7 @@ import torch.optim as optim
from tqdm import tqdm from tqdm import tqdm
from model import DeepJSCC, ratio2filtersize from model import DeepJSCC, ratio2filtersize
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
from utils import image_normalization
def config_parser(): def config_parser():
@ -24,12 +25,14 @@ def config_parser():
parser.add_argument('--epochs', default=100, type=int, help='number of epochs') parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
parser.add_argument('--batch_size', default=256, type=int, help='batch size') parser.add_argument('--batch_size', default=256, type=int, help='batch size')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay') parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str, help='channel type') parser.add_argument('--channel', default='AWGN', type=str,
choices=['AWGN', 'Rayleigh'], help='channel type')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path') parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list') parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], type=list, help='ratio_list') parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], type=list, help='ratio_list')
parser.add_argument('--num_workers', default=0, type=int, help='num_workers') parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
parser.add_argument('--dataset', default='imagenet', type=str, help='dataset') parser.add_argument('--dataset', default='cifar10', type=str,
choices=['cifar10', 'imagenet'], help='dataset')
return parser.parse_args() return parser.parse_args()
@ -55,7 +58,8 @@ def train(args: config_parser(), ratio: float, snr: float):
batch_size=args.batch_size, num_workers=args.num_workers) batch_size=args.batch_size, num_workers=args.num_workers)
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False, test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
download=True, transform=transform) download=True, transform=transform)
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size) test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
elif args.dataset == 'imagenet': elif args.dataset == 'imagenet':
transform = transforms.Compose([transforms.ToTensor(), ]) transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.ImageNet(root='./Dataset/', train=True, train_dataset = datasets.ImageNet(root='./Dataset/', train=True,
@ -65,7 +69,8 @@ def train(args: config_parser(), ratio: float, snr: float):
batch_size=args.batch_size, num_workers=args.num_workers) batch_size=args.batch_size, num_workers=args.num_workers)
test_dataset = datasets.ImageNet(root='./Dataset/', train=False, test_dataset = datasets.ImageNet(root='./Dataset/', train=False,
download=True, transform=transform) download=True, transform=transform)
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size) test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
else: else:
raise Exception('Unknown dataset') raise Exception('Unknown dataset')
@ -75,7 +80,7 @@ def train(args: config_parser(), ratio: float, snr: float):
c = ratio2filtersize(image_fisrt, ratio) c = ratio2filtersize(image_fisrt, ratio)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
criterion = nn.MSELoss(reduction='sum').cuda(device=device) criterion = nn.MSELoss(reduction='mean').cuda(device=device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
@ -85,13 +90,23 @@ def train(args: config_parser(), ratio: float, snr: float):
optimizer.zero_grad() optimizer.zero_grad()
images = images.cuda(device=device) images = images.cuda(device=device)
outputs = model(images) outputs = model(images)
loss = criterion(outputs, images) / args.batch_size loss = criterion(image_normalization('denormalization')(outputs),
image_normalization('denormalization')(images))
loss.backward() loss.backward()
optimizer.step() optimizer.step()
run_loss += loss.item() run_loss += loss.item()
with torch.no_grad():
epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]') model.eval()
epoch_loop.set_postfix(loss=run_loss/len(train_loader)) test_mse = 0.0
for images, _ in tqdm((test_loader), leave=False):
images = images.cuda(device=device)
outputs = model(images)
images = image_normalization('normalization')(images)
outputs = image_normalization('normalization')(outputs)
loss = criterion(outputs, images)
test_mse += loss.item()
model.train()
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
save_model(model, args.saved + '/model{}_{:2f}_{:2f}.pth'.format(args.dataset, ratio, snr)) save_model(model, args.saved + '/model{}_{:2f}_{:2f}.pth'.format(args.dataset, ratio, snr))

27
utils.py Normal file
View File

@ -0,0 +1,27 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def image_normalization(norm_type):
def _inner(tensor: torch.Tensor):
if norm_type == 'nomalization':
return tensor / 255.0
elif norm_type == 'denormalization':
return tensor * 255.0
else:
raise Exception('Unknown type of normalization')
return _inner
def get_psnr(image, gt, max=255):
image = image_normalization('denormalization')(image)
gt = image_normalization('denormalization')(gt)
mse = F.mse_loss(image, gt)
psnr = 10 * torch.log10(max**2 / mse)
return psnr
a = torch.randn(2, 3, 32, 32)
b = image_normalization('nomalization')(a)