diff --git a/train.py b/train.py index 62583fe..410279b 100644 --- a/train.py +++ b/train.py @@ -34,6 +34,8 @@ def config_parser(): parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset') parser.add_argument('--parallel', default=False, type=bool, help='parallel') + parser.add_argument('--if_scheduler', default=True, type=bool, help='if_scheduler') + parser.add_argument('--step_size', default=640, type=int, help='scheduler') return parser.parse_args() @@ -80,13 +82,17 @@ def train(args: config_parser(), ratio: float, snr: float): image_fisrt = train_dataset.__getitem__(0)[0] c = ratio2filtersize(image_fisrt, ratio) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr) + if args.parallel and torch.cuda.device_count() > 1: model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) model = model.cuda() + criterion = nn.MSELoss(reduction='mean').cuda() optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) - epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) + if args.if_scheduler: + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1) + epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) for epoch in epoch_loop: run_loss = 0.0 for images, _ in tqdm((train_loader), leave=False): @@ -99,6 +105,8 @@ def train(args: config_parser(), ratio: float, snr: float): loss.backward() optimizer.step() run_loss += loss.item() + if args.if_scheduler: + scheduler.step() with torch.no_grad(): model.eval() test_mse = 0.0