JSCC/utils.py
2023-12-23 17:30:19 +08:00

23 lines
518 B
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
def image_normalization(norm_type):
def _inner(tensor: torch.Tensor):
if norm_type == 'normalization':
return tensor / 255.0
elif norm_type == 'denormalization':
return tensor * 255.0
else:
raise Exception('Unknown type of normalization')
return _inner
def get_psnr(image, gt, max=255):
mse = F.mse_loss(image, gt)
psnr = 10 * torch.log10(max**2 / mse)
return psnr