# -*- coding: utf-8 -*- """ Created on Tue Dec 11:00:00 2023 @author: chun """ import torch import torch.nn as nn import numpy as np import channel """ def _image_normalization(norm_type): def _inner(tensor: torch.Tensor): if norm_type == 'nomalization': return tensor / 255.0 elif norm_type == 'denormalization': return (tensor * 255.0).type(torch.FloatTensor) else: raise Exception('Unknown type of normalization') return _inner """ def ratio2filtersize(x: torch.Tensor, ratio): if x.dim() == 4: before_size = np.prod(x.size()[1:]) elif x.dim() == 3: before_size = np.prod(x.size()) else: raise Exception('Unknown size of input') encoder_temp = _Encoder(is_temp=True) z_temp = encoder_temp(x) c = before_size * ratio / np.prod(z_temp.size()[-2:]) return int(c) class _ConvWithPReLU(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): super(_ConvWithPReLU, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) self.prelu = nn.PReLU() def forward(self, x): x = self.conv(x) x = self.prelu(x) return x class _TransConvWithPReLU(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, activate=nn.PReLU(), padding=0, output_padding=0): super(_TransConvWithPReLU, self).__init__() self.transconv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size, stride, padding, output_padding) self.activate = activate def forward(self, x): x = self.transconv(x) x = self.activate(x) return x class _Encoder(nn.Module): def __init__(self, c=1, is_temp=False): super(_Encoder, self).__init__() self.is_temp = is_temp # self.imgae_normalization = _image_normalization(norm_type='nomalization') self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2) self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2) self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2) # padding size could be changed here self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2) self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2) self.norm = self._normlizationLayer() @staticmethod def _normlizationLayer(P=1): def _inner(z_hat: torch.Tensor): if z_hat.dim() == 4: batch_size = z_hat.size()[0] k = np.prod(z_hat.size()[1:]) elif z_hat.dim() == 3: batch_size = 1 k = np.prod(z_hat.size()) else: raise Exception('Unknown size of input') k = torch.tensor(k) z_temp = z_hat.reshape(batch_size, 1, 1, -1) z_trans = z_hat.reshape(batch_size, 1, -1, 1) tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans)) return tensor return _inner def forward(self, x): #x = self.imgae_normalization(x) x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) if not self.is_temp: x = self.conv5(x) x = self.norm(x) return x class _Decoder(nn.Module): def __init__(self, c=1): super(_Decoder, self).__init__() #self.imgae_normalization = _image_normalization(norm_type='denormalization') self.tconv1 = _TransConvWithPReLU( in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2) self.tconv2 = _TransConvWithPReLU( in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2) self.tconv3 = _TransConvWithPReLU( in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2) self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=6, stride=2) self.tconv5 = _TransConvWithPReLU( in_channels=16, out_channels=3, kernel_size=6, stride=2, activate=nn.Sigmoid()) # may be some problems in tconv4 and tconv5, the kernal_size is not the same as the paper which is 5 def forward(self, x): x = self.tconv1(x) x = self.tconv2(x) x = self.tconv3(x) x = self.tconv4(x) x = self.tconv5(x) #x = self.imgae_normalization(x) return x class DeepJSCC(nn.Module): def __init__(self, c, channel_type='AWGN', snr=20): super(DeepJSCC, self).__init__() self.encoder = _Encoder(c=c) self.channel = channel.channel(channel_type, snr) self.decoder = _Decoder(c=c) def forward(self, x): z = self.encoder(x) z = self.channel(z) x_hat = self.decoder(z) return x_hat