JSCC/train.py
2024-02-03 20:05:40 +08:00

164 lines
7.2 KiB
Python

# -*- coding: utf-8 -*-
"""
Created on Tue Dec 17:00:00 2023
@author: chun
"""
import os
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
from model import DeepJSCC, ratio2filtersize
from torch.nn.parallel import DataParallel
from utils import image_normalization
from fractions import Fraction
from dataset import Vanilla
import numpy as np
def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def config_parser():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--epochs', default=256, type=int, help='number of epochs')
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str,
choices=['AWGN', 'Rayleigh'], help='channel type')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=['19', '13',
'7', '4', '1'], nargs='+', help='snr_list')
parser.add_argument('--ratio_list', default=['1/3',
'1/6', '1/12'], nargs='+', help='ratio_list')
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
parser.add_argument('--dataset', default='cifar10', type=str,
choices=['cifar10', 'imagenet'], help='dataset')
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler')
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
parser.add_argument('--device', default='cuda:0', type=str, help='device')
parser.add_argument('--gamma', default=0.5, type=float, help='gamma')
parser.add_argument('--disable_tqdm', default=True, type=bool, help='disable_tqdm')
return parser.parse_args()
def main():
args = config_parser()
args.snr_list = list(map(float, args.snr_list))
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
set_seed(args.seed)
print("Training Start")
for ratio in args.ratio_list:
for snr in args.snr_list:
train(args, ratio, snr)
def train(args: config_parser(), ratio: float, snr: float):
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
# load data
if args.dataset == 'cifar10':
transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='./dataset/', train=True,
download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
test_dataset = datasets.CIFAR10(root='./dataset/', train=False,
download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
elif args.dataset == 'imagenet':
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
print("loading data of imagenet")
train_dataset = datasets.ImageFolder(root='./dataset/ImageNet/train', transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
test_dataset = Vanilla(root='./dataset/ImageNet/val', transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
else:
raise Exception('Unknown dataset')
print(args)
image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio)
print("the inner channel is {}".format(c))
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
if args.parallel and torch.cuda.device_count() > 1:
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
model = model.cuda()
criterion = nn.MSELoss(reduction='mean').cuda()
else:
model = model.to(device)
criterion = nn.MSELoss(reduction='mean').to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.if_scheduler:
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=True, disable=args.disable_tqdm)
for epoch in epoch_loop:
run_loss = 0.0
for images, _ in tqdm((train_loader), leave=False, disable=args.disable_tqdm):
optimizer.zero_grad()
images = images.cuda() if args.parallel and torch.cuda.device_count() > 1 else images.to(device)
outputs = model(images)
outputs = image_normalization('denormalization')(outputs)
images = image_normalization('denormalization')(images)
loss = criterion(outputs, images)
loss.backward()
optimizer.step()
run_loss += loss.item()
if args.if_scheduler: # the scheduler is wrong before
scheduler.step()
with torch.no_grad():
model.eval()
test_mse = 0.0
for images, _ in tqdm((test_loader), leave=False, disable=args.disable_tqdm):
images = images.cuda() if args.parallel and torch.cuda.device_count() > 1 else images.to(device)
outputs = model(images)
images = image_normalization('denormalization')(images)
outputs = image_normalization('denormalization')(outputs)
loss = criterion(outputs, images)
test_mse += loss.item()
model.train()
# epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format(
epoch, run_loss/len(train_loader), test_mse/len(test_loader), optimizer.param_groups[0]['lr']))
save_model(model, args.saved, args.saved +
'/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, args.batch_size, c))
def save_model(model, dir, path):
os.makedirs(dir, exist_ok=True)
flag = 1
while True:
if os.path.exists(path):
path = path+'_'+str(flag)
flag += 1
else:
break
torch.save(model.state_dict(), path)
print("Model saved in {}".format(path))
if __name__ == '__main__':
main()