debug and demo
This commit is contained in:
parent
c7baea5a24
commit
0d7a9da03c
10
channel.py
10
channel.py
@ -5,9 +5,13 @@ import numpy as np
|
||||
|
||||
def channel(channel_type='AWGN', snr=20):
|
||||
def AWGN_channel(z_hat: torch.Tensor):
|
||||
k = np.prod(z_hat.size()[1:])
|
||||
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)
|
||||
noi_pwr = sig_pwr / (k * 10 ** (snr / 10))
|
||||
if z_hat.dim() == 4:
|
||||
k = np.prod(z_hat.size()[1:])
|
||||
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)/k
|
||||
elif z_hat.dim() == 3:
|
||||
k = np.prod(z_hat.size())
|
||||
sig_pwr = torch.sum(torch.abs(z_hat).square())/k
|
||||
noi_pwr = sig_pwr / ( 10 ** (snr / 10))
|
||||
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
|
||||
return z_hat + noise
|
||||
|
||||
|
||||
BIN
demo/demo.png
Normal file
BIN
demo/demo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.5 MiB |
4
eval.py
4
eval.py
@ -21,8 +21,8 @@ def config_parser():
|
||||
|
||||
def main():
|
||||
args = config_parser()
|
||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||
|
||||
transform = transforms.Compose([transforms.ToTensor()])
|
||||
args.saved = './saved/model_cifar10_0.33_19.00_40.pth' # to be deleted
|
||||
test_image = Image.open(args.test_image)
|
||||
test_image.load()
|
||||
test_image = transform(test_image)
|
||||
|
||||
6
model.py
6
model.py
@ -61,7 +61,7 @@ class _TransConvWithPReLU(nn.Module):
|
||||
|
||||
|
||||
class _Encoder(nn.Module):
|
||||
def __init__(self, c=1, is_temp=False):
|
||||
def __init__(self, c=1, is_temp=False,P=1):
|
||||
super(_Encoder, self).__init__()
|
||||
self.is_temp = is_temp
|
||||
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
||||
@ -71,7 +71,7 @@ class _Encoder(nn.Module):
|
||||
kernel_size=5, padding=2) # padding size could be changed here
|
||||
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2)
|
||||
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2)
|
||||
self.norm = self._normlizationLayer()
|
||||
self.norm = self._normlizationLayer(P=P)
|
||||
|
||||
@staticmethod
|
||||
def _normlizationLayer(P=1):
|
||||
@ -88,6 +88,8 @@ class _Encoder(nn.Module):
|
||||
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
|
||||
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
|
||||
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
|
||||
if batch_size == 1:
|
||||
return tensor.squeeze(0)
|
||||
return tensor
|
||||
return _inner
|
||||
|
||||
|
||||
BIN
saved/model_cifar10_0.33_19.00_40.pth
Normal file
BIN
saved/model_cifar10_0.33_19.00_40.pth
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user