This commit is contained in:
chun 2023-12-21 18:54:52 +08:00
parent 78ff76955c
commit a7900ec006
5 changed files with 133 additions and 51 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
test.py test.py
*.pyc *.pyc
*.log

View File

@ -1,9 +1,22 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
def AWGN_channel(): def channel(channel_type='AWGN', snr=20):
pass def AWGN_channel(z_hat: torch.Tensor):
k = np.prod(z_hat.size()[1:])
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)
noi_pwr = sig_pwr / (k * 10 ** (snr / 10))
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
return z_hat + noise
def Rayleigh_channel(): def Rayleigh_channel(z_hat: torch.Tensor):
pass pass
if channel_type == 'AWGN':
return AWGN_channel
elif channel_type == 'Rayleigh':
return Rayleigh_channel
else:
raise Exception('Unknown type of channel')

106
model.py
View File

@ -9,10 +9,30 @@ import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import channel import channel
import torch.nn.functional as F
def _image_normalization(norm_type):
def _inner(tensor: torch.Tensor):
if norm_type == 'nomalization':
return tensor / 255.0
elif norm_type == 'denormalization':
return (tensor * 255.0).type(torch.FloatTensor)
else:
raise Exception('Unknown type of normalization')
return _inner
def ratio2filtersize(x, ratio):
before_size = np.prod(x.size())
encoder_temp = _Encoder(is_temp=True)
z_temp = encoder_temp(x)
c = before_size * ratio / np.prod(z_temp.size()[-2:])
return int(c)
class _ConvWithPReLU(nn.Module): class _ConvWithPReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super(_ConvWithPReLU, self).__init__() super(_ConvWithPReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.prelu = nn.PReLU() self.prelu = nn.PReLU()
@ -24,9 +44,10 @@ class _ConvWithPReLU(nn.Module):
class _TransConvWithPReLU(nn.Module): class _TransConvWithPReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, activate=nn.PReLU(), padding=0): def __init__(self, in_channels, out_channels, kernel_size, stride, activate=nn.PReLU(), padding=0, output_padding=0):
super(_TransConvWithPReLU, self).__init__() super(_TransConvWithPReLU, self).__init__()
self.transconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding) self.transconv = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride, padding, output_padding)
self.activate = activate self.activate = activate
def forward(self, x): def forward(self, x):
@ -35,37 +56,30 @@ class _TransConvWithPReLU(nn.Module):
return x return x
def _image_normalization(tensor, norm_type): class _Encoder(nn.Module):
if norm_type == 'nomalization':
return tensor / 255.0
elif norm_type == 'denormalization':
return tensor * 255.0
else:
raise Exception('Unknown type of normalization')
def _NormlizationLayer(norm_type='nomalization'):
pass
def ratio2filter_size(x, ratio):
before_size = np.prod(x.size())
after_size = before_size*ratio
encoder_temp = Encoder(is_temp=True)
x_temp = encoder_temp(x)
class Encoder(nn.Module):
def __init__(self, c=1, is_temp=False): def __init__(self, c=1, is_temp=False):
super(Encoder, self).__init__() super(_Encoder, self).__init__()
self.is_temp = is_temp self.is_temp = is_temp
self.imgae_normalization = _image_normalization(norm_type='nomalization') self.imgae_normalization = _image_normalization(norm_type='nomalization')
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2) self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2) self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1) self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32,
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1) kernel_size=5, padding=2) # padding size could be changed here
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, stride=1) self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2)
self.norm = _NormlizationLayer(norm_type='nomalization') self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2)
self.norm = self._normlizationLayer()
@staticmethod
def _normlizationLayer(P=1):
def _inner(z_hat: torch.Tensor):
batch_size = z_hat.size()[0]
k = np.prod(z_hat.size()[1:])
k = torch.tensor(k)
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
tensor = torch.sqrt(P * k) * z_hat / (z_temp @ z_trans)
return tensor
return _inner
def forward(self, x): def forward(self, x):
x = self.imgae_normalization(x) x = self.imgae_normalization(x)
@ -75,26 +89,46 @@ class Encoder(nn.Module):
x = self.conv4(x) x = self.conv4(x)
if not self.is_temp: if not self.is_temp:
x = self.conv5(x) x = self.conv5(x)
z = self.norm(x) z = self.norm(x)
del x del x
return z return z
class Decoder(nn.Module): class _Decoder(nn.Module):
def __init__(self): def __init__(self, c=1):
super(Decoder, self).__init__() super(_Decoder, self).__init__()
self.imgae_normalization = _image_normalization(norm_type='denormalization')
self.tconv1 = _TransConvWithPReLU(
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv2 = _TransConvWithPReLU(
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv3 = _TransConvWithPReLU(
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=6, stride=2)
self.tconv5 = _TransConvWithPReLU(
in_channels=16, out_channels=3, kernel_size=6, stride=2, activate=nn.Sigmoid())
# may be some problems in tconv4 and tconv5, the kernal_size is not the same as the paper which is 5
def forward(self, x): def forward(self, x):
pass x = self.tconv1(x)
x = self.tconv2(x)
x = self.tconv3(x)
x = self.tconv4(x)
x = self.tconv5(x)
x = self.imgae_normalization(x)
return x
class DeepJSCC(nn.Module): class DeepJSCC(nn.Module):
def __init__(self): def __init__(self, c, channel_type='AWGN', snr=20):
super(DeepJSCC, self).__init__() super(DeepJSCC, self).__init__()
self.encoder = Encoder() self.encoder = _Encoder(c=c)
self.decoder = Decoder() self.channel = channel.channel(channel_type,snr)
self.decoder = _Decoder(c=c)
def forward(self, x): def forward(self, x):
z = self.encoder(x) z = self.encoder(x)
z = self.channel(z)
x_hat = self.decoder(z) x_hat = self.decoder(z)
return x_hat return x_hat

View File

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
Created on Tue Dec 11:00:00 2023 Created on Tue Dec 17:00:00 2023
@author: chun @author: chun
""" """
@ -10,10 +10,11 @@ import torch.nn as nn
from torchvision import transforms from torchvision import transforms
from torchvision import datasets from torchvision import datasets
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
from model import DeepJSCC from model import DeepJSCC, ratio2filtersize
from torch.nn.parallel import DataParallel
from channel import channel
def config_parser(): def config_parser():
@ -21,31 +22,64 @@ def config_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=2048, type=int, help='Random seed') parser.add_argument('--seed', default=2048, type=int, help='Random seed')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--batch_size', default=128, type=int, help='batch size') parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
parser.add_argument('optimizer', default='Adam', type=str, help='optimizer') parser.add_argument('--batch_size', default=64, type=int, help='batch size')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum') parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay') parser.add_argument('--weight_decay', default=1e-3, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str, help='weight decay') parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path') parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list')
parser.add_argument('--early_stop', default=True, type=bool, help='early_stop')
return parser.parse_args() return parser.parse_args()
def main(): def main():
args = config_parser() args = config_parser()
print("Training Start")
# for ratio in args.ratio_list:
# for snr in args.snr_list:
# train(args, ratio, snr)
train(args, 1/6, 20)
def train(args: config_parser(), ratio: float, snr: float):
print("training with ratio: {}, snr: {}, channel: {}".format(ratio, snr, args.channel))
# load data # load data
transform = transforms.Compose([transforms.ToTensor(), ]) transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='./Dataset/cifar-10-batches-py/', train=True, train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
download=True, transform=transform) download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size) train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
test_dataset = datasets.MNIST(root='./Dataset/cifar-10-batches-py/', train=False, test_dataset = datasets.MNIST(root='./Dataset/', train=False,
download=True, transform=transform) download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size) test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size)
image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
epoch_loop = tqdm((args.epochs), total=len(args.epochs), leave=False)
for epoch in epoch_loop:
run_loss = 0.0
for images, _ in tqdm((train_loader), leave=False):
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, images)
loss.backward()
optimizer.step()
run_loss += loss.item()
epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]')
epoch_loop.set_postfix(loss=run_loss)
save_model(model, args.saved + '/model_{}_{}.pth'.format(ratio, snr))
def train(): def save_model(model, path):
pass torch.save(model.state_dict(), path)
if __name__ == '__main__': if __name__ == '__main__':