utils.py added
This commit is contained in:
parent
2c1a6ca92e
commit
acd646a9ea
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
test.py
|
test.py
|
||||||
*.pyc
|
*.pyc
|
||||||
*.log
|
*.log
|
||||||
|
Dataset/*
|
||||||
10
README.md
10
README.md
@ -13,12 +13,18 @@ pip install requirements.txt
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
### Training Model
|
### Training Model
|
||||||
Run(example)
|
Run(example presented in paper)
|
||||||
```
|
```
|
||||||
cd ./Deep-JSCC-PyTorch
|
cd ./Deep-JSCC-PyTorch
|
||||||
python train.py --seed 2048 --epochs 200 --batch_size 256 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12] --dataset imagenet
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python train.py --lr 10e-4 --epochs 100 --batch_size 32 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12] --dataset imagenet
|
||||||
|
```
|
||||||
|
or
|
||||||
|
```
|
||||||
|
python train.py --lr 10e-3 --epochs 100 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12] --dataset cifar10
|
||||||
|
```
|
||||||
### Evaluation
|
### Evaluation
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
29
eval.py
29
eval.py
@ -2,13 +2,36 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from utils import get_psnr
|
||||||
|
|
||||||
def config_parser():
|
def config_parser():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
||||||
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
parser.add_argument('--saved', type=str, help='saved_path')
|
||||||
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
|
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
|
||||||
parser.add_argument('--demo_image', default='./demo/kodim08.png', type=str, help='demo_image')
|
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
|
||||||
|
parser.add_argument('--times', default=100, type=int, help='num_workers')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = config_parser()
|
||||||
|
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||||
|
|
||||||
|
test_image = Image.open(args.test_image)
|
||||||
|
test_image.load()
|
||||||
|
test_image = transform(test_image)
|
||||||
|
model = torch.load(args.saved)
|
||||||
|
psnr_all = 0.0
|
||||||
|
for i in range(args.times):
|
||||||
|
demo_image = model(test_image)
|
||||||
|
psnr_all += get_psnr(demo_image, test_image)
|
||||||
|
demo_image = torch.cat([test_image, demo_image], dim=1)
|
||||||
|
demo_image = transforms.ToPILImage()(demo_image)
|
||||||
|
demo_image.save('./demo/demo.png')
|
||||||
|
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|||||||
18
scripts.py
18
scripts.py
@ -1,18 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def image_normalization(norm_type):
|
|
||||||
def _inner(tensor: torch.Tensor):
|
|
||||||
if norm_type == 'nomalization':
|
|
||||||
return tensor / 255.0
|
|
||||||
elif norm_type == 'denormalization':
|
|
||||||
return (tensor * 255.0).type(torch.FloatTensor)
|
|
||||||
else:
|
|
||||||
raise Exception('Unknown type of normalization')
|
|
||||||
return _inner
|
|
||||||
|
|
||||||
def get_psnr(image,gt,max=255):
|
|
||||||
psnr = 10 * torch.log10(max**2 / torch.mean((image - gt)**2))
|
|
||||||
return psnr
|
|
||||||
|
|
||||||
33
train.py
33
train.py
@ -14,6 +14,7 @@ import torch.optim as optim
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from model import DeepJSCC, ratio2filtersize
|
from model import DeepJSCC, ratio2filtersize
|
||||||
from torch.nn.parallel import DataParallel
|
from torch.nn.parallel import DataParallel
|
||||||
|
from utils import image_normalization
|
||||||
|
|
||||||
|
|
||||||
def config_parser():
|
def config_parser():
|
||||||
@ -24,12 +25,14 @@ def config_parser():
|
|||||||
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
|
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
|
||||||
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
|
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
|
||||||
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
|
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
|
||||||
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
parser.add_argument('--channel', default='AWGN', type=str,
|
||||||
|
choices=['AWGN', 'Rayleigh'], help='channel type')
|
||||||
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
||||||
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
|
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
|
||||||
parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], type=list, help='ratio_list')
|
parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], type=list, help='ratio_list')
|
||||||
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
||||||
parser.add_argument('--dataset', default='imagenet', type=str, help='dataset')
|
parser.add_argument('--dataset', default='cifar10', type=str,
|
||||||
|
choices=['cifar10', 'imagenet'], help='dataset')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -55,7 +58,8 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
|
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
|
||||||
download=True, transform=transform)
|
download=True, transform=transform)
|
||||||
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||||
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
elif args.dataset == 'imagenet':
|
elif args.dataset == 'imagenet':
|
||||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||||
train_dataset = datasets.ImageNet(root='./Dataset/', train=True,
|
train_dataset = datasets.ImageNet(root='./Dataset/', train=True,
|
||||||
@ -65,7 +69,8 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
test_dataset = datasets.ImageNet(root='./Dataset/', train=False,
|
test_dataset = datasets.ImageNet(root='./Dataset/', train=False,
|
||||||
download=True, transform=transform)
|
download=True, transform=transform)
|
||||||
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||||
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
else:
|
else:
|
||||||
raise Exception('Unknown dataset')
|
raise Exception('Unknown dataset')
|
||||||
|
|
||||||
@ -75,7 +80,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
c = ratio2filtersize(image_fisrt, ratio)
|
c = ratio2filtersize(image_fisrt, ratio)
|
||||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
|
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
|
||||||
|
|
||||||
criterion = nn.MSELoss(reduction='sum').cuda(device=device)
|
criterion = nn.MSELoss(reduction='mean').cuda(device=device)
|
||||||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
|
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
|
||||||
|
|
||||||
@ -85,13 +90,23 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
images = images.cuda(device=device)
|
images = images.cuda(device=device)
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
loss = criterion(outputs, images) / args.batch_size
|
loss = criterion(image_normalization('denormalization')(outputs),
|
||||||
|
image_normalization('denormalization')(images))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
run_loss += loss.item()
|
run_loss += loss.item()
|
||||||
|
with torch.no_grad():
|
||||||
epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]')
|
model.eval()
|
||||||
epoch_loop.set_postfix(loss=run_loss/len(train_loader))
|
test_mse = 0.0
|
||||||
|
for images, _ in tqdm((test_loader), leave=False):
|
||||||
|
images = images.cuda(device=device)
|
||||||
|
outputs = model(images)
|
||||||
|
images = image_normalization('normalization')(images)
|
||||||
|
outputs = image_normalization('normalization')(outputs)
|
||||||
|
loss = criterion(outputs, images)
|
||||||
|
test_mse += loss.item()
|
||||||
|
model.train()
|
||||||
|
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
||||||
save_model(model, args.saved + '/model{}_{:2f}_{:2f}.pth'.format(args.dataset, ratio, snr))
|
save_model(model, args.saved + '/model{}_{:2f}_{:2f}.pth'.format(args.dataset, ratio, snr))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
27
utils.py
Normal file
27
utils.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def image_normalization(norm_type):
|
||||||
|
def _inner(tensor: torch.Tensor):
|
||||||
|
if norm_type == 'nomalization':
|
||||||
|
return tensor / 255.0
|
||||||
|
elif norm_type == 'denormalization':
|
||||||
|
return tensor * 255.0
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown type of normalization')
|
||||||
|
return _inner
|
||||||
|
|
||||||
|
|
||||||
|
def get_psnr(image, gt, max=255):
|
||||||
|
image = image_normalization('denormalization')(image)
|
||||||
|
gt = image_normalization('denormalization')(gt)
|
||||||
|
mse = F.mse_loss(image, gt)
|
||||||
|
|
||||||
|
psnr = 10 * torch.log10(max**2 / mse)
|
||||||
|
return psnr
|
||||||
|
|
||||||
|
|
||||||
|
a = torch.randn(2, 3, 32, 32)
|
||||||
|
b = image_normalization('nomalization')(a)
|
||||||
Loading…
Reference in New Issue
Block a user