debug norm

This commit is contained in:
chun 2023-12-23 17:30:19 +08:00
parent f93f85fed9
commit 5c5ff99bb9
2 changed files with 4 additions and 8 deletions

View File

@ -83,7 +83,7 @@ def train(args: config_parser(), ratio: float, snr: float):
if args.parallel and torch.cuda.device_count() > 1: if args.parallel and torch.cuda.device_count() > 1:
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
model = model.cuda() model = model.cuda()
criterion = nn.MSELoss(reduction='mean') criterion = nn.MSELoss(reduction='mean').cuda()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
@ -103,10 +103,10 @@ def train(args: config_parser(), ratio: float, snr: float):
model.eval() model.eval()
test_mse = 0.0 test_mse = 0.0
for images, _ in tqdm((test_loader), leave=False): for images, _ in tqdm((test_loader), leave=False):
images = images images = images.cuda()
outputs = model(images) outputs = model(images)
images = image_normalization('normalization')(images) images = image_normalization('denormalization')(images)
outputs = image_normalization('normalization')(outputs) outputs = image_normalization('denormalization')(outputs)
loss = criterion(outputs, images) loss = criterion(outputs, images)
test_mse += loss.item() test_mse += loss.item()
model.train() model.train()

View File

@ -20,7 +20,3 @@ def get_psnr(image, gt, max=255):
psnr = 10 * torch.log10(max**2 / mse) psnr = 10 * torch.log10(max**2 / mse)
return psnr return psnr
a = torch.randn(2, 3, 32, 32)
b = image_normalization('nomalization')(a)