JSCC/utils.py
2023-12-23 13:40:56 +08:00

27 lines
592 B
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
def image_normalization(norm_type):
def _inner(tensor: torch.Tensor):
if norm_type == 'nomalization':
return tensor / 255.0
elif norm_type == 'denormalization':
return tensor * 255.0
else:
raise Exception('Unknown type of normalization')
return _inner
def get_psnr(image, gt, max=255):
mse = F.mse_loss(image, gt)
psnr = 10 * torch.log10(max**2 / mse)
return psnr
a = torch.randn(2, 3, 32, 32)
b = image_normalization('nomalization')(a)