debug norm

This commit is contained in:
chun 2023-12-23 17:19:49 +08:00
parent 0e8aaeda07
commit f93f85fed9
2 changed files with 4 additions and 3 deletions

View File

@ -93,8 +93,9 @@ def train(args: config_parser(), ratio: float, snr: float):
optimizer.zero_grad() optimizer.zero_grad()
images = images.cuda() images = images.cuda()
outputs = model(images) outputs = model(images)
loss = criterion(image_normalization('denormalization')(outputs), outputs = image_normalization('denormalization')(outputs)
image_normalization('denormalization')(images)) images = image_normalization('denormalization')(images)
loss = criterion(outputs, images)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
run_loss += loss.item() run_loss += loss.item()

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
def image_normalization(norm_type): def image_normalization(norm_type):
def _inner(tensor: torch.Tensor): def _inner(tensor: torch.Tensor):
if norm_type == 'nomalization': if norm_type == 'normalization':
return tensor / 255.0 return tensor / 255.0
elif norm_type == 'denormalization': elif norm_type == 'denormalization':
return tensor * 255.0 return tensor * 255.0