requirements modified
This commit is contained in:
parent
8b79e816d1
commit
2c1a6ca92e
@ -16,7 +16,7 @@ pip install requirements.txt
|
||||
Run(example)
|
||||
```
|
||||
cd ./Deep-JSCC-PyTorch
|
||||
python train.py --seed 2048 --epochs 200 --batch_size 256 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12]
|
||||
python train.py --seed 2048 --epochs 200 --batch_size 256 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12] --dataset imagenet
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
13
eval.py
13
eval.py
@ -1 +1,14 @@
|
||||
# to be implemented
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def config_parser():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
||||
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
||||
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
|
||||
parser.add_argument('--demo_image', default='./demo/kodim08.png', type=str, help='demo_image')
|
||||
return parser.parse_args()
|
||||
|
||||
@ -3,4 +3,4 @@ torchvison
|
||||
matplotlib
|
||||
tqdm
|
||||
numpy
|
||||
fraction
|
||||
pillow
|
||||
@ -11,3 +11,8 @@ def image_normalization(norm_type):
|
||||
else:
|
||||
raise Exception('Unknown type of normalization')
|
||||
return _inner
|
||||
|
||||
def get_psnr(image,gt,max=255):
|
||||
psnr = 10 * torch.log10(max**2 / torch.mean((image - gt)**2))
|
||||
return psnr
|
||||
|
||||
29
train.py
29
train.py
@ -22,13 +22,14 @@ def config_parser():
|
||||
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
|
||||
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
|
||||
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
|
||||
parser.add_argument('--batch_size', default=64, type=int, help='batch size')
|
||||
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
|
||||
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
|
||||
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
||||
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
||||
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
|
||||
parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list')
|
||||
parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], type=list, help='ratio_list')
|
||||
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
||||
parser.add_argument('--dataset', default='imagenet', type=str, help='dataset')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -36,16 +37,16 @@ def main():
|
||||
args = config_parser()
|
||||
|
||||
print("Training Start")
|
||||
# for ratio in args.ratio_list:
|
||||
# for snr in args.snr_list:
|
||||
# train(args, ratio, snr)
|
||||
train(args, 1/6, 20)
|
||||
for ratio in args.ratio_list:
|
||||
for snr in args.snr_list:
|
||||
train(args, ratio, snr)
|
||||
|
||||
|
||||
def train(args: config_parser(), ratio: float, snr: float):
|
||||
|
||||
device = torch.device('cuda')
|
||||
# load data
|
||||
if args.dataset == 'cifar10':
|
||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
|
||||
download=True, transform=transform)
|
||||
@ -55,6 +56,18 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
|
||||
download=True, transform=transform)
|
||||
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
|
||||
elif args.dataset == 'imagenet':
|
||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||
train_dataset = datasets.ImageNet(root='./Dataset/', train=True,
|
||||
download=True, transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
test_dataset = datasets.ImageNet(root='./Dataset/', train=False,
|
||||
download=True, transform=transform)
|
||||
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
|
||||
else:
|
||||
raise Exception('Unknown dataset')
|
||||
|
||||
print("training with ratio: {:2f}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
|
||||
|
||||
@ -79,12 +92,12 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
|
||||
epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]')
|
||||
epoch_loop.set_postfix(loss=run_loss/len(train_loader))
|
||||
save_model(model, args.saved + '/model_{:2f}_{:2f}.pth'.format(ratio, snr))
|
||||
save_model(model, args.saved + '/model{}_{:2f}_{:2f}.pth'.format(args.dataset, ratio, snr))
|
||||
|
||||
|
||||
def save_model(model, path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torch.save(model.state_dict(), path)
|
||||
torch.save(model, path)
|
||||
print("Model saved in {}".format(path))
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user