From 0e8aaeda075e42a85af2ced86ac5c47eaadde0eb Mon Sep 17 00:00:00 2001 From: chun Date: Sat, 23 Dec 2023 17:16:39 +0800 Subject: [PATCH] para modified --- train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 7c0a8cb..b462a44 100644 --- a/train.py +++ b/train.py @@ -82,7 +82,7 @@ def train(args: config_parser(), ratio: float, snr: float): model = DeepJSCC(c=c, channel_type=args.channel, snr=snr) if args.parallel and torch.cuda.device_count() > 1: model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) - model.cuda() + model = model.cuda() criterion = nn.MSELoss(reduction='mean') optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) @@ -91,8 +91,7 @@ def train(args: config_parser(), ratio: float, snr: float): run_loss = 0.0 for images, _ in tqdm((train_loader), leave=False): optimizer.zero_grad() - if not args.parallel: - images = images.cuda(device=device) + images = images.cuda() outputs = model(images) loss = criterion(image_normalization('denormalization')(outputs), image_normalization('denormalization')(images))