diff --git a/channel.py b/channel.py index eb58d05..fb92a48 100644 --- a/channel.py +++ b/channel.py @@ -11,20 +11,37 @@ class Channel(nn.Module): self.snr = snr def forward(self, z_hat): - if z_hat.dim() == 4: - # k = np.prod(z_hat.size()[1:]) - k = torch.prod(torch.tensor(z_hat.size()[1:])) - sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True) / k - elif z_hat.dim() == 3: - # k = np.prod(z_hat.size()) - k = torch.prod(torch.tensor(z_hat.size())) - sig_pwr = torch.sum(torch.abs(z_hat).square()) / k + if z_hat.dim() not in {3, 4}: + raise ValueError('Input tensor must be 3D or 4D') + + # if z_hat.dim() == 4: + # # k = np.prod(z_hat.size()[1:]) + # k = torch.prod(torch.tensor(z_hat.size()[1:])) + # sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True) / k + # elif z_hat.dim() == 3: + # # k = np.prod(z_hat.size()) + # k = torch.prod(torch.tensor(z_hat.size())) + # sig_pwr = torch.sum(torch.abs(z_hat).square()) / k + + if z_hat.dim() == 3: + z_hat = z_hat.unsqueeze(0) + + k = z_hat[0].numel() + sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True) / k noi_pwr = sig_pwr / (10 ** (self.snr / 10)) - noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr) + noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr/2) if self.channel_type == 'Rayleigh': # hc = torch.randn_like(z_hat) wrong implement before - hc = torch.randn(1, device = z_hat.device) - z_hat = hc * z_hat + # hc = torch.randn(1, device = z_hat.device) + hc = torch.randn(2, device = z_hat.device) + + # clone for in-place operation + z_hat = z_hat.clone() + z_hat[:,:z_hat.size(1)//2] = hc[0] * z_hat[:,:z_hat.size(1)//2] + z_hat[:,z_hat.size(1)//2:] = hc[1] * z_hat[:,z_hat.size(1)//2:] + + + # z_hat = hc * z_hat return z_hat + noise @@ -40,6 +57,6 @@ if __name__ == '__main__': print(z_hat) channel = Channel(channel_type='Rayleigh', snr=10) - z_hat = torch.randn(64, 10, 5, 5) + z_hat = torch.randn(10, 5, 5) z_hat = channel(z_hat) print(z_hat)