From a7900ec0061d7e41dc2d8601f80267c77ec14c28 Mon Sep 17 00:00:00 2001 From: chun Date: Thu, 21 Dec 2023 18:54:52 +0800 Subject: [PATCH] v1.0 --- .gitignore | 1 + channel.py | 21 +++++++++-- model.py | 106 +++++++++++++++++++++++++++++++++++------------------ scripts.py | 0 train.py | 56 ++++++++++++++++++++++------ 5 files changed, 133 insertions(+), 51 deletions(-) delete mode 100644 scripts.py diff --git a/.gitignore b/.gitignore index a8dd0a0..444aaee 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py *.pyc +*.log \ No newline at end of file diff --git a/channel.py b/channel.py index 7009645..4bbf558 100644 --- a/channel.py +++ b/channel.py @@ -1,9 +1,22 @@ import torch import torch.nn as nn +import numpy as np -def AWGN_channel(): - pass +def channel(channel_type='AWGN', snr=20): + def AWGN_channel(z_hat: torch.Tensor): + k = np.prod(z_hat.size()[1:]) + sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True) + noi_pwr = sig_pwr / (k * 10 ** (snr / 10)) + noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr) + return z_hat + noise -def Rayleigh_channel(): - pass \ No newline at end of file + def Rayleigh_channel(z_hat: torch.Tensor): + pass + + if channel_type == 'AWGN': + return AWGN_channel + elif channel_type == 'Rayleigh': + return Rayleigh_channel + else: + raise Exception('Unknown type of channel') diff --git a/model.py b/model.py index 1e9ecc1..2dacb0b 100644 --- a/model.py +++ b/model.py @@ -9,10 +9,30 @@ import torch import torch.nn as nn import numpy as np import channel +import torch.nn.functional as F + + +def _image_normalization(norm_type): + def _inner(tensor: torch.Tensor): + if norm_type == 'nomalization': + return tensor / 255.0 + elif norm_type == 'denormalization': + return (tensor * 255.0).type(torch.FloatTensor) + else: + raise Exception('Unknown type of normalization') + return _inner + + +def ratio2filtersize(x, ratio): + before_size = np.prod(x.size()) + encoder_temp = _Encoder(is_temp=True) + z_temp = encoder_temp(x) + c = before_size * ratio / np.prod(z_temp.size()[-2:]) + return int(c) class _ConvWithPReLU(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): super(_ConvWithPReLU, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) self.prelu = nn.PReLU() @@ -24,9 +44,10 @@ class _ConvWithPReLU(nn.Module): class _TransConvWithPReLU(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, activate=nn.PReLU(), padding=0): + def __init__(self, in_channels, out_channels, kernel_size, stride, activate=nn.PReLU(), padding=0, output_padding=0): super(_TransConvWithPReLU, self).__init__() - self.transconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding) + self.transconv = nn.ConvTranspose2d( + in_channels, out_channels, kernel_size, stride, padding, output_padding) self.activate = activate def forward(self, x): @@ -35,37 +56,30 @@ class _TransConvWithPReLU(nn.Module): return x -def _image_normalization(tensor, norm_type): - if norm_type == 'nomalization': - return tensor / 255.0 - elif norm_type == 'denormalization': - return tensor * 255.0 - else: - raise Exception('Unknown type of normalization') - - -def _NormlizationLayer(norm_type='nomalization'): - pass - - -def ratio2filter_size(x, ratio): - before_size = np.prod(x.size()) - after_size = before_size*ratio - encoder_temp = Encoder(is_temp=True) - x_temp = encoder_temp(x) - - -class Encoder(nn.Module): +class _Encoder(nn.Module): def __init__(self, c=1, is_temp=False): - super(Encoder, self).__init__() + super(_Encoder, self).__init__() self.is_temp = is_temp self.imgae_normalization = _image_normalization(norm_type='nomalization') self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2) self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2) - self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1) - self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1) - self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, stride=1) - self.norm = _NormlizationLayer(norm_type='nomalization') + self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32, + kernel_size=5, padding=2) # padding size could be changed here + self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2) + self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2) + self.norm = self._normlizationLayer() + + @staticmethod + def _normlizationLayer(P=1): + def _inner(z_hat: torch.Tensor): + batch_size = z_hat.size()[0] + k = np.prod(z_hat.size()[1:]) + k = torch.tensor(k) + z_temp = z_hat.reshape(batch_size, 1, 1, -1) + z_trans = z_hat.reshape(batch_size, 1, -1, 1) + tensor = torch.sqrt(P * k) * z_hat / (z_temp @ z_trans) + return tensor + return _inner def forward(self, x): x = self.imgae_normalization(x) @@ -75,26 +89,46 @@ class Encoder(nn.Module): x = self.conv4(x) if not self.is_temp: x = self.conv5(x) + z = self.norm(x) del x return z -class Decoder(nn.Module): - def __init__(self): - super(Decoder, self).__init__() +class _Decoder(nn.Module): + def __init__(self, c=1): + super(_Decoder, self).__init__() + self.imgae_normalization = _image_normalization(norm_type='denormalization') + self.tconv1 = _TransConvWithPReLU( + in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2) + self.tconv2 = _TransConvWithPReLU( + in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2) + self.tconv3 = _TransConvWithPReLU( + in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2) + self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=6, stride=2) + self.tconv5 = _TransConvWithPReLU( + in_channels=16, out_channels=3, kernel_size=6, stride=2, activate=nn.Sigmoid()) + # may be some problems in tconv4 and tconv5, the kernal_size is not the same as the paper which is 5 def forward(self, x): - pass + x = self.tconv1(x) + x = self.tconv2(x) + x = self.tconv3(x) + x = self.tconv4(x) + x = self.tconv5(x) + x = self.imgae_normalization(x) + return x class DeepJSCC(nn.Module): - def __init__(self): + def __init__(self, c, channel_type='AWGN', snr=20): super(DeepJSCC, self).__init__() - self.encoder = Encoder() - self.decoder = Decoder() + self.encoder = _Encoder(c=c) + self.channel = channel.channel(channel_type,snr) + self.decoder = _Decoder(c=c) def forward(self, x): z = self.encoder(x) + z = self.channel(z) x_hat = self.decoder(z) return x_hat diff --git a/scripts.py b/scripts.py deleted file mode 100644 index e69de29..0000000 diff --git a/train.py b/train.py index 93029d7..153ebbc 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Created on Tue Dec 11:00:00 2023 +Created on Tue Dec 17:00:00 2023 @author: chun """ @@ -10,10 +10,11 @@ import torch.nn as nn from torchvision import transforms from torchvision import datasets from torch.utils.data import DataLoader -import torch.nn.functional as F import torch.optim as optim import tqdm -from model import DeepJSCC +from model import DeepJSCC, ratio2filtersize +from torch.nn.parallel import DataParallel +from channel import channel def config_parser(): @@ -21,31 +22,64 @@ def config_parser(): parser = argparse.ArgumentParser() parser.add_argument('--seed', default=2048, type=int, help='Random seed') parser.add_argument('--lr', default=0.1, type=float, help='learning rate') - parser.add_argument('--batch_size', default=128, type=int, help='batch size') - parser.add_argument('optimizer', default='Adam', type=str, help='optimizer') + parser.add_argument('--epochs', default=100, type=int, help='number of epochs') + parser.add_argument('--batch_size', default=64, type=int, help='batch size') parser.add_argument('--momentum', default=0.9, type=float, help='momentum') - parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay') - parser.add_argument('--channel', default='AWGN', type=str, help='weight decay') + parser.add_argument('--weight_decay', default=1e-3, type=float, help='weight decay') + parser.add_argument('--channel', default='AWGN', type=str, help='channel type') parser.add_argument('--saved', default='./saved', type=str, help='saved_path') + parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list') + parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list') + parser.add_argument('--early_stop', default=True, type=bool, help='early_stop') return parser.parse_args() def main(): args = config_parser() + print("Training Start") + # for ratio in args.ratio_list: + # for snr in args.snr_list: + # train(args, ratio, snr) + train(args, 1/6, 20) + + +def train(args: config_parser(), ratio: float, snr: float): + + print("training with ratio: {}, snr: {}, channel: {}".format(ratio, snr, args.channel)) + # load data transform = transforms.Compose([transforms.ToTensor(), ]) - train_dataset = datasets.CIFAR10(root='./Dataset/cifar-10-batches-py/', train=True, + train_dataset = datasets.CIFAR10(root='./Dataset/', train=True, download=True, transform=transform) train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size) - test_dataset = datasets.MNIST(root='./Dataset/cifar-10-batches-py/', train=False, + test_dataset = datasets.MNIST(root='./Dataset/', train=False, download=True, transform=transform) test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size) + image_fisrt = train_dataset.__getitem__(0)[0] + c = ratio2filtersize(image_fisrt, ratio) + model = DeepJSCC(c=c, channel_type=args.channel, snr=snr) + + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + epoch_loop = tqdm((args.epochs), total=len(args.epochs), leave=False) + for epoch in epoch_loop: + run_loss = 0.0 + for images, _ in tqdm((train_loader), leave=False): + optimizer.zero_grad() + outputs = model(images) + loss = criterion(outputs, images) + loss.backward() + optimizer.step() + run_loss += loss.item() + epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]') + epoch_loop.set_postfix(loss=run_loss) + save_model(model, args.saved + '/model_{}_{}.pth'.format(ratio, snr)) -def train(): - pass +def save_model(model, path): + torch.save(model.state_dict(), path) if __name__ == '__main__':