debug parallel

This commit is contained in:
chun 2024-01-20 11:57:37 +08:00
parent 55198377dd
commit 52b7efb407

View File

@ -117,7 +117,7 @@ def train(args: config_parser(), ratio: float, snr: float):
run_loss = 0.0
for images, _ in tqdm((train_loader), leave=False, disable=args.disable_tqdm):
optimizer.zero_grad()
images = images.cuda() if args.parallel else images.to(device)
images = images.cuda() if args.parallel and torch.cuda.device_count() > 1 else images.to(device)
outputs = model(images)
outputs = image_normalization('denormalization')(outputs)
images = image_normalization('denormalization')(images)
@ -131,7 +131,7 @@ def train(args: config_parser(), ratio: float, snr: float):
model.eval()
test_mse = 0.0
for images, _ in tqdm((test_loader), leave=False, disable=args.disable_tqdm):
images = images if args.parallel else images.to(device)
images = images.cuda() if args.parallel and torch.cuda.device_count() > 1 else images.to(device)
outputs = model(images)
images = image_normalization('denormalization')(images)
outputs = image_normalization('denormalization')(outputs)