formatter modified
This commit is contained in:
parent
14cd4ed4e4
commit
c81101c269
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def channel(channel_type='AWGN', snr=20):
|
def channel(channel_type='AWGN', snr=20):
|
||||||
def AWGN_channel(z_hat: torch.Tensor):
|
def AWGN_channel(z_hat: torch.Tensor):
|
||||||
if z_hat.dim() == 4:
|
if z_hat.dim() == 4:
|
||||||
@ -11,7 +12,7 @@ def channel(channel_type='AWGN', snr=20):
|
|||||||
# k = np.prod(z_hat.size())
|
# k = np.prod(z_hat.size())
|
||||||
k = torch.prod(torch.tensor(z_hat.size()))
|
k = torch.prod(torch.tensor(z_hat.size()))
|
||||||
sig_pwr = torch.sum(torch.abs(z_hat).square())/k
|
sig_pwr = torch.sum(torch.abs(z_hat).square())/k
|
||||||
noi_pwr = sig_pwr / ( 10 ** (snr / 10))
|
noi_pwr = sig_pwr / (10 ** (snr / 10))
|
||||||
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
|
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
|
||||||
return z_hat + noise
|
return z_hat + noise
|
||||||
|
|
||||||
|
|||||||
2
eval.py
2
eval.py
@ -43,7 +43,7 @@ def main():
|
|||||||
demo_image = image_normalization('normalization')(demo_image)
|
demo_image = image_normalization('normalization')(demo_image)
|
||||||
demo_image = torch.cat([test_image, demo_image], dim=1)
|
demo_image = torch.cat([test_image, demo_image], dim=1)
|
||||||
demo_image = transforms.ToPILImage()(demo_image)
|
demo_image = transforms.ToPILImage()(demo_image)
|
||||||
demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1],args.test_image.split('/')[-1]))
|
demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1], args.test_image.split('/')[-1]))
|
||||||
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
3
model.py
3
model.py
@ -43,7 +43,6 @@ class _ConvWithPReLU(nn.Module):
|
|||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
||||||
self.prelu = nn.PReLU()
|
self.prelu = nn.PReLU()
|
||||||
|
|
||||||
|
|
||||||
nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -70,7 +69,7 @@ class _TransConvWithPReLU(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class _Encoder(nn.Module):
|
class _Encoder(nn.Module):
|
||||||
def __init__(self, c=1, is_temp=False,P=1):
|
def __init__(self, c=1, is_temp=False, P=1):
|
||||||
super(_Encoder, self).__init__()
|
super(_Encoder, self).__init__()
|
||||||
self.is_temp = is_temp
|
self.is_temp = is_temp
|
||||||
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
||||||
|
|||||||
5
train.py
5
train.py
@ -19,6 +19,7 @@ from fractions import Fraction
|
|||||||
from dataset import Vanilla
|
from dataset import Vanilla
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
def set_seed(seed):
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
@ -55,7 +56,6 @@ def config_parser():
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = config_parser()
|
args = config_parser()
|
||||||
args.snr_list = list(map(float, args.snr_list))
|
args.snr_list = list(map(float, args.snr_list))
|
||||||
@ -66,6 +66,7 @@ def main():
|
|||||||
for snr in args.snr_list:
|
for snr in args.snr_list:
|
||||||
train(args, ratio, snr)
|
train(args, ratio, snr)
|
||||||
|
|
||||||
|
|
||||||
def train(args: config_parser(), ratio: float, snr: float):
|
def train(args: config_parser(), ratio: float, snr: float):
|
||||||
|
|
||||||
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
||||||
@ -142,7 +143,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format(
|
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format(
|
||||||
epoch, run_loss/len(train_loader), test_mse/len(test_loader), optimizer.param_groups[0]['lr']))
|
epoch, run_loss/len(train_loader), test_mse/len(test_loader), optimizer.param_groups[0]['lr']))
|
||||||
save_model(model, args.saved, args.saved +
|
save_model(model, args.saved, args.saved +
|
||||||
'/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, args.batch_size,c))
|
'/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, args.batch_size, c))
|
||||||
|
|
||||||
|
|
||||||
def save_model(model, dir, path):
|
def save_model(model, dir, path):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user