Update channel.py

This commit is contained in:
ZHENG Chunhang 2024-01-16 22:55:35 +08:00 committed by GitHub
parent 2064fd28aa
commit e142bb899a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,15 +1,15 @@
import torch
import torch.nn as nn
import numpy as np
def channel(channel_type='AWGN', snr=20):
def AWGN_channel(z_hat: torch.Tensor):
if z_hat.dim() == 4:
k = np.prod(z_hat.size()[1:])
# k = np.prod(z_hat.size()[1:])
k = torch.prod(torch.tensor(z_hat.size()[1:]))
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)/k
elif z_hat.dim() == 3:
k = np.prod(z_hat.size())
# k = np.prod(z_hat.size())
k = torch.prod(torch.tensor(z_hat.size()))
sig_pwr = torch.sum(torch.abs(z_hat).square())/k
noi_pwr = sig_pwr / ( 10 ** (snr / 10))
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)