23 lines
518 B
Python
23 lines
518 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def image_normalization(norm_type):
|
|
def _inner(tensor: torch.Tensor):
|
|
if norm_type == 'normalization':
|
|
return tensor / 255.0
|
|
elif norm_type == 'denormalization':
|
|
return tensor * 255.0
|
|
else:
|
|
raise Exception('Unknown type of normalization')
|
|
return _inner
|
|
|
|
|
|
def get_psnr(image, gt, max=255):
|
|
|
|
mse = F.mse_loss(image, gt)
|
|
|
|
psnr = 10 * torch.log10(max**2 / mse)
|
|
return psnr
|