ver1.1
This commit is contained in:
parent
a7900ec006
commit
9c60ca0e2c
31
README.md
31
README.md
@ -1,7 +1,34 @@
|
||||
# Deep JSCC
|
||||
This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow implementation](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
|
||||
This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow and keras implementations ](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
|
||||
|
||||
This is my first time to use PyTorch and git to reproduce a paper, so there may be some mistakes. If you find any, please let me know. Thanks!
|
||||
|
||||
## Requirements
|
||||
## Installation
|
||||
conda or other virtual environment is recommended.
|
||||
|
||||
```
|
||||
git clone https://github.com/chunbaobao/Deep-JSCC-PyTorch.git
|
||||
pip install requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
### Training Model
|
||||
Run(example)
|
||||
```
|
||||
cd ./Deep-JSCC-PyTorch
|
||||
python train.py --seed 2048 --epochs 200 --batch_size 256 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12]
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
|
||||
## Citation
|
||||
If you find (part of) this code useful for your research, please consider citing
|
||||
```
|
||||
@misc{chunhang_Deep-JSCC,
|
||||
author = {chunhang},
|
||||
title = {a pytorch implementation of Deep JSCC},
|
||||
url ={https://github.com/chunbaobao/Deep-JSCC-PyTorch},
|
||||
year = {2023}
|
||||
}
|
||||
|
||||
|
||||
44
model.py
44
model.py
@ -9,10 +9,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import channel
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def _image_normalization(norm_type):
|
||||
""" def _image_normalization(norm_type):
|
||||
def _inner(tensor: torch.Tensor):
|
||||
if norm_type == 'nomalization':
|
||||
return tensor / 255.0
|
||||
@ -20,11 +19,16 @@ def _image_normalization(norm_type):
|
||||
return (tensor * 255.0).type(torch.FloatTensor)
|
||||
else:
|
||||
raise Exception('Unknown type of normalization')
|
||||
return _inner
|
||||
return _inner """
|
||||
|
||||
|
||||
def ratio2filtersize(x, ratio):
|
||||
before_size = np.prod(x.size())
|
||||
def ratio2filtersize(x: torch.Tensor, ratio):
|
||||
if x.dim() == 4:
|
||||
before_size = np.prod(x.size()[1:])
|
||||
elif x.dim() == 3:
|
||||
before_size = np.prod(x.size())
|
||||
else:
|
||||
raise Exception('Unknown size of input')
|
||||
encoder_temp = _Encoder(is_temp=True)
|
||||
z_temp = encoder_temp(x)
|
||||
c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
||||
@ -60,7 +64,7 @@ class _Encoder(nn.Module):
|
||||
def __init__(self, c=1, is_temp=False):
|
||||
super(_Encoder, self).__init__()
|
||||
self.is_temp = is_temp
|
||||
self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
||||
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
||||
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
|
||||
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
|
||||
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32,
|
||||
@ -72,33 +76,39 @@ class _Encoder(nn.Module):
|
||||
@staticmethod
|
||||
def _normlizationLayer(P=1):
|
||||
def _inner(z_hat: torch.Tensor):
|
||||
batch_size = z_hat.size()[0]
|
||||
k = np.prod(z_hat.size()[1:])
|
||||
if z_hat.dim() == 4:
|
||||
batch_size = z_hat.size()[0]
|
||||
k = np.prod(z_hat.size()[1:])
|
||||
elif z_hat.dim() == 3:
|
||||
batch_size = 1
|
||||
k = np.prod(z_hat.size())
|
||||
else:
|
||||
raise Exception('Unknown size of input')
|
||||
k = torch.tensor(k)
|
||||
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
|
||||
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
|
||||
tensor = torch.sqrt(P * k) * z_hat / (z_temp @ z_trans)
|
||||
temp = z_temp@z_trans
|
||||
temp = torch.sqrt((z_temp @ z_trans))
|
||||
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
|
||||
return tensor
|
||||
return _inner
|
||||
|
||||
def forward(self, x):
|
||||
x = self.imgae_normalization(x)
|
||||
#x = self.imgae_normalization(x)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
if not self.is_temp:
|
||||
x = self.conv5(x)
|
||||
|
||||
z = self.norm(x)
|
||||
del x
|
||||
return z
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class _Decoder(nn.Module):
|
||||
def __init__(self, c=1):
|
||||
super(_Decoder, self).__init__()
|
||||
self.imgae_normalization = _image_normalization(norm_type='denormalization')
|
||||
#self.imgae_normalization = _image_normalization(norm_type='denormalization')
|
||||
self.tconv1 = _TransConvWithPReLU(
|
||||
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
|
||||
self.tconv2 = _TransConvWithPReLU(
|
||||
@ -116,7 +126,7 @@ class _Decoder(nn.Module):
|
||||
x = self.tconv3(x)
|
||||
x = self.tconv4(x)
|
||||
x = self.tconv5(x)
|
||||
x = self.imgae_normalization(x)
|
||||
#x = self.imgae_normalization(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -124,7 +134,7 @@ class DeepJSCC(nn.Module):
|
||||
def __init__(self, c, channel_type='AWGN', snr=20):
|
||||
super(DeepJSCC, self).__init__()
|
||||
self.encoder = _Encoder(c=c)
|
||||
self.channel = channel.channel(channel_type,snr)
|
||||
self.channel = channel.channel(channel_type, snr)
|
||||
self.decoder = _Decoder(c=c)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
36
train.py
36
train.py
@ -4,33 +4,31 @@ Created on Tue Dec 17:00:00 2023
|
||||
|
||||
@author: chun
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import transforms
|
||||
from torchvision import datasets
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
from tqdm import tqdm
|
||||
from model import DeepJSCC, ratio2filtersize
|
||||
from torch.nn.parallel import DataParallel
|
||||
from channel import channel
|
||||
|
||||
|
||||
def config_parser():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
|
||||
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
|
||||
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
|
||||
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
|
||||
parser.add_argument('--batch_size', default=64, type=int, help='batch size')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
|
||||
parser.add_argument('--weight_decay', default=1e-3, type=float, help='weight decay')
|
||||
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
|
||||
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
||||
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
||||
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
|
||||
parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list')
|
||||
parser.add_argument('--early_stop', default=True, type=bool, help='early_stop')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -46,40 +44,46 @@ def main():
|
||||
|
||||
def train(args: config_parser(), ratio: float, snr: float):
|
||||
|
||||
print("training with ratio: {}, snr: {}, channel: {}".format(ratio, snr, args.channel))
|
||||
print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
|
||||
|
||||
device = torch.device('cuda:1')
|
||||
# load data
|
||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
|
||||
download=True, transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
|
||||
test_dataset = datasets.MNIST(root='./Dataset/', train=False,
|
||||
download=True, transform=transform)
|
||||
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size)
|
||||
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
|
||||
download=True, transform=transform)
|
||||
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
|
||||
image_fisrt = train_dataset.__getitem__(0)[0]
|
||||
c = ratio2filtersize(image_fisrt, ratio)
|
||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
|
||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
criterion = nn.MSELoss().cuda(device=device)
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
epoch_loop = tqdm((args.epochs), total=len(args.epochs), leave=False)
|
||||
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
|
||||
|
||||
for epoch in epoch_loop:
|
||||
run_loss = 0.0
|
||||
for images, _ in tqdm((train_loader), leave=False):
|
||||
optimizer.zero_grad()
|
||||
images = images.cuda(device=device)
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, images)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
run_loss += loss.item()
|
||||
|
||||
epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]')
|
||||
epoch_loop.set_postfix(loss=run_loss)
|
||||
save_model(model, args.saved + '/model_{}_{}.pth'.format(ratio, snr))
|
||||
epoch_loop.set_postfix(loss=run_loss/len(train_loader))
|
||||
save_model(model, args.saved + '/model_{:2f}_{:2f}.pth'.format(ratio, snr))
|
||||
|
||||
|
||||
def save_model(model, path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torch.save(model.state_dict(), path)
|
||||
print("Model saved in {}".format(path))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Loading…
Reference in New Issue
Block a user