diff --git a/train.py b/train.py index 8eda390..27bea95 100644 --- a/train.py +++ b/train.py @@ -82,6 +82,7 @@ def train(args: config_parser(), ratio: float, snr: float): model = DeepJSCC(c=c, channel_type=args.channel, snr=snr) if args.parallel and torch.cuda.device_count() > 1: model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) + model.cuda() criterion = nn.MSELoss(reduction='mean') optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) @@ -92,7 +93,6 @@ def train(args: config_parser(), ratio: float, snr: float): optimizer.zero_grad() if not args.parallel: images = images.cuda(device=device) - # images = images.cuda(device=device) outputs = model(images) loss = criterion(image_normalization('denormalization')(outputs), image_normalization('denormalization')(images))