debug parallel
This commit is contained in:
parent
55198377dd
commit
52b7efb407
4
train.py
4
train.py
@ -117,7 +117,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
run_loss = 0.0
|
run_loss = 0.0
|
||||||
for images, _ in tqdm((train_loader), leave=False, disable=args.disable_tqdm):
|
for images, _ in tqdm((train_loader), leave=False, disable=args.disable_tqdm):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
images = images.cuda() if args.parallel else images.to(device)
|
images = images.cuda() if args.parallel and torch.cuda.device_count() > 1 else images.to(device)
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
outputs = image_normalization('denormalization')(outputs)
|
outputs = image_normalization('denormalization')(outputs)
|
||||||
images = image_normalization('denormalization')(images)
|
images = image_normalization('denormalization')(images)
|
||||||
@ -131,7 +131,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
model.eval()
|
model.eval()
|
||||||
test_mse = 0.0
|
test_mse = 0.0
|
||||||
for images, _ in tqdm((test_loader), leave=False, disable=args.disable_tqdm):
|
for images, _ in tqdm((test_loader), leave=False, disable=args.disable_tqdm):
|
||||||
images = images if args.parallel else images.to(device)
|
images = images.cuda() if args.parallel and torch.cuda.device_count() > 1 else images.to(device)
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
images = image_normalization('denormalization')(images)
|
images = image_normalization('denormalization')(images)
|
||||||
outputs = image_normalization('denormalization')(outputs)
|
outputs = image_normalization('denormalization')(outputs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user