56 lines
1.4 KiB
Python
56 lines
1.4 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
import os
|
|
import numpy as np
|
|
|
|
|
|
def image_normalization(norm_type):
|
|
def _inner(tensor: torch.Tensor):
|
|
if norm_type == 'normalization':
|
|
return tensor / 255.0
|
|
elif norm_type == 'denormalization':
|
|
return tensor * 255.0
|
|
else:
|
|
raise Exception('Unknown type of normalization')
|
|
return _inner
|
|
|
|
|
|
def get_psnr(image, gt, max_val=255, mse=None):
|
|
if mse is None:
|
|
mse = F.mse_loss(image, gt)
|
|
mse = torch.tensor(mse) if not isinstance(mse, torch.Tensor) else mse
|
|
|
|
psnr = 10 * torch.log10(max_val**2 / mse)
|
|
return psnr
|
|
|
|
|
|
def save_model(model, dir, path):
|
|
os.makedirs(dir, exist_ok=True)
|
|
flag = 1
|
|
while True:
|
|
if os.path.exists(path):
|
|
path = path + '_' + str(flag)
|
|
flag += 1
|
|
else:
|
|
break
|
|
torch.save(model.state_dict(), path)
|
|
print("Model saved in {}".format(path))
|
|
|
|
|
|
def set_seed(seed):
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
def view_model_param(model):
|
|
total_param = 0
|
|
|
|
for param in model.parameters():
|
|
# print(param.data.size())
|
|
total_param += np.prod(list(param.data.size()))
|
|
return total_param
|