This commit is contained in:
chun 2023-12-21 18:54:52 +08:00
parent 78ff76955c
commit a7900ec006
5 changed files with 133 additions and 51 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
test.py
*.pyc
*.log

View File

@ -1,9 +1,22 @@
import torch
import torch.nn as nn
import numpy as np
def AWGN_channel():
def channel(channel_type='AWGN', snr=20):
def AWGN_channel(z_hat: torch.Tensor):
k = np.prod(z_hat.size()[1:])
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)
noi_pwr = sig_pwr / (k * 10 ** (snr / 10))
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
return z_hat + noise
def Rayleigh_channel(z_hat: torch.Tensor):
pass
def Rayleigh_channel():
pass
if channel_type == 'AWGN':
return AWGN_channel
elif channel_type == 'Rayleigh':
return Rayleigh_channel
else:
raise Exception('Unknown type of channel')

106
model.py
View File

@ -9,10 +9,30 @@ import torch
import torch.nn as nn
import numpy as np
import channel
import torch.nn.functional as F
def _image_normalization(norm_type):
def _inner(tensor: torch.Tensor):
if norm_type == 'nomalization':
return tensor / 255.0
elif norm_type == 'denormalization':
return (tensor * 255.0).type(torch.FloatTensor)
else:
raise Exception('Unknown type of normalization')
return _inner
def ratio2filtersize(x, ratio):
before_size = np.prod(x.size())
encoder_temp = _Encoder(is_temp=True)
z_temp = encoder_temp(x)
c = before_size * ratio / np.prod(z_temp.size()[-2:])
return int(c)
class _ConvWithPReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super(_ConvWithPReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.prelu = nn.PReLU()
@ -24,9 +44,10 @@ class _ConvWithPReLU(nn.Module):
class _TransConvWithPReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, activate=nn.PReLU(), padding=0):
def __init__(self, in_channels, out_channels, kernel_size, stride, activate=nn.PReLU(), padding=0, output_padding=0):
super(_TransConvWithPReLU, self).__init__()
self.transconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
self.transconv = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride, padding, output_padding)
self.activate = activate
def forward(self, x):
@ -35,37 +56,30 @@ class _TransConvWithPReLU(nn.Module):
return x
def _image_normalization(tensor, norm_type):
if norm_type == 'nomalization':
return tensor / 255.0
elif norm_type == 'denormalization':
return tensor * 255.0
else:
raise Exception('Unknown type of normalization')
def _NormlizationLayer(norm_type='nomalization'):
pass
def ratio2filter_size(x, ratio):
before_size = np.prod(x.size())
after_size = before_size*ratio
encoder_temp = Encoder(is_temp=True)
x_temp = encoder_temp(x)
class Encoder(nn.Module):
class _Encoder(nn.Module):
def __init__(self, c=1, is_temp=False):
super(Encoder, self).__init__()
super(_Encoder, self).__init__()
self.is_temp = is_temp
self.imgae_normalization = _image_normalization(norm_type='nomalization')
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1)
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1)
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, stride=1)
self.norm = _NormlizationLayer(norm_type='nomalization')
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32,
kernel_size=5, padding=2) # padding size could be changed here
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2)
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2)
self.norm = self._normlizationLayer()
@staticmethod
def _normlizationLayer(P=1):
def _inner(z_hat: torch.Tensor):
batch_size = z_hat.size()[0]
k = np.prod(z_hat.size()[1:])
k = torch.tensor(k)
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
tensor = torch.sqrt(P * k) * z_hat / (z_temp @ z_trans)
return tensor
return _inner
def forward(self, x):
x = self.imgae_normalization(x)
@ -75,26 +89,46 @@ class Encoder(nn.Module):
x = self.conv4(x)
if not self.is_temp:
x = self.conv5(x)
z = self.norm(x)
del x
return z
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
class _Decoder(nn.Module):
def __init__(self, c=1):
super(_Decoder, self).__init__()
self.imgae_normalization = _image_normalization(norm_type='denormalization')
self.tconv1 = _TransConvWithPReLU(
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv2 = _TransConvWithPReLU(
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv3 = _TransConvWithPReLU(
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=6, stride=2)
self.tconv5 = _TransConvWithPReLU(
in_channels=16, out_channels=3, kernel_size=6, stride=2, activate=nn.Sigmoid())
# may be some problems in tconv4 and tconv5, the kernal_size is not the same as the paper which is 5
def forward(self, x):
pass
x = self.tconv1(x)
x = self.tconv2(x)
x = self.tconv3(x)
x = self.tconv4(x)
x = self.tconv5(x)
x = self.imgae_normalization(x)
return x
class DeepJSCC(nn.Module):
def __init__(self):
def __init__(self, c, channel_type='AWGN', snr=20):
super(DeepJSCC, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.encoder = _Encoder(c=c)
self.channel = channel.channel(channel_type,snr)
self.decoder = _Decoder(c=c)
def forward(self, x):
z = self.encoder(x)
z = self.channel(z)
x_hat = self.decoder(z)
return x_hat

View File

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 11:00:00 2023
Created on Tue Dec 17:00:00 2023
@author: chun
"""
@ -10,10 +10,11 @@ import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from model import DeepJSCC
from model import DeepJSCC, ratio2filtersize
from torch.nn.parallel import DataParallel
from channel import channel
def config_parser():
@ -21,31 +22,64 @@ def config_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('optimizer', default='Adam', type=str, help='optimizer')
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
parser.add_argument('--batch_size', default=64, type=int, help='batch size')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str, help='weight decay')
parser.add_argument('--weight_decay', default=1e-3, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list')
parser.add_argument('--early_stop', default=True, type=bool, help='early_stop')
return parser.parse_args()
def main():
args = config_parser()
print("Training Start")
# for ratio in args.ratio_list:
# for snr in args.snr_list:
# train(args, ratio, snr)
train(args, 1/6, 20)
def train(args: config_parser(), ratio: float, snr: float):
print("training with ratio: {}, snr: {}, channel: {}".format(ratio, snr, args.channel))
# load data
transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='./Dataset/cifar-10-batches-py/', train=True,
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
test_dataset = datasets.MNIST(root='./Dataset/cifar-10-batches-py/', train=False,
test_dataset = datasets.MNIST(root='./Dataset/', train=False,
download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size)
image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
epoch_loop = tqdm((args.epochs), total=len(args.epochs), leave=False)
for epoch in epoch_loop:
run_loss = 0.0
for images, _ in tqdm((train_loader), leave=False):
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, images)
loss.backward()
optimizer.step()
run_loss += loss.item()
epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]')
epoch_loop.set_postfix(loss=run_loss)
save_model(model, args.saved + '/model_{}_{}.pth'.format(ratio, snr))
def train():
pass
def save_model(model, path):
torch.save(model.state_dict(), path)
if __name__ == '__main__':