debug parallel
This commit is contained in:
parent
55198377dd
commit
52b7efb407
4
train.py
4
train.py
@ -117,7 +117,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
run_loss = 0.0
|
||||
for images, _ in tqdm((train_loader), leave=False, disable=args.disable_tqdm):
|
||||
optimizer.zero_grad()
|
||||
images = images.cuda() if args.parallel else images.to(device)
|
||||
images = images.cuda() if args.parallel and torch.cuda.device_count() > 1 else images.to(device)
|
||||
outputs = model(images)
|
||||
outputs = image_normalization('denormalization')(outputs)
|
||||
images = image_normalization('denormalization')(images)
|
||||
@ -131,7 +131,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
model.eval()
|
||||
test_mse = 0.0
|
||||
for images, _ in tqdm((test_loader), leave=False, disable=args.disable_tqdm):
|
||||
images = images if args.parallel else images.to(device)
|
||||
images = images.cuda() if args.parallel and torch.cuda.device_count() > 1 else images.to(device)
|
||||
outputs = model(images)
|
||||
images = image_normalization('denormalization')(images)
|
||||
outputs = image_normalization('denormalization')(outputs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user