debug and demo

This commit is contained in:
chun 2023-12-23 22:58:38 +08:00
parent c7baea5a24
commit 0d7a9da03c
5 changed files with 13 additions and 7 deletions

View File

@ -5,9 +5,13 @@ import numpy as np
def channel(channel_type='AWGN', snr=20):
def AWGN_channel(z_hat: torch.Tensor):
k = np.prod(z_hat.size()[1:])
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)
noi_pwr = sig_pwr / (k * 10 ** (snr / 10))
if z_hat.dim() == 4:
k = np.prod(z_hat.size()[1:])
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)/k
elif z_hat.dim() == 3:
k = np.prod(z_hat.size())
sig_pwr = torch.sum(torch.abs(z_hat).square())/k
noi_pwr = sig_pwr / ( 10 ** (snr / 10))
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
return z_hat + noise

BIN
demo/demo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

View File

@ -21,8 +21,8 @@ def config_parser():
def main():
args = config_parser()
transform = transforms.Compose([transforms.ToTensor(), ])
transform = transforms.Compose([transforms.ToTensor()])
args.saved = './saved/model_cifar10_0.33_19.00_40.pth' # to be deleted
test_image = Image.open(args.test_image)
test_image.load()
test_image = transform(test_image)

View File

@ -61,7 +61,7 @@ class _TransConvWithPReLU(nn.Module):
class _Encoder(nn.Module):
def __init__(self, c=1, is_temp=False):
def __init__(self, c=1, is_temp=False,P=1):
super(_Encoder, self).__init__()
self.is_temp = is_temp
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
@ -71,7 +71,7 @@ class _Encoder(nn.Module):
kernel_size=5, padding=2) # padding size could be changed here
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2)
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2)
self.norm = self._normlizationLayer()
self.norm = self._normlizationLayer(P=P)
@staticmethod
def _normlizationLayer(P=1):
@ -88,6 +88,8 @@ class _Encoder(nn.Module):
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
if batch_size == 1:
return tensor.squeeze(0)
return tensor
return _inner

Binary file not shown.