This commit is contained in:
chun 2023-12-22 00:11:46 +08:00
parent a7900ec006
commit 9c60ca0e2c
3 changed files with 76 additions and 35 deletions

View File

@ -1,7 +1,34 @@
# Deep JSCC
This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow implementation](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow and keras implementations ](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
This is my first time to use PyTorch and git to reproduce a paper, so there may be some mistakes. If you find any, please let me know. Thanks!
## Requirements
## Installation
conda or other virtual environment is recommended.
```
git clone https://github.com/chunbaobao/Deep-JSCC-PyTorch.git
pip install requirements.txt
```
## Usage
### Training Model
Run(example)
```
cd ./Deep-JSCC-PyTorch
python train.py --seed 2048 --epochs 200 --batch_size 256 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12]
```
### Evaluation
## Citation
If you find (part of) this code useful for your research, please consider citing
```
@misc{chunhang_Deep-JSCC,
author = {chunhang},
title = {a pytorch implementation of Deep JSCC},
url ={https://github.com/chunbaobao/Deep-JSCC-PyTorch},
year = {2023}
}

View File

@ -9,10 +9,9 @@ import torch
import torch.nn as nn
import numpy as np
import channel
import torch.nn.functional as F
def _image_normalization(norm_type):
""" def _image_normalization(norm_type):
def _inner(tensor: torch.Tensor):
if norm_type == 'nomalization':
return tensor / 255.0
@ -20,11 +19,16 @@ def _image_normalization(norm_type):
return (tensor * 255.0).type(torch.FloatTensor)
else:
raise Exception('Unknown type of normalization')
return _inner
return _inner """
def ratio2filtersize(x, ratio):
def ratio2filtersize(x: torch.Tensor, ratio):
if x.dim() == 4:
before_size = np.prod(x.size()[1:])
elif x.dim() == 3:
before_size = np.prod(x.size())
else:
raise Exception('Unknown size of input')
encoder_temp = _Encoder(is_temp=True)
z_temp = encoder_temp(x)
c = before_size * ratio / np.prod(z_temp.size()[-2:])
@ -60,7 +64,7 @@ class _Encoder(nn.Module):
def __init__(self, c=1, is_temp=False):
super(_Encoder, self).__init__()
self.is_temp = is_temp
self.imgae_normalization = _image_normalization(norm_type='nomalization')
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32,
@ -72,33 +76,39 @@ class _Encoder(nn.Module):
@staticmethod
def _normlizationLayer(P=1):
def _inner(z_hat: torch.Tensor):
if z_hat.dim() == 4:
batch_size = z_hat.size()[0]
k = np.prod(z_hat.size()[1:])
elif z_hat.dim() == 3:
batch_size = 1
k = np.prod(z_hat.size())
else:
raise Exception('Unknown size of input')
k = torch.tensor(k)
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
tensor = torch.sqrt(P * k) * z_hat / (z_temp @ z_trans)
temp = z_temp@z_trans
temp = torch.sqrt((z_temp @ z_trans))
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
return tensor
return _inner
def forward(self, x):
x = self.imgae_normalization(x)
#x = self.imgae_normalization(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
if not self.is_temp:
x = self.conv5(x)
z = self.norm(x)
del x
return z
x = self.norm(x)
return x
class _Decoder(nn.Module):
def __init__(self, c=1):
super(_Decoder, self).__init__()
self.imgae_normalization = _image_normalization(norm_type='denormalization')
#self.imgae_normalization = _image_normalization(norm_type='denormalization')
self.tconv1 = _TransConvWithPReLU(
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv2 = _TransConvWithPReLU(
@ -116,7 +126,7 @@ class _Decoder(nn.Module):
x = self.tconv3(x)
x = self.tconv4(x)
x = self.tconv5(x)
x = self.imgae_normalization(x)
#x = self.imgae_normalization(x)
return x
@ -124,7 +134,7 @@ class DeepJSCC(nn.Module):
def __init__(self, c, channel_type='AWGN', snr=20):
super(DeepJSCC, self).__init__()
self.encoder = _Encoder(c=c)
self.channel = channel.channel(channel_type,snr)
self.channel = channel.channel(channel_type, snr)
self.decoder = _Decoder(c=c)
def forward(self, x):

View File

@ -4,33 +4,31 @@ Created on Tue Dec 17:00:00 2023
@author: chun
"""
import os
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, RandomSampler
import torch.optim as optim
import tqdm
from tqdm import tqdm
from model import DeepJSCC, ratio2filtersize
from torch.nn.parallel import DataParallel
from channel import channel
def config_parser():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
parser.add_argument('--batch_size', default=64, type=int, help='batch size')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=1e-3, type=float, help='weight decay')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list')
parser.add_argument('--early_stop', default=True, type=bool, help='early_stop')
return parser.parse_args()
@ -46,40 +44,46 @@ def main():
def train(args: config_parser(), ratio: float, snr: float):
print("training with ratio: {}, snr: {}, channel: {}".format(ratio, snr, args.channel))
print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
device = torch.device('cuda:1')
# load data
transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
test_dataset = datasets.MNIST(root='./Dataset/', train=False,
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size)
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
criterion = nn.MSELoss()
criterion = nn.MSELoss().cuda(device=device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
epoch_loop = tqdm((args.epochs), total=len(args.epochs), leave=False)
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
for epoch in epoch_loop:
run_loss = 0.0
for images, _ in tqdm((train_loader), leave=False):
optimizer.zero_grad()
images = images.cuda(device=device)
outputs = model(images)
loss = criterion(outputs, images)
loss.backward()
optimizer.step()
run_loss += loss.item()
epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]')
epoch_loop.set_postfix(loss=run_loss)
save_model(model, args.saved + '/model_{}_{}.pth'.format(ratio, snr))
epoch_loop.set_postfix(loss=run_loss/len(train_loader))
save_model(model, args.saved + '/model_{:2f}_{:2f}.pth'.format(ratio, snr))
def save_model(model, path):
os.makedirs(path, exist_ok=True)
torch.save(model.state_dict(), path)
print("Model saved in {}".format(path))
if __name__ == '__main__':