This commit is contained in:
chun 2023-12-22 00:11:46 +08:00
parent a7900ec006
commit 9c60ca0e2c
3 changed files with 76 additions and 35 deletions

View File

@ -1,7 +1,34 @@
# Deep JSCC # Deep JSCC
This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow implementation](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission). This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow and keras implementations ](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
This is my first time to use PyTorch and git to reproduce a paper, so there may be some mistakes. If you find any, please let me know. Thanks! This is my first time to use PyTorch and git to reproduce a paper, so there may be some mistakes. If you find any, please let me know. Thanks!
## Requirements ## Installation
conda or other virtual environment is recommended.
```
git clone https://github.com/chunbaobao/Deep-JSCC-PyTorch.git
pip install requirements.txt
```
## Usage
### Training Model
Run(example)
```
cd ./Deep-JSCC-PyTorch
python train.py --seed 2048 --epochs 200 --batch_size 256 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12]
```
### Evaluation
## Citation
If you find (part of) this code useful for your research, please consider citing
```
@misc{chunhang_Deep-JSCC,
author = {chunhang},
title = {a pytorch implementation of Deep JSCC},
url ={https://github.com/chunbaobao/Deep-JSCC-PyTorch},
year = {2023}
}

View File

@ -9,10 +9,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import channel import channel
import torch.nn.functional as F
def _image_normalization(norm_type): """ def _image_normalization(norm_type):
def _inner(tensor: torch.Tensor): def _inner(tensor: torch.Tensor):
if norm_type == 'nomalization': if norm_type == 'nomalization':
return tensor / 255.0 return tensor / 255.0
@ -20,11 +19,16 @@ def _image_normalization(norm_type):
return (tensor * 255.0).type(torch.FloatTensor) return (tensor * 255.0).type(torch.FloatTensor)
else: else:
raise Exception('Unknown type of normalization') raise Exception('Unknown type of normalization')
return _inner return _inner """
def ratio2filtersize(x, ratio): def ratio2filtersize(x: torch.Tensor, ratio):
if x.dim() == 4:
before_size = np.prod(x.size()[1:])
elif x.dim() == 3:
before_size = np.prod(x.size()) before_size = np.prod(x.size())
else:
raise Exception('Unknown size of input')
encoder_temp = _Encoder(is_temp=True) encoder_temp = _Encoder(is_temp=True)
z_temp = encoder_temp(x) z_temp = encoder_temp(x)
c = before_size * ratio / np.prod(z_temp.size()[-2:]) c = before_size * ratio / np.prod(z_temp.size()[-2:])
@ -60,7 +64,7 @@ class _Encoder(nn.Module):
def __init__(self, c=1, is_temp=False): def __init__(self, c=1, is_temp=False):
super(_Encoder, self).__init__() super(_Encoder, self).__init__()
self.is_temp = is_temp self.is_temp = is_temp
self.imgae_normalization = _image_normalization(norm_type='nomalization') # self.imgae_normalization = _image_normalization(norm_type='nomalization')
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2) self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2) self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32, self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32,
@ -72,33 +76,39 @@ class _Encoder(nn.Module):
@staticmethod @staticmethod
def _normlizationLayer(P=1): def _normlizationLayer(P=1):
def _inner(z_hat: torch.Tensor): def _inner(z_hat: torch.Tensor):
if z_hat.dim() == 4:
batch_size = z_hat.size()[0] batch_size = z_hat.size()[0]
k = np.prod(z_hat.size()[1:]) k = np.prod(z_hat.size()[1:])
elif z_hat.dim() == 3:
batch_size = 1
k = np.prod(z_hat.size())
else:
raise Exception('Unknown size of input')
k = torch.tensor(k) k = torch.tensor(k)
z_temp = z_hat.reshape(batch_size, 1, 1, -1) z_temp = z_hat.reshape(batch_size, 1, 1, -1)
z_trans = z_hat.reshape(batch_size, 1, -1, 1) z_trans = z_hat.reshape(batch_size, 1, -1, 1)
tensor = torch.sqrt(P * k) * z_hat / (z_temp @ z_trans) temp = z_temp@z_trans
temp = torch.sqrt((z_temp @ z_trans))
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
return tensor return tensor
return _inner return _inner
def forward(self, x): def forward(self, x):
x = self.imgae_normalization(x) #x = self.imgae_normalization(x)
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x) x = self.conv2(x)
x = self.conv3(x) x = self.conv3(x)
x = self.conv4(x) x = self.conv4(x)
if not self.is_temp: if not self.is_temp:
x = self.conv5(x) x = self.conv5(x)
x = self.norm(x)
z = self.norm(x) return x
del x
return z
class _Decoder(nn.Module): class _Decoder(nn.Module):
def __init__(self, c=1): def __init__(self, c=1):
super(_Decoder, self).__init__() super(_Decoder, self).__init__()
self.imgae_normalization = _image_normalization(norm_type='denormalization') #self.imgae_normalization = _image_normalization(norm_type='denormalization')
self.tconv1 = _TransConvWithPReLU( self.tconv1 = _TransConvWithPReLU(
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2) in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv2 = _TransConvWithPReLU( self.tconv2 = _TransConvWithPReLU(
@ -116,7 +126,7 @@ class _Decoder(nn.Module):
x = self.tconv3(x) x = self.tconv3(x)
x = self.tconv4(x) x = self.tconv4(x)
x = self.tconv5(x) x = self.tconv5(x)
x = self.imgae_normalization(x) #x = self.imgae_normalization(x)
return x return x
@ -124,7 +134,7 @@ class DeepJSCC(nn.Module):
def __init__(self, c, channel_type='AWGN', snr=20): def __init__(self, c, channel_type='AWGN', snr=20):
super(DeepJSCC, self).__init__() super(DeepJSCC, self).__init__()
self.encoder = _Encoder(c=c) self.encoder = _Encoder(c=c)
self.channel = channel.channel(channel_type,snr) self.channel = channel.channel(channel_type, snr)
self.decoder = _Decoder(c=c) self.decoder = _Decoder(c=c)
def forward(self, x): def forward(self, x):

View File

@ -4,33 +4,31 @@ Created on Tue Dec 17:00:00 2023
@author: chun @author: chun
""" """
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchvision import transforms from torchvision import transforms
from torchvision import datasets from torchvision import datasets
from torch.utils.data import DataLoader from torch.utils.data import DataLoader, RandomSampler
import torch.optim as optim import torch.optim as optim
import tqdm from tqdm import tqdm
from model import DeepJSCC, ratio2filtersize from model import DeepJSCC, ratio2filtersize
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
from channel import channel
def config_parser(): def config_parser():
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=2048, type=int, help='Random seed') parser.add_argument('--seed', default=2048, type=int, help='Random seed')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--epochs', default=100, type=int, help='number of epochs') parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
parser.add_argument('--batch_size', default=64, type=int, help='batch size') parser.add_argument('--batch_size', default=64, type=int, help='batch size')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum') parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=1e-3, type=float, help='weight decay') parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str, help='channel type') parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path') parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list') parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list') parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list')
parser.add_argument('--early_stop', default=True, type=bool, help='early_stop')
return parser.parse_args() return parser.parse_args()
@ -46,40 +44,46 @@ def main():
def train(args: config_parser(), ratio: float, snr: float): def train(args: config_parser(), ratio: float, snr: float):
print("training with ratio: {}, snr: {}, channel: {}".format(ratio, snr, args.channel)) print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
device = torch.device('cuda:1')
# load data # load data
transform = transforms.Compose([transforms.ToTensor(), ]) transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True, train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
download=True, transform=transform) download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size) train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
test_dataset = datasets.MNIST(root='./Dataset/', train=False, test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
download=True, transform=transform) download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size) test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
image_fisrt = train_dataset.__getitem__(0)[0] image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio) c = ratio2filtersize(image_fisrt, ratio)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
criterion = nn.MSELoss() criterion = nn.MSELoss().cuda(device=device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
epoch_loop = tqdm((args.epochs), total=len(args.epochs), leave=False) epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
for epoch in epoch_loop: for epoch in epoch_loop:
run_loss = 0.0 run_loss = 0.0
for images, _ in tqdm((train_loader), leave=False): for images, _ in tqdm((train_loader), leave=False):
optimizer.zero_grad() optimizer.zero_grad()
images = images.cuda(device=device)
outputs = model(images) outputs = model(images)
loss = criterion(outputs, images) loss = criterion(outputs, images)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
run_loss += loss.item() run_loss += loss.item()
epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]') epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]')
epoch_loop.set_postfix(loss=run_loss) epoch_loop.set_postfix(loss=run_loss/len(train_loader))
save_model(model, args.saved + '/model_{}_{}.pth'.format(ratio, snr)) save_model(model, args.saved + '/model_{:2f}_{:2f}.pth'.format(ratio, snr))
def save_model(model, path): def save_model(model, path):
os.makedirs(path, exist_ok=True)
torch.save(model.state_dict(), path) torch.save(model.state_dict(), path)
print("Model saved in {}".format(path))
if __name__ == '__main__': if __name__ == '__main__':