101 lines
2.8 KiB
Python
101 lines
2.8 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on Tue Dec 11:00:00 2023
|
|
|
|
@author: chun
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import channel
|
|
|
|
|
|
class _ConvWithPReLU(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0):
|
|
super(_ConvWithPReLU, self).__init__()
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
|
self.prelu = nn.PReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.prelu(x)
|
|
return x
|
|
|
|
|
|
class _TransConvWithPReLU(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, activate=nn.PReLU(), padding=0):
|
|
super(_TransConvWithPReLU, self).__init__()
|
|
self.transconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
|
|
self.activate = activate
|
|
|
|
def forward(self, x):
|
|
x = self.transconv(x)
|
|
x = self.activate(x)
|
|
return x
|
|
|
|
|
|
def _image_normalization(tensor, norm_type):
|
|
if norm_type == 'nomalization':
|
|
return tensor / 255.0
|
|
elif norm_type == 'denormalization':
|
|
return tensor * 255.0
|
|
else:
|
|
raise Exception('Unknown type of normalization')
|
|
|
|
|
|
def _NormlizationLayer(norm_type='nomalization'):
|
|
pass
|
|
|
|
|
|
def ratio2filter_size(x, ratio):
|
|
before_size = np.prod(x.size())
|
|
after_size = before_size*ratio
|
|
encoder_temp = Encoder(is_temp=True)
|
|
x_temp = encoder_temp(x)
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, c=1, is_temp=False):
|
|
super(Encoder, self).__init__()
|
|
self.is_temp = is_temp
|
|
self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
|
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
|
|
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
|
|
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1)
|
|
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1)
|
|
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, stride=1)
|
|
self.norm = _NormlizationLayer(norm_type='nomalization')
|
|
|
|
def forward(self, x):
|
|
x = self.imgae_normalization(x)
|
|
x = self.conv1(x)
|
|
x = self.conv2(x)
|
|
x = self.conv3(x)
|
|
x = self.conv4(x)
|
|
if not self.is_temp:
|
|
x = self.conv5(x)
|
|
z = self.norm(x)
|
|
del x
|
|
return z
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self):
|
|
super(Decoder, self).__init__()
|
|
|
|
def forward(self, x):
|
|
pass
|
|
|
|
|
|
class DeepJSCC(nn.Module):
|
|
def __init__(self):
|
|
super(DeepJSCC, self).__init__()
|
|
self.encoder = Encoder()
|
|
self.decoder = Decoder()
|
|
|
|
def forward(self, x):
|
|
z = self.encoder(x)
|
|
x_hat = self.decoder(z)
|
|
return x_hat
|