JSCC/model.py
2024-02-03 20:05:40 +08:00

158 lines
5.7 KiB
Python

# -*- coding: utf-8 -*-
"""
Created on Tue Dec 11:00:00 2023
@author: chun
"""
import torch
import torch.nn as nn
import channel
""" def _image_normalization(norm_type):
def _inner(tensor: torch.Tensor):
if norm_type == 'nomalization':
return tensor / 255.0
elif norm_type == 'denormalization':
return (tensor * 255.0).type(torch.FloatTensor)
else:
raise Exception('Unknown type of normalization')
return _inner """
def ratio2filtersize(x: torch.Tensor, ratio):
if x.dim() == 4:
# before_size = np.prod(x.size()[1:])
before_size = torch.prod(torch.tensor(x.size()[1:]))
elif x.dim() == 3:
# before_size = np.prod(x.size())
before_size = torch.prod(torch.tensor(x.size()))
else:
raise Exception('Unknown size of input')
encoder_temp = _Encoder(is_temp=True)
z_temp = encoder_temp(x)
# c = before_size * ratio / np.prod(z_temp.size()[-2:])
c = before_size * ratio / torch.prod(torch.tensor(z_temp.size()[-2:]))
return int(c)
class _ConvWithPReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super(_ConvWithPReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.prelu = nn.PReLU()
nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='leaky_relu')
def forward(self, x):
x = self.conv(x)
x = self.prelu(x)
return x
class _TransConvWithPReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, activate=nn.PReLU(), padding=0, output_padding=0):
super(_TransConvWithPReLU, self).__init__()
self.transconv = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride, padding, output_padding)
self.activate = activate
if activate == nn.PReLU():
nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out', nonlinearity='leaky_relu')
else:
nn.init.xavier_normal_(self.transconv.weight)
def forward(self, x):
x = self.transconv(x)
x = self.activate(x)
return x
class _Encoder(nn.Module):
def __init__(self, c=1, is_temp=False, P=1):
super(_Encoder, self).__init__()
self.is_temp = is_temp
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32,
kernel_size=5, padding=2) # padding size could be changed here
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2)
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2)
self.norm = self._normlizationLayer(P=P)
@staticmethod
def _normlizationLayer(P=1):
def _inner(z_hat: torch.Tensor):
if z_hat.dim() == 4:
batch_size = z_hat.size()[0]
# k = np.prod(z_hat.size()[1:])
k = torch.prod(torch.tensor(z_hat.size()[1:]))
elif z_hat.dim() == 3:
batch_size = 1
# k = np.prod(z_hat.size())
k = torch.prod(torch.tensor(z_hat.size()))
else:
raise Exception('Unknown size of input')
# k = torch.tensor(k)
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
if batch_size == 1:
return tensor.squeeze(0)
return tensor
return _inner
def forward(self, x):
#x = self.imgae_normalization(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
if not self.is_temp:
x = self.conv5(x)
x = self.norm(x)
return x
class _Decoder(nn.Module):
def __init__(self, c=1):
super(_Decoder, self).__init__()
#self.imgae_normalization = _image_normalization(norm_type='denormalization')
self.tconv1 = _TransConvWithPReLU(
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv2 = _TransConvWithPReLU(
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv3 = _TransConvWithPReLU(
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=6, stride=2)
self.tconv5 = _TransConvWithPReLU(
in_channels=16, out_channels=3, kernel_size=6, stride=2, activate=nn.Sigmoid())
# may be some problems in tconv4 and tconv5, the kernal_size is not the same as the paper which is 5
def forward(self, x):
x = self.tconv1(x)
x = self.tconv2(x)
x = self.tconv3(x)
x = self.tconv4(x)
x = self.tconv5(x)
#x = self.imgae_normalization(x)
return x
class DeepJSCC(nn.Module):
def __init__(self, c, channel_type='AWGN', snr=20):
super(DeepJSCC, self).__init__()
self.encoder = _Encoder(c=c)
self.channel = channel.channel(channel_type, snr)
self.decoder = _Decoder(c=c)
def forward(self, x):
z = self.encoder(x)
z = self.channel(z)
x_hat = self.decoder(z)
return x_hat
def change_channel(self, channel_type, snr):
self.channel = channel.channel(channel_type, snr)