add Rayleigh channel
This commit is contained in:
parent
ff3d377583
commit
b7bd3bdd42
41
channel.py
41
channel.py
@ -11,20 +11,37 @@ class Channel(nn.Module):
|
|||||||
self.snr = snr
|
self.snr = snr
|
||||||
|
|
||||||
def forward(self, z_hat):
|
def forward(self, z_hat):
|
||||||
if z_hat.dim() == 4:
|
if z_hat.dim() not in {3, 4}:
|
||||||
# k = np.prod(z_hat.size()[1:])
|
raise ValueError('Input tensor must be 3D or 4D')
|
||||||
k = torch.prod(torch.tensor(z_hat.size()[1:]))
|
|
||||||
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True) / k
|
# if z_hat.dim() == 4:
|
||||||
elif z_hat.dim() == 3:
|
# # k = np.prod(z_hat.size()[1:])
|
||||||
# k = np.prod(z_hat.size())
|
# k = torch.prod(torch.tensor(z_hat.size()[1:]))
|
||||||
k = torch.prod(torch.tensor(z_hat.size()))
|
# sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True) / k
|
||||||
sig_pwr = torch.sum(torch.abs(z_hat).square()) / k
|
# elif z_hat.dim() == 3:
|
||||||
|
# # k = np.prod(z_hat.size())
|
||||||
|
# k = torch.prod(torch.tensor(z_hat.size()))
|
||||||
|
# sig_pwr = torch.sum(torch.abs(z_hat).square()) / k
|
||||||
|
|
||||||
|
if z_hat.dim() == 3:
|
||||||
|
z_hat = z_hat.unsqueeze(0)
|
||||||
|
|
||||||
|
k = z_hat[0].numel()
|
||||||
|
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True) / k
|
||||||
noi_pwr = sig_pwr / (10 ** (self.snr / 10))
|
noi_pwr = sig_pwr / (10 ** (self.snr / 10))
|
||||||
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
|
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr/2)
|
||||||
if self.channel_type == 'Rayleigh':
|
if self.channel_type == 'Rayleigh':
|
||||||
# hc = torch.randn_like(z_hat) wrong implement before
|
# hc = torch.randn_like(z_hat) wrong implement before
|
||||||
hc = torch.randn(1, device = z_hat.device)
|
# hc = torch.randn(1, device = z_hat.device)
|
||||||
z_hat = hc * z_hat
|
hc = torch.randn(2, device = z_hat.device)
|
||||||
|
|
||||||
|
# clone for in-place operation
|
||||||
|
z_hat = z_hat.clone()
|
||||||
|
z_hat[:,:z_hat.size(1)//2] = hc[0] * z_hat[:,:z_hat.size(1)//2]
|
||||||
|
z_hat[:,z_hat.size(1)//2:] = hc[1] * z_hat[:,z_hat.size(1)//2:]
|
||||||
|
|
||||||
|
|
||||||
|
# z_hat = hc * z_hat
|
||||||
|
|
||||||
return z_hat + noise
|
return z_hat + noise
|
||||||
|
|
||||||
@ -40,6 +57,6 @@ if __name__ == '__main__':
|
|||||||
print(z_hat)
|
print(z_hat)
|
||||||
|
|
||||||
channel = Channel(channel_type='Rayleigh', snr=10)
|
channel = Channel(channel_type='Rayleigh', snr=10)
|
||||||
z_hat = torch.randn(64, 10, 5, 5)
|
z_hat = torch.randn(10, 5, 5)
|
||||||
z_hat = channel(z_hat)
|
z_hat = channel(z_hat)
|
||||||
print(z_hat)
|
print(z_hat)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user