formatter modified

This commit is contained in:
chun 2024-02-03 20:05:40 +08:00
parent 14cd4ed4e4
commit c81101c269
5 changed files with 13 additions and 12 deletions

View File

@ -1,6 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
def channel(channel_type='AWGN', snr=20): def channel(channel_type='AWGN', snr=20):
def AWGN_channel(z_hat: torch.Tensor): def AWGN_channel(z_hat: torch.Tensor):
if z_hat.dim() == 4: if z_hat.dim() == 4:
@ -11,7 +12,7 @@ def channel(channel_type='AWGN', snr=20):
# k = np.prod(z_hat.size()) # k = np.prod(z_hat.size())
k = torch.prod(torch.tensor(z_hat.size())) k = torch.prod(torch.tensor(z_hat.size()))
sig_pwr = torch.sum(torch.abs(z_hat).square())/k sig_pwr = torch.sum(torch.abs(z_hat).square())/k
noi_pwr = sig_pwr / ( 10 ** (snr / 10)) noi_pwr = sig_pwr / (10 ** (snr / 10))
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr) noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
return z_hat + noise return z_hat + noise

View File

@ -14,7 +14,7 @@ class Vanilla(Dataset):
img = Image.open(img_path).convert('RGB') img = Image.open(img_path).convert('RGB')
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
return img, 0 # 0 is a fake label not important return img, 0 # 0 is a fake label not important
def __len__(self): def __len__(self):
return len(self.imgs) return len(self.imgs)

View File

@ -34,7 +34,7 @@ def main():
model.change_channel(args.channel, args.snr) model.change_channel(args.channel, args.snr)
psnr_all = 0.0 psnr_all = 0.0
for i in range(args.times): for i in range(args.times):
demo_image = model(test_image) demo_image = model(test_image)
demo_image = image_normalization('denormalization')(demo_image) demo_image = image_normalization('denormalization')(demo_image)
@ -43,7 +43,7 @@ def main():
demo_image = image_normalization('normalization')(demo_image) demo_image = image_normalization('normalization')(demo_image)
demo_image = torch.cat([test_image, demo_image], dim=1) demo_image = torch.cat([test_image, demo_image], dim=1)
demo_image = transforms.ToPILImage()(demo_image) demo_image = transforms.ToPILImage()(demo_image)
demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1],args.test_image.split('/')[-1])) demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1], args.test_image.split('/')[-1]))
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times)) print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))

View File

@ -42,8 +42,7 @@ class _ConvWithPReLU(nn.Module):
super(_ConvWithPReLU, self).__init__() super(_ConvWithPReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.prelu = nn.PReLU() self.prelu = nn.PReLU()
nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='leaky_relu') nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='leaky_relu')
def forward(self, x): def forward(self, x):
@ -62,7 +61,7 @@ class _TransConvWithPReLU(nn.Module):
nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out', nonlinearity='leaky_relu') nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out', nonlinearity='leaky_relu')
else: else:
nn.init.xavier_normal_(self.transconv.weight) nn.init.xavier_normal_(self.transconv.weight)
def forward(self, x): def forward(self, x):
x = self.transconv(x) x = self.transconv(x)
x = self.activate(x) x = self.activate(x)
@ -70,7 +69,7 @@ class _TransConvWithPReLU(nn.Module):
class _Encoder(nn.Module): class _Encoder(nn.Module):
def __init__(self, c=1, is_temp=False,P=1): def __init__(self, c=1, is_temp=False, P=1):
super(_Encoder, self).__init__() super(_Encoder, self).__init__()
self.is_temp = is_temp self.is_temp = is_temp
# self.imgae_normalization = _image_normalization(norm_type='nomalization') # self.imgae_normalization = _image_normalization(norm_type='nomalization')

View File

@ -19,6 +19,7 @@ from fractions import Fraction
from dataset import Vanilla from dataset import Vanilla
import numpy as np import numpy as np
def set_seed(seed): def set_seed(seed):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
@ -55,7 +56,6 @@ def config_parser():
return parser.parse_args() return parser.parse_args()
def main(): def main():
args = config_parser() args = config_parser()
args.snr_list = list(map(float, args.snr_list)) args.snr_list = list(map(float, args.snr_list))
@ -66,6 +66,7 @@ def main():
for snr in args.snr_list: for snr in args.snr_list:
train(args, ratio, snr) train(args, ratio, snr)
def train(args: config_parser(), ratio: float, snr: float): def train(args: config_parser(), ratio: float, snr: float):
device = torch.device(args.device if torch.cuda.is_available() else 'cpu') device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
@ -94,7 +95,7 @@ def train(args: config_parser(), ratio: float, snr: float):
batch_size=args.batch_size, num_workers=args.num_workers) batch_size=args.batch_size, num_workers=args.num_workers)
else: else:
raise Exception('Unknown dataset') raise Exception('Unknown dataset')
print(args) print(args)
image_fisrt = train_dataset.__getitem__(0)[0] image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio) c = ratio2filtersize(image_fisrt, ratio)
@ -125,7 +126,7 @@ def train(args: config_parser(), ratio: float, snr: float):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
run_loss += loss.item() run_loss += loss.item()
if args.if_scheduler: # the scheduler is wrong before if args.if_scheduler: # the scheduler is wrong before
scheduler.step() scheduler.step()
with torch.no_grad(): with torch.no_grad():
model.eval() model.eval()
@ -142,7 +143,7 @@ def train(args: config_parser(), ratio: float, snr: float):
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format( print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format(
epoch, run_loss/len(train_loader), test_mse/len(test_loader), optimizer.param_groups[0]['lr'])) epoch, run_loss/len(train_loader), test_mse/len(test_loader), optimizer.param_groups[0]['lr']))
save_model(model, args.saved, args.saved + save_model(model, args.saved, args.saved +
'/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, args.batch_size,c)) '/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, args.batch_size, c))
def save_model(model, dir, path): def save_model(model, dir, path):