431 lines
18 KiB
Python
431 lines
18 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on Tue Dec 17:00:00 2023
|
|
|
|
@author: chun
|
|
"""
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
from torchvision import transforms
|
|
from torchvision import datasets
|
|
from torch.utils.data import DataLoader
|
|
import torch.optim as optim
|
|
from tqdm import tqdm
|
|
from model import DeepJSCC, ratio2filtersize
|
|
from torch.nn.parallel import DataParallel
|
|
from utils import image_normalization, set_seed, save_model, view_model_param
|
|
from fractions import Fraction
|
|
from dataset import Vanilla
|
|
import numpy as np
|
|
import time
|
|
from tensorboardX import SummaryWriter
|
|
import glob
|
|
|
|
|
|
def train_epoch(model, optimizer, param, data_loader):
|
|
model.train()
|
|
epoch_loss = 0
|
|
|
|
for iter, (images, _) in enumerate(data_loader):
|
|
images = images.cuda() if param['parallel'] and torch.cuda.device_count(
|
|
) > 1 else images.to(param['device'])
|
|
optimizer.zero_grad()
|
|
outputs = model.forward(images)
|
|
outputs = image_normalization('denormalization')(outputs)
|
|
images = image_normalization('denormalization')(images)
|
|
loss = model.loss(images, outputs) if not param['parallel'] else model.module.loss(
|
|
images, outputs)
|
|
loss.backward()
|
|
optimizer.step()
|
|
epoch_loss += loss.detach().item()
|
|
epoch_loss /= (iter + 1)
|
|
|
|
return epoch_loss, optimizer
|
|
|
|
|
|
def evaluate_epoch(model, param, data_loader):
|
|
model.eval()
|
|
epoch_loss = 0
|
|
|
|
with torch.no_grad():
|
|
for iter, (images, _) in enumerate(data_loader):
|
|
images = images.cuda() if param['parallel'] and torch.cuda.device_count(
|
|
) > 1 else images.to(param['device'])
|
|
outputs = model.forward(images)
|
|
outputs = image_normalization('denormalization')(outputs)
|
|
images = image_normalization('denormalization')(images)
|
|
loss = model.loss(images, outputs) if not param['parallel'] else model.module.loss(
|
|
images, outputs)
|
|
epoch_loss += loss.detach().item()
|
|
epoch_loss /= (iter + 1)
|
|
|
|
return epoch_loss
|
|
|
|
|
|
def config_parser_pipeline():
|
|
import argparse
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--dataset', default='cifar10', type=str,
|
|
choices=['cifar10', 'imagenet'], help='dataset')
|
|
parser.add_argument('--out', default='./out', type=str, help='out_path')
|
|
parser.add_argument('--disable_tqdm', default=False, type=bool, help='disable_tqdm')
|
|
parser.add_argument('--device', default='cuda:0', type=str, help='device')
|
|
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
|
|
parser.add_argument('--snr_list', default=['19', '13',
|
|
'7', '4', '1'], nargs='+', help='snr_list')
|
|
parser.add_argument('--ratio_list', default=['1/6', '1/12'], nargs='+', help='ratio_list')
|
|
parser.add_argument('--channel', default='AWGN', type=str,
|
|
choices=['AWGN', 'Rayleigh'], help='channel')
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main_pipeline():
|
|
args = config_parser_pipeline()
|
|
|
|
print("Training Start")
|
|
dataset_name = args.dataset
|
|
out_dir = args.out
|
|
args.snr_list = list(map(float, args.snr_list))
|
|
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
|
|
params = {}
|
|
params['disable_tqdm'] = args.disable_tqdm
|
|
params['dataset'] = dataset_name
|
|
params['out_dir'] = out_dir
|
|
params['device'] = args.device
|
|
params['snr_list'] = args.snr_list
|
|
params['ratio_list'] = args.ratio_list
|
|
params['channel'] = args.channel
|
|
if dataset_name == 'cifar10':
|
|
params['batch_size'] = 64 # 1024
|
|
params['num_workers'] = 4
|
|
params['epochs'] = 1000
|
|
params['init_lr'] = 1e-3 # 1e-2
|
|
params['weight_decay'] = 5e-4
|
|
params['parallel'] = False
|
|
params['if_scheduler'] = True
|
|
params['step_size'] = 640
|
|
params['gamma'] = 0.1
|
|
params['seed'] = 42
|
|
params['ReduceLROnPlateau'] = False
|
|
params['lr_reduce_factor'] = 0.5
|
|
params['lr_schedule_patience'] = 15
|
|
params['max_time'] = 12
|
|
params['min_lr'] = 1e-5
|
|
elif dataset_name == 'imagenet':
|
|
params['batch_size'] = 32
|
|
params['num_workers'] = 4
|
|
params['epochs'] = 300
|
|
params['init_lr'] = 1e-4
|
|
params['weight_decay'] = 5e-4
|
|
params['parallel'] = True
|
|
params['if_scheduler'] = True
|
|
params['gamma'] = 0.1
|
|
params['seed'] = 42
|
|
params['ReduceLROnPlateau'] = True
|
|
params['lr_reduce_factor'] = 0.5
|
|
params['lr_schedule_patience'] = 15
|
|
params['max_time'] = 12
|
|
params['min_lr'] = 1e-5
|
|
else:
|
|
raise Exception('Unknown dataset')
|
|
|
|
set_seed(params['seed'])
|
|
|
|
for ratio in params['ratio_list']:
|
|
for snr in params['snr_list']:
|
|
params['ratio'] = ratio
|
|
params['snr'] = snr
|
|
|
|
train_pipeline(params)
|
|
|
|
|
|
# add train_pipeline to with only dataset_name args
|
|
def train_pipeline(params):
|
|
|
|
dataset_name = params['dataset']
|
|
# load data
|
|
if dataset_name == 'cifar10':
|
|
transform = transforms.Compose([transforms.ToTensor(), ])
|
|
train_dataset = datasets.CIFAR10(root='../dataset/', train=True,
|
|
download=True, transform=transform)
|
|
|
|
train_loader = DataLoader(train_dataset, shuffle=True,
|
|
batch_size=params['batch_size'], num_workers=params['num_workers'])
|
|
test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
|
|
download=True, transform=transform)
|
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
|
batch_size=params['batch_size'], num_workers=params['num_workers'])
|
|
|
|
elif dataset_name == 'imagenet':
|
|
transform = transforms.Compose(
|
|
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
|
print("loading data of imagenet")
|
|
train_dataset = datasets.ImageFolder(root='../dataset/ImageNet/train', transform=transform)
|
|
|
|
train_loader = DataLoader(train_dataset, shuffle=True,
|
|
batch_size=params['batch_size'], num_workers=params['num_workers'])
|
|
test_dataset = Vanilla(root='../dataset/ImageNet/val', transform=transform)
|
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
|
batch_size=params['batch_size'], num_workers=params['num_workers'])
|
|
else:
|
|
raise Exception('Unknown dataset')
|
|
|
|
# create model
|
|
image_fisrt = train_dataset.__getitem__(0)[0]
|
|
c = ratio2filtersize(image_fisrt, params['ratio'])
|
|
print("The snr is {}, the inner channel is {}, the ratio is {:.2f}".format(
|
|
params['snr'], c, params['ratio']))
|
|
model = DeepJSCC(c=c, channel_type=params['channel'], snr=params['snr'])
|
|
|
|
# init exp dir
|
|
out_dir = params['out_dir']
|
|
phaser = dataset_name.upper() + '_' + str(c) + '_' + str(params['snr']) + '_' + \
|
|
"{:.2f}".format(params['ratio']) + '_' + str(params['channel']) + \
|
|
'_' + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
|
|
root_log_dir = out_dir + '/' + 'logs/' + phaser
|
|
root_ckpt_dir = out_dir + '/' + 'checkpoints/' + phaser
|
|
root_config_dir = out_dir + '/' + 'configs/' + phaser
|
|
writer = SummaryWriter(log_dir=root_log_dir)
|
|
|
|
# model init
|
|
device = torch.device(params['device'] if torch.cuda.is_available() else 'cpu')
|
|
if params['parallel'] and torch.cuda.device_count() > 1:
|
|
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
|
model = model.cuda()
|
|
else:
|
|
model = model.to(device)
|
|
|
|
# opt
|
|
optimizer = optim.Adam(
|
|
model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
|
|
if params['if_scheduler'] and not params['ReduceLROnPlateau']:
|
|
scheduler = optim.lr_scheduler.StepLR(
|
|
optimizer, step_size=params['step_size'], gamma=params['gamma'])
|
|
elif params['ReduceLROnPlateau']:
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
|
|
factor=params['lr_reduce_factor'],
|
|
patience=params['lr_schedule_patience'],
|
|
verbose=False)
|
|
else:
|
|
print("No scheduler")
|
|
scheduler = None
|
|
|
|
writer.add_text('config', str(params))
|
|
t0 = time.time()
|
|
epoch_train_losses, epoch_val_losses = [], []
|
|
per_epoch_time = []
|
|
|
|
# train
|
|
# At any point you can hit Ctrl + C to break out of training early.
|
|
try:
|
|
with tqdm(range(params['epochs']), disable=params['disable_tqdm']) as t:
|
|
for epoch in t:
|
|
|
|
t.set_description('Epoch %d' % epoch)
|
|
|
|
start = time.time()
|
|
|
|
epoch_train_loss, optimizer = train_epoch(
|
|
model, optimizer, params, train_loader)
|
|
|
|
epoch_val_loss = evaluate_epoch(model, params, test_loader)
|
|
|
|
epoch_train_losses.append(epoch_train_loss)
|
|
epoch_val_losses.append(epoch_val_loss)
|
|
|
|
writer.add_scalar('train/_loss', epoch_train_loss, epoch)
|
|
writer.add_scalar('val/_loss', epoch_val_loss, epoch)
|
|
writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)
|
|
|
|
t.set_postfix(time=time.time() - start, lr=optimizer.param_groups[0]['lr'],
|
|
train_loss=epoch_train_loss, val_loss=epoch_val_loss)
|
|
|
|
per_epoch_time.append(time.time() - start)
|
|
|
|
# Saving checkpoint
|
|
|
|
if not os.path.exists(root_ckpt_dir):
|
|
os.makedirs(root_ckpt_dir)
|
|
torch.save(model.state_dict(), '{}.pkl'.format(
|
|
root_ckpt_dir + "/epoch_" + str(epoch)))
|
|
|
|
files = glob.glob(root_ckpt_dir + '/*.pkl')
|
|
for file in files:
|
|
epoch_nb = file.split('_')[-1]
|
|
epoch_nb = int(epoch_nb.split('.')[0])
|
|
if epoch_nb < epoch - 1:
|
|
os.remove(file)
|
|
|
|
if params['ReduceLROnPlateau'] and scheduler is not None:
|
|
scheduler.step(epoch_val_loss)
|
|
elif params['if_scheduler'] and not params['ReduceLROnPlateau']:
|
|
scheduler.step() # use only information from the validation loss
|
|
|
|
if optimizer.param_groups[0]['lr'] < params['min_lr']:
|
|
print("\n!! LR EQUAL TO MIN LR SET.")
|
|
break
|
|
|
|
# Stop training after params['max_time'] hours
|
|
if time.time() - t0 > params['max_time'] * 3600:
|
|
print('-' * 89)
|
|
print("Max_time for training elapsed {:.2f} hours, so stopping".format(
|
|
params['max_time']))
|
|
break
|
|
|
|
except KeyboardInterrupt:
|
|
print('-' * 89)
|
|
print('Exiting from training early because of KeyboardInterrupt')
|
|
|
|
test_loss = evaluate_epoch(model, params, test_loader)
|
|
train_loss = evaluate_epoch(model, params, train_loader)
|
|
print("Test Accuracy: {:.4f}".format(test_loss))
|
|
print("Train Accuracy: {:.4f}".format(train_loss))
|
|
print("Convergence Time (Epochs): {:.4f}".format(epoch))
|
|
print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - t0))
|
|
print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))
|
|
|
|
"""
|
|
Write the results in out_dir/results folder
|
|
"""
|
|
|
|
writer.add_text(tag='result', text_string="""Dataset: {}\nparams={}\n\nTotal Parameters: {}\n\n
|
|
FINAL RESULTS\nTEST Loss: {:.4f}\nTRAIN Loss: {:.4f}\n\n
|
|
Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""
|
|
.format(dataset_name, params, view_model_param(model), np.mean(np.array(train_loss)),
|
|
np.mean(np.array(test_loss)), epoch, (time.time() - t0) / 3600, np.mean(per_epoch_time)))
|
|
writer.close()
|
|
if not os.path.exists(os.path.dirname(root_config_dir)):
|
|
os.makedirs(os.path.dirname(root_config_dir))
|
|
with open(root_config_dir + '.yaml', 'w') as f:
|
|
dict_yaml = {'dataset_name': dataset_name, 'params': params,
|
|
'inner_channel': c, 'total_parameters': view_model_param(model)}
|
|
import yaml
|
|
yaml.dump(dict_yaml, f)
|
|
|
|
del model, optimizer, scheduler, train_loader, test_loader
|
|
del writer
|
|
|
|
|
|
def train(args, ratio: float, snr: float): # deprecated
|
|
|
|
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
|
# load data
|
|
if args.dataset == 'cifar10':
|
|
transform = transforms.Compose([transforms.ToTensor(), ])
|
|
train_dataset = datasets.CIFAR10(root='../dataset/', train=True,
|
|
download=True, transform=transform)
|
|
|
|
train_loader = DataLoader(train_dataset, shuffle=True,
|
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
|
test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
|
|
download=True, transform=transform)
|
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
|
elif args.dataset == 'imagenet':
|
|
transform = transforms.Compose(
|
|
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
|
print("loading data of imagenet")
|
|
train_dataset = datasets.ImageFolder(root='./dataset/ImageNet/train', transform=transform)
|
|
|
|
train_loader = DataLoader(train_dataset, shuffle=True,
|
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
|
test_dataset = Vanilla(root='./dataset/ImageNet/val', transform=transform)
|
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
|
else:
|
|
raise Exception('Unknown dataset')
|
|
|
|
print(args)
|
|
image_fisrt = train_dataset.__getitem__(0)[0]
|
|
c = ratio2filtersize(image_fisrt, ratio)
|
|
print("the inner channel is {}".format(c))
|
|
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
|
|
|
|
if args.parallel and torch.cuda.device_count() > 1:
|
|
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
|
model = model.cuda()
|
|
criterion = nn.MSELoss(reduction='mean').cuda()
|
|
else:
|
|
model = model.to(device)
|
|
criterion = nn.MSELoss(reduction='mean').to(device)
|
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
|
if args.if_scheduler:
|
|
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
|
|
|
|
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=True, disable=args.disable_tqdm)
|
|
for epoch in epoch_loop:
|
|
run_loss = 0.0
|
|
for images, _ in tqdm((train_loader), leave=False, disable=args.disable_tqdm):
|
|
optimizer.zero_grad()
|
|
images = images.cuda() if args.parallel and torch.cuda.device_count() > 1 else images.to(device)
|
|
outputs = model(images)
|
|
outputs = image_normalization('denormalization')(outputs)
|
|
images = image_normalization('denormalization')(images)
|
|
loss = criterion(outputs, images)
|
|
loss.backward()
|
|
optimizer.step()
|
|
run_loss += loss.item()
|
|
if args.if_scheduler: # the scheduler is wrong before
|
|
scheduler.step()
|
|
with torch.no_grad():
|
|
model.eval()
|
|
test_mse = 0.0
|
|
for images, _ in tqdm((test_loader), leave=False, disable=args.disable_tqdm):
|
|
images = images.cuda() if args.parallel and torch.cuda.device_count() > 1 else images.to(device)
|
|
outputs = model(images)
|
|
images = image_normalization('denormalization')(images)
|
|
outputs = image_normalization('denormalization')(outputs)
|
|
loss = criterion(outputs, images)
|
|
test_mse += loss.item()
|
|
model.train()
|
|
# epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
|
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format(
|
|
epoch, run_loss / len(train_loader), test_mse / len(test_loader), optimizer.param_groups[0]['lr']))
|
|
save_model(model, args.saved, args.saved + '/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'
|
|
.format(args.dataset, args.epochs, ratio, snr, args.batch_size, c))
|
|
|
|
|
|
def config_parser(): # deprecated
|
|
import argparse
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
|
|
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
|
|
parser.add_argument('--epochs', default=256, type=int, help='number of epochs')
|
|
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
|
|
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
|
|
parser.add_argument('--channel', default='AWGN', type=str,
|
|
choices=['AWGN', 'Rayleigh'], help='channel type')
|
|
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
|
parser.add_argument('--snr_list', default=['19', '13',
|
|
'7', '4', '1'], nargs='+', help='snr_list')
|
|
parser.add_argument('--ratio_list', default=['1/3',
|
|
'1/6', '1/12'], nargs='+', help='ratio_list')
|
|
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
|
parser.add_argument('--dataset', default='cifar10', type=str,
|
|
choices=['cifar10', 'imagenet'], help='dataset')
|
|
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
|
|
parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler')
|
|
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
|
|
parser.add_argument('--device', default='cuda:0', type=str, help='device')
|
|
parser.add_argument('--gamma', default=0.5, type=float, help='gamma')
|
|
parser.add_argument('--disable_tqdm', default=True, type=bool, help='disable_tqdm')
|
|
return parser.parse_args()
|
|
|
|
|
|
def main(): # deprecated
|
|
args = config_parser()
|
|
args.snr_list = list(map(float, args.snr_list))
|
|
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
|
|
set_seed(args.seed)
|
|
print("Training Start")
|
|
for ratio in args.ratio_list:
|
|
for snr in args.snr_list:
|
|
train(args, ratio, snr)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main_pipeline()
|
|
# main()
|