debug and demo
This commit is contained in:
parent
c7baea5a24
commit
0d7a9da03c
10
channel.py
10
channel.py
@ -5,9 +5,13 @@ import numpy as np
|
|||||||
|
|
||||||
def channel(channel_type='AWGN', snr=20):
|
def channel(channel_type='AWGN', snr=20):
|
||||||
def AWGN_channel(z_hat: torch.Tensor):
|
def AWGN_channel(z_hat: torch.Tensor):
|
||||||
k = np.prod(z_hat.size()[1:])
|
if z_hat.dim() == 4:
|
||||||
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)
|
k = np.prod(z_hat.size()[1:])
|
||||||
noi_pwr = sig_pwr / (k * 10 ** (snr / 10))
|
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)/k
|
||||||
|
elif z_hat.dim() == 3:
|
||||||
|
k = np.prod(z_hat.size())
|
||||||
|
sig_pwr = torch.sum(torch.abs(z_hat).square())/k
|
||||||
|
noi_pwr = sig_pwr / ( 10 ** (snr / 10))
|
||||||
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
|
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
|
||||||
return z_hat + noise
|
return z_hat + noise
|
||||||
|
|
||||||
|
|||||||
BIN
demo/demo.png
Normal file
BIN
demo/demo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.5 MiB |
4
eval.py
4
eval.py
@ -21,8 +21,8 @@ def config_parser():
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = config_parser()
|
args = config_parser()
|
||||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
transform = transforms.Compose([transforms.ToTensor()])
|
||||||
|
args.saved = './saved/model_cifar10_0.33_19.00_40.pth' # to be deleted
|
||||||
test_image = Image.open(args.test_image)
|
test_image = Image.open(args.test_image)
|
||||||
test_image.load()
|
test_image.load()
|
||||||
test_image = transform(test_image)
|
test_image = transform(test_image)
|
||||||
|
|||||||
6
model.py
6
model.py
@ -61,7 +61,7 @@ class _TransConvWithPReLU(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class _Encoder(nn.Module):
|
class _Encoder(nn.Module):
|
||||||
def __init__(self, c=1, is_temp=False):
|
def __init__(self, c=1, is_temp=False,P=1):
|
||||||
super(_Encoder, self).__init__()
|
super(_Encoder, self).__init__()
|
||||||
self.is_temp = is_temp
|
self.is_temp = is_temp
|
||||||
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
||||||
@ -71,7 +71,7 @@ class _Encoder(nn.Module):
|
|||||||
kernel_size=5, padding=2) # padding size could be changed here
|
kernel_size=5, padding=2) # padding size could be changed here
|
||||||
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2)
|
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2)
|
||||||
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2)
|
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2)
|
||||||
self.norm = self._normlizationLayer()
|
self.norm = self._normlizationLayer(P=P)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _normlizationLayer(P=1):
|
def _normlizationLayer(P=1):
|
||||||
@ -88,6 +88,8 @@ class _Encoder(nn.Module):
|
|||||||
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
|
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
|
||||||
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
|
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
|
||||||
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
|
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
|
||||||
|
if batch_size == 1:
|
||||||
|
return tensor.squeeze(0)
|
||||||
return tensor
|
return tensor
|
||||||
return _inner
|
return _inner
|
||||||
|
|
||||||
|
|||||||
BIN
saved/model_cifar10_0.33_19.00_40.pth
Normal file
BIN
saved/model_cifar10_0.33_19.00_40.pth
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user