debug norm
This commit is contained in:
parent
0e8aaeda07
commit
f93f85fed9
5
train.py
5
train.py
@ -93,8 +93,9 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
images = images.cuda()
|
images = images.cuda()
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
loss = criterion(image_normalization('denormalization')(outputs),
|
outputs = image_normalization('denormalization')(outputs)
|
||||||
image_normalization('denormalization')(images))
|
images = image_normalization('denormalization')(images)
|
||||||
|
loss = criterion(outputs, images)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
run_loss += loss.item()
|
run_loss += loss.item()
|
||||||
|
|||||||
2
utils.py
2
utils.py
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
def image_normalization(norm_type):
|
def image_normalization(norm_type):
|
||||||
def _inner(tensor: torch.Tensor):
|
def _inner(tensor: torch.Tensor):
|
||||||
if norm_type == 'nomalization':
|
if norm_type == 'normalization':
|
||||||
return tensor / 255.0
|
return tensor / 255.0
|
||||||
elif norm_type == 'denormalization':
|
elif norm_type == 'denormalization':
|
||||||
return tensor * 255.0
|
return tensor * 255.0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user