# -*- coding: utf-8 -*- """ Created on Tue Dec 17:00:00 2023 @author: chun """ import os import torch import torch.nn as nn from torchvision import transforms from torchvision import datasets from torch.utils.data import DataLoader import torch.optim as optim from tqdm import tqdm from model import DeepJSCC, ratio2filtersize from torch.nn.parallel import DataParallel from utils import image_normalization, set_seed, save_model, view_model_param from fractions import Fraction from dataset import Vanilla import numpy as np import time from tensorboardX import SummaryWriter import glob def train_epoch(model, optimizer, param, data_loader): model.train() epoch_loss = 0 for iter, (images, _) in enumerate(data_loader): images = images.cuda() if param['parallel'] and torch.cuda.device_count( ) > 1 else images.to(param['device']) optimizer.zero_grad() outputs = model.forward(images) outputs = image_normalization('denormalization')(outputs) images = image_normalization('denormalization')(images) loss = model.loss(images, outputs) if not param['parallel'] else model.module.loss( images, outputs) loss.backward() optimizer.step() epoch_loss += loss.detach().item() epoch_loss /= (iter + 1) return epoch_loss, optimizer def evaluate_epoch(model, param, data_loader): model.eval() epoch_loss = 0 with torch.no_grad(): for iter, (images, _) in enumerate(data_loader): images = images.cuda() if param['parallel'] and torch.cuda.device_count( ) > 1 else images.to(param['device']) outputs = model.forward(images) outputs = image_normalization('denormalization')(outputs) images = image_normalization('denormalization')(images) loss = model.loss(images, outputs) if not param['parallel'] else model.module.loss( images, outputs) epoch_loss += loss.detach().item() epoch_loss /= (iter + 1) return epoch_loss def config_parser_pipeline(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset') parser.add_argument('--out', default='./out', type=str, help='out_path') parser.add_argument('--disable_tqdm', default=False, type=bool, help='disable_tqdm') parser.add_argument('--device', default='cuda:0', type=str, help='device') parser.add_argument('--parallel', default=False, type=bool, help='parallel') parser.add_argument('--snr_list', default=['19', '13', '7', '4', '1'], nargs='+', help='snr_list') parser.add_argument('--ratio_list', default=['1/6', '1/12'], nargs='+', help='ratio_list') parser.add_argument('--channel', default='AWGN', type=str, choices=['AWGN', 'Rayleigh'], help='channel') return parser.parse_args() def main_pipeline(): args = config_parser_pipeline() print("Training Start") dataset_name = args.dataset out_dir = args.out args.snr_list = list(map(float, args.snr_list)) args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list)) params = {} params['disable_tqdm'] = args.disable_tqdm params['dataset'] = dataset_name params['out_dir'] = out_dir params['device'] = args.device params['snr_list'] = args.snr_list params['ratio_list'] = args.ratio_list params['channel'] = args.channel if dataset_name == 'cifar10': params['batch_size'] = 64 # 1024 params['num_workers'] = 4 params['epochs'] = 1000 params['init_lr'] = 1e-3 # 1e-2 params['weight_decay'] = 5e-4 params['parallel'] = False params['if_scheduler'] = True params['step_size'] = 640 params['gamma'] = 0.1 params['seed'] = 42 params['ReduceLROnPlateau'] = False params['lr_reduce_factor'] = 0.5 params['lr_schedule_patience'] = 15 params['max_time'] = 12 params['min_lr'] = 1e-5 elif dataset_name == 'imagenet': params['batch_size'] = 32 params['num_workers'] = 4 params['epochs'] = 300 params['init_lr'] = 1e-4 params['weight_decay'] = 5e-4 params['parallel'] = True params['if_scheduler'] = True params['gamma'] = 0.1 params['seed'] = 42 params['ReduceLROnPlateau'] = True params['lr_reduce_factor'] = 0.5 params['lr_schedule_patience'] = 15 params['max_time'] = 12 params['min_lr'] = 1e-5 else: raise Exception('Unknown dataset') set_seed(params['seed']) for ratio in params['ratio_list']: for snr in params['snr_list']: params['ratio'] = ratio params['snr'] = snr train_pipeline(params) # add train_pipeline to with only dataset_name args def train_pipeline(params): dataset_name = params['dataset'] # load data if dataset_name == 'cifar10': transform = transforms.Compose([transforms.ToTensor(), ]) train_dataset = datasets.CIFAR10(root='../dataset/', train=True, download=True, transform=transform) train_loader = DataLoader(train_dataset, shuffle=True, batch_size=params['batch_size'], num_workers=params['num_workers']) test_dataset = datasets.CIFAR10(root='../dataset/', train=False, download=True, transform=transform) test_loader = DataLoader(test_dataset, shuffle=True, batch_size=params['batch_size'], num_workers=params['num_workers']) elif dataset_name == 'imagenet': transform = transforms.Compose( [transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128 print("loading data of imagenet") train_dataset = datasets.ImageFolder(root='../dataset/ImageNet/train', transform=transform) train_loader = DataLoader(train_dataset, shuffle=True, batch_size=params['batch_size'], num_workers=params['num_workers']) test_dataset = Vanilla(root='../dataset/ImageNet/val', transform=transform) test_loader = DataLoader(test_dataset, shuffle=True, batch_size=params['batch_size'], num_workers=params['num_workers']) else: raise Exception('Unknown dataset') # create model image_fisrt = train_dataset.__getitem__(0)[0] c = ratio2filtersize(image_fisrt, params['ratio']) print("The snr is {}, the inner channel is {}, the ratio is {:.2f}".format( params['snr'], c, params['ratio'])) model = DeepJSCC(c=c, channel_type=params['channel'], snr=params['snr']) # init exp dir out_dir = params['out_dir'] phaser = dataset_name.upper() + '_' + str(c) + '_' + str(params['snr']) + '_' + \ "{:.2f}".format(params['ratio']) + '_' + str(params['channel']) + \ '_' + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') root_log_dir = out_dir + '/' + 'logs/' + phaser root_ckpt_dir = out_dir + '/' + 'checkpoints/' + phaser root_config_dir = out_dir + '/' + 'configs/' + phaser writer = SummaryWriter(log_dir=root_log_dir) # model init device = torch.device(params['device'] if torch.cuda.is_available() else 'cpu') if params['parallel'] and torch.cuda.device_count() > 1: model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) model = model.cuda() else: model = model.to(device) # opt optimizer = optim.Adam( model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay']) if params['if_scheduler'] and not params['ReduceLROnPlateau']: scheduler = optim.lr_scheduler.StepLR( optimizer, step_size=params['step_size'], gamma=params['gamma']) elif params['ReduceLROnPlateau']: scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=params['lr_reduce_factor'], patience=params['lr_schedule_patience'], verbose=False) else: print("No scheduler") scheduler = None writer.add_text('config', str(params)) t0 = time.time() epoch_train_losses, epoch_val_losses = [], [] per_epoch_time = [] # train # At any point you can hit Ctrl + C to break out of training early. try: with tqdm(range(params['epochs']), disable=params['disable_tqdm']) as t: for epoch in t: t.set_description('Epoch %d' % epoch) start = time.time() epoch_train_loss, optimizer = train_epoch( model, optimizer, params, train_loader) epoch_val_loss = evaluate_epoch(model, params, test_loader) epoch_train_losses.append(epoch_train_loss) epoch_val_losses.append(epoch_val_loss) writer.add_scalar('train/_loss', epoch_train_loss, epoch) writer.add_scalar('val/_loss', epoch_val_loss, epoch) writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) t.set_postfix(time=time.time() - start, lr=optimizer.param_groups[0]['lr'], train_loss=epoch_train_loss, val_loss=epoch_val_loss) per_epoch_time.append(time.time() - start) # Saving checkpoint if not os.path.exists(root_ckpt_dir): os.makedirs(root_ckpt_dir) torch.save(model.state_dict(), '{}.pkl'.format( root_ckpt_dir + "/epoch_" + str(epoch))) files = glob.glob(root_ckpt_dir + '/*.pkl') for file in files: epoch_nb = file.split('_')[-1] epoch_nb = int(epoch_nb.split('.')[0]) if epoch_nb < epoch - 1: os.remove(file) if params['ReduceLROnPlateau'] and scheduler is not None: scheduler.step(epoch_val_loss) elif params['if_scheduler'] and not params['ReduceLROnPlateau']: scheduler.step() # use only information from the validation loss if optimizer.param_groups[0]['lr'] < params['min_lr']: print("\n!! LR EQUAL TO MIN LR SET.") break # Stop training after params['max_time'] hours if time.time() - t0 > params['max_time'] * 3600: print('-' * 89) print("Max_time for training elapsed {:.2f} hours, so stopping".format( params['max_time'])) break except KeyboardInterrupt: print('-' * 89) print('Exiting from training early because of KeyboardInterrupt') test_loss = evaluate_epoch(model, params, test_loader) train_loss = evaluate_epoch(model, params, train_loader) print("Test Accuracy: {:.4f}".format(test_loss)) print("Train Accuracy: {:.4f}".format(train_loss)) print("Convergence Time (Epochs): {:.4f}".format(epoch)) print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - t0)) print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time))) """ Write the results in out_dir/results folder """ writer.add_text(tag='result', text_string="""Dataset: {}\nparams={}\n\nTotal Parameters: {}\n\n FINAL RESULTS\nTEST Loss: {:.4f}\nTRAIN Loss: {:.4f}\n\n Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n""" .format(dataset_name, params, view_model_param(model), np.mean(np.array(train_loss)), np.mean(np.array(test_loss)), epoch, (time.time() - t0) / 3600, np.mean(per_epoch_time))) writer.close() if not os.path.exists(os.path.dirname(root_config_dir)): os.makedirs(os.path.dirname(root_config_dir)) with open(root_config_dir + '.yaml', 'w') as f: dict_yaml = {'dataset_name': dataset_name, 'params': params, 'inner_channel': c, 'total_parameters': view_model_param(model)} import yaml yaml.dump(dict_yaml, f) del model, optimizer, scheduler, train_loader, test_loader del writer def train(args, ratio: float, snr: float): # deprecated device = torch.device(args.device if torch.cuda.is_available() else 'cpu') # load data if args.dataset == 'cifar10': transform = transforms.Compose([transforms.ToTensor(), ]) train_dataset = datasets.CIFAR10(root='../dataset/', train=True, download=True, transform=transform) train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers) test_dataset = datasets.CIFAR10(root='../dataset/', train=False, download=True, transform=transform) test_loader = DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers) elif args.dataset == 'imagenet': transform = transforms.Compose( [transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128 print("loading data of imagenet") train_dataset = datasets.ImageFolder(root='./dataset/ImageNet/train', transform=transform) train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers) test_dataset = Vanilla(root='./dataset/ImageNet/val', transform=transform) test_loader = DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers) else: raise Exception('Unknown dataset') print(args) image_fisrt = train_dataset.__getitem__(0)[0] c = ratio2filtersize(image_fisrt, ratio) print("the inner channel is {}".format(c)) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr) if args.parallel and torch.cuda.device_count() > 1: model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) model = model.cuda() criterion = nn.MSELoss(reduction='mean').cuda() else: model = model.to(device) criterion = nn.MSELoss(reduction='mean').to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) if args.if_scheduler: scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=True, disable=args.disable_tqdm) for epoch in epoch_loop: run_loss = 0.0 for images, _ in tqdm((train_loader), leave=False, disable=args.disable_tqdm): optimizer.zero_grad() images = images.cuda() if args.parallel and torch.cuda.device_count() > 1 else images.to(device) outputs = model(images) outputs = image_normalization('denormalization')(outputs) images = image_normalization('denormalization')(images) loss = criterion(outputs, images) loss.backward() optimizer.step() run_loss += loss.item() if args.if_scheduler: # the scheduler is wrong before scheduler.step() with torch.no_grad(): model.eval() test_mse = 0.0 for images, _ in tqdm((test_loader), leave=False, disable=args.disable_tqdm): images = images.cuda() if args.parallel and torch.cuda.device_count() > 1 else images.to(device) outputs = model(images) images = image_normalization('denormalization')(images) outputs = image_normalization('denormalization')(outputs) loss = criterion(outputs, images) test_mse += loss.item() model.train() # epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader)) print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format( epoch, run_loss / len(train_loader), test_mse / len(test_loader), optimizer.param_groups[0]['lr'])) save_model(model, args.saved, args.saved + '/{}_{}_{:.2f}_{:.2f}_{}_{}.pth' .format(args.dataset, args.epochs, ratio, snr, args.batch_size, c)) def config_parser(): # deprecated import argparse parser = argparse.ArgumentParser() parser.add_argument('--seed', default=2048, type=int, help='Random seed') parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') parser.add_argument('--epochs', default=256, type=int, help='number of epochs') parser.add_argument('--batch_size', default=256, type=int, help='batch size') parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay') parser.add_argument('--channel', default='AWGN', type=str, choices=['AWGN', 'Rayleigh'], help='channel type') parser.add_argument('--saved', default='./saved', type=str, help='saved_path') parser.add_argument('--snr_list', default=['19', '13', '7', '4', '1'], nargs='+', help='snr_list') parser.add_argument('--ratio_list', default=['1/3', '1/6', '1/12'], nargs='+', help='ratio_list') parser.add_argument('--num_workers', default=0, type=int, help='num_workers') parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset') parser.add_argument('--parallel', default=False, type=bool, help='parallel') parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler') parser.add_argument('--step_size', default=640, type=int, help='scheduler') parser.add_argument('--device', default='cuda:0', type=str, help='device') parser.add_argument('--gamma', default=0.5, type=float, help='gamma') parser.add_argument('--disable_tqdm', default=True, type=bool, help='disable_tqdm') return parser.parse_args() def main(): # deprecated args = config_parser() args.snr_list = list(map(float, args.snr_list)) args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list)) set_seed(args.seed) print("Training Start") for ratio in args.ratio_list: for snr in args.snr_list: train(args, ratio, snr) if __name__ == '__main__': main_pipeline() # main()