diff --git a/train.py b/train.py index b462a44..67f5674 100644 --- a/train.py +++ b/train.py @@ -93,8 +93,9 @@ def train(args: config_parser(), ratio: float, snr: float): optimizer.zero_grad() images = images.cuda() outputs = model(images) - loss = criterion(image_normalization('denormalization')(outputs), - image_normalization('denormalization')(images)) + outputs = image_normalization('denormalization')(outputs) + images = image_normalization('denormalization')(images) + loss = criterion(outputs, images) loss.backward() optimizer.step() run_loss += loss.item() diff --git a/utils.py b/utils.py index cf27d61..1f5c6f8 100644 --- a/utils.py +++ b/utils.py @@ -5,7 +5,7 @@ import torch.nn.functional as F def image_normalization(norm_type): def _inner(tensor: torch.Tensor): - if norm_type == 'nomalization': + if norm_type == 'normalization': return tensor / 255.0 elif norm_type == 'denormalization': return tensor * 255.0