debug norm
This commit is contained in:
parent
f93f85fed9
commit
5c5ff99bb9
8
train.py
8
train.py
@ -83,7 +83,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
if args.parallel and torch.cuda.device_count() > 1:
|
if args.parallel and torch.cuda.device_count() > 1:
|
||||||
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
criterion = nn.MSELoss(reduction='mean')
|
criterion = nn.MSELoss(reduction='mean').cuda()
|
||||||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
|
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
|
||||||
|
|
||||||
@ -103,10 +103,10 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
model.eval()
|
model.eval()
|
||||||
test_mse = 0.0
|
test_mse = 0.0
|
||||||
for images, _ in tqdm((test_loader), leave=False):
|
for images, _ in tqdm((test_loader), leave=False):
|
||||||
images = images
|
images = images.cuda()
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
images = image_normalization('normalization')(images)
|
images = image_normalization('denormalization')(images)
|
||||||
outputs = image_normalization('normalization')(outputs)
|
outputs = image_normalization('denormalization')(outputs)
|
||||||
loss = criterion(outputs, images)
|
loss = criterion(outputs, images)
|
||||||
test_mse += loss.item()
|
test_mse += loss.item()
|
||||||
model.train()
|
model.train()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user