From 9c60ca0e2c1c16b7e9901f535138163e976d2855 Mon Sep 17 00:00:00 2001 From: chun Date: Fri, 22 Dec 2023 00:11:46 +0800 Subject: [PATCH] ver1.1 --- README.md | 31 +++++++++++++++++++++++++++++-- model.py | 44 +++++++++++++++++++++++++++----------------- train.py | 36 ++++++++++++++++++++---------------- 3 files changed, 76 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 220f2d9..c0cfeac 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,34 @@ # Deep JSCC -This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow implementation](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission). +This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow and keras implementations ](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission). This is my first time to use PyTorch and git to reproduce a paper, so there may be some mistakes. If you find any, please let me know. Thanks! -## Requirements +## Installation +conda or other virtual environment is recommended. + +``` +git clone https://github.com/chunbaobao/Deep-JSCC-PyTorch.git +pip install requirements.txt +``` + +## Usage +### Training Model +Run(example) +``` +cd ./Deep-JSCC-PyTorch +python train.py --seed 2048 --epochs 200 --batch_size 256 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12] +``` + +### Evaluation + + +## Citation +If you find (part of) this code useful for your research, please consider citing +``` +@misc{chunhang_Deep-JSCC, + author = {chunhang}, + title = {a pytorch implementation of Deep JSCC}, + url ={https://github.com/chunbaobao/Deep-JSCC-PyTorch}, + year = {2023} +} diff --git a/model.py b/model.py index 2dacb0b..7025d58 100644 --- a/model.py +++ b/model.py @@ -9,10 +9,9 @@ import torch import torch.nn as nn import numpy as np import channel -import torch.nn.functional as F -def _image_normalization(norm_type): +""" def _image_normalization(norm_type): def _inner(tensor: torch.Tensor): if norm_type == 'nomalization': return tensor / 255.0 @@ -20,11 +19,16 @@ def _image_normalization(norm_type): return (tensor * 255.0).type(torch.FloatTensor) else: raise Exception('Unknown type of normalization') - return _inner + return _inner """ -def ratio2filtersize(x, ratio): - before_size = np.prod(x.size()) +def ratio2filtersize(x: torch.Tensor, ratio): + if x.dim() == 4: + before_size = np.prod(x.size()[1:]) + elif x.dim() == 3: + before_size = np.prod(x.size()) + else: + raise Exception('Unknown size of input') encoder_temp = _Encoder(is_temp=True) z_temp = encoder_temp(x) c = before_size * ratio / np.prod(z_temp.size()[-2:]) @@ -60,7 +64,7 @@ class _Encoder(nn.Module): def __init__(self, c=1, is_temp=False): super(_Encoder, self).__init__() self.is_temp = is_temp - self.imgae_normalization = _image_normalization(norm_type='nomalization') + # self.imgae_normalization = _image_normalization(norm_type='nomalization') self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2) self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2) self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32, @@ -72,33 +76,39 @@ class _Encoder(nn.Module): @staticmethod def _normlizationLayer(P=1): def _inner(z_hat: torch.Tensor): - batch_size = z_hat.size()[0] - k = np.prod(z_hat.size()[1:]) + if z_hat.dim() == 4: + batch_size = z_hat.size()[0] + k = np.prod(z_hat.size()[1:]) + elif z_hat.dim() == 3: + batch_size = 1 + k = np.prod(z_hat.size()) + else: + raise Exception('Unknown size of input') k = torch.tensor(k) z_temp = z_hat.reshape(batch_size, 1, 1, -1) z_trans = z_hat.reshape(batch_size, 1, -1, 1) - tensor = torch.sqrt(P * k) * z_hat / (z_temp @ z_trans) + temp = z_temp@z_trans + temp = torch.sqrt((z_temp @ z_trans)) + tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans)) return tensor return _inner def forward(self, x): - x = self.imgae_normalization(x) + #x = self.imgae_normalization(x) x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) if not self.is_temp: x = self.conv5(x) - - z = self.norm(x) - del x - return z + x = self.norm(x) + return x class _Decoder(nn.Module): def __init__(self, c=1): super(_Decoder, self).__init__() - self.imgae_normalization = _image_normalization(norm_type='denormalization') + #self.imgae_normalization = _image_normalization(norm_type='denormalization') self.tconv1 = _TransConvWithPReLU( in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2) self.tconv2 = _TransConvWithPReLU( @@ -116,7 +126,7 @@ class _Decoder(nn.Module): x = self.tconv3(x) x = self.tconv4(x) x = self.tconv5(x) - x = self.imgae_normalization(x) + #x = self.imgae_normalization(x) return x @@ -124,7 +134,7 @@ class DeepJSCC(nn.Module): def __init__(self, c, channel_type='AWGN', snr=20): super(DeepJSCC, self).__init__() self.encoder = _Encoder(c=c) - self.channel = channel.channel(channel_type,snr) + self.channel = channel.channel(channel_type, snr) self.decoder = _Decoder(c=c) def forward(self, x): diff --git a/train.py b/train.py index 153ebbc..57b1b72 100644 --- a/train.py +++ b/train.py @@ -4,33 +4,31 @@ Created on Tue Dec 17:00:00 2023 @author: chun """ - +import os import torch import torch.nn as nn from torchvision import transforms from torchvision import datasets -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, RandomSampler import torch.optim as optim -import tqdm +from tqdm import tqdm from model import DeepJSCC, ratio2filtersize from torch.nn.parallel import DataParallel -from channel import channel def config_parser(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--seed', default=2048, type=int, help='Random seed') - parser.add_argument('--lr', default=0.1, type=float, help='learning rate') + parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') parser.add_argument('--epochs', default=100, type=int, help='number of epochs') parser.add_argument('--batch_size', default=64, type=int, help='batch size') parser.add_argument('--momentum', default=0.9, type=float, help='momentum') - parser.add_argument('--weight_decay', default=1e-3, type=float, help='weight decay') + parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay') parser.add_argument('--channel', default='AWGN', type=str, help='channel type') parser.add_argument('--saved', default='./saved', type=str, help='saved_path') parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list') parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list') - parser.add_argument('--early_stop', default=True, type=bool, help='early_stop') return parser.parse_args() @@ -46,40 +44,46 @@ def main(): def train(args: config_parser(), ratio: float, snr: float): - print("training with ratio: {}, snr: {}, channel: {}".format(ratio, snr, args.channel)) + print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel)) + device = torch.device('cuda:1') # load data transform = transforms.Compose([transforms.ToTensor(), ]) train_dataset = datasets.CIFAR10(root='./Dataset/', train=True, download=True, transform=transform) train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size) - test_dataset = datasets.MNIST(root='./Dataset/', train=False, - download=True, transform=transform) - test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size) + test_dataset = datasets.CIFAR10(root='./Dataset/', train=False, + download=True, transform=transform) + test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size) image_fisrt = train_dataset.__getitem__(0)[0] c = ratio2filtersize(image_fisrt, ratio) - model = DeepJSCC(c=c, channel_type=args.channel, snr=snr) + model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device) - criterion = nn.MSELoss() + criterion = nn.MSELoss().cuda(device=device) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) - epoch_loop = tqdm((args.epochs), total=len(args.epochs), leave=False) + epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) + for epoch in epoch_loop: run_loss = 0.0 for images, _ in tqdm((train_loader), leave=False): optimizer.zero_grad() + images = images.cuda(device=device) outputs = model(images) loss = criterion(outputs, images) loss.backward() optimizer.step() run_loss += loss.item() + epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]') - epoch_loop.set_postfix(loss=run_loss) - save_model(model, args.saved + '/model_{}_{}.pth'.format(ratio, snr)) + epoch_loop.set_postfix(loss=run_loss/len(train_loader)) + save_model(model, args.saved + '/model_{:2f}_{:2f}.pth'.format(ratio, snr)) def save_model(model, path): + os.makedirs(path, exist_ok=True) torch.save(model.state_dict(), path) + print("Model saved in {}".format(path)) if __name__ == '__main__':