JSCC/channel.py
2023-12-21 18:54:52 +08:00

23 lines
656 B
Python

import torch
import torch.nn as nn
import numpy as np
def channel(channel_type='AWGN', snr=20):
def AWGN_channel(z_hat: torch.Tensor):
k = np.prod(z_hat.size()[1:])
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)
noi_pwr = sig_pwr / (k * 10 ** (snr / 10))
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
return z_hat + noise
def Rayleigh_channel(z_hat: torch.Tensor):
pass
if channel_type == 'AWGN':
return AWGN_channel
elif channel_type == 'Rayleigh':
return Rayleigh_channel
else:
raise Exception('Unknown type of channel')