ver1.1
This commit is contained in:
parent
a7900ec006
commit
9c60ca0e2c
31
README.md
31
README.md
@ -1,7 +1,34 @@
|
|||||||
# Deep JSCC
|
# Deep JSCC
|
||||||
This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow implementation](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
|
This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow and keras implementations ](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
|
||||||
|
|
||||||
This is my first time to use PyTorch and git to reproduce a paper, so there may be some mistakes. If you find any, please let me know. Thanks!
|
This is my first time to use PyTorch and git to reproduce a paper, so there may be some mistakes. If you find any, please let me know. Thanks!
|
||||||
|
|
||||||
## Requirements
|
## Installation
|
||||||
|
conda or other virtual environment is recommended.
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://github.com/chunbaobao/Deep-JSCC-PyTorch.git
|
||||||
|
pip install requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
### Training Model
|
||||||
|
Run(example)
|
||||||
|
```
|
||||||
|
cd ./Deep-JSCC-PyTorch
|
||||||
|
python train.py --seed 2048 --epochs 200 --batch_size 256 --channel 'AWGN' --saved ./saved --snr_list [1,4,7,13,19] --ratio_list [1/6,1/12]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
If you find (part of) this code useful for your research, please consider citing
|
||||||
|
```
|
||||||
|
@misc{chunhang_Deep-JSCC,
|
||||||
|
author = {chunhang},
|
||||||
|
title = {a pytorch implementation of Deep JSCC},
|
||||||
|
url ={https://github.com/chunbaobao/Deep-JSCC-PyTorch},
|
||||||
|
year = {2023}
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
38
model.py
38
model.py
@ -9,10 +9,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import channel
|
import channel
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def _image_normalization(norm_type):
|
""" def _image_normalization(norm_type):
|
||||||
def _inner(tensor: torch.Tensor):
|
def _inner(tensor: torch.Tensor):
|
||||||
if norm_type == 'nomalization':
|
if norm_type == 'nomalization':
|
||||||
return tensor / 255.0
|
return tensor / 255.0
|
||||||
@ -20,11 +19,16 @@ def _image_normalization(norm_type):
|
|||||||
return (tensor * 255.0).type(torch.FloatTensor)
|
return (tensor * 255.0).type(torch.FloatTensor)
|
||||||
else:
|
else:
|
||||||
raise Exception('Unknown type of normalization')
|
raise Exception('Unknown type of normalization')
|
||||||
return _inner
|
return _inner """
|
||||||
|
|
||||||
|
|
||||||
def ratio2filtersize(x, ratio):
|
def ratio2filtersize(x: torch.Tensor, ratio):
|
||||||
|
if x.dim() == 4:
|
||||||
|
before_size = np.prod(x.size()[1:])
|
||||||
|
elif x.dim() == 3:
|
||||||
before_size = np.prod(x.size())
|
before_size = np.prod(x.size())
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown size of input')
|
||||||
encoder_temp = _Encoder(is_temp=True)
|
encoder_temp = _Encoder(is_temp=True)
|
||||||
z_temp = encoder_temp(x)
|
z_temp = encoder_temp(x)
|
||||||
c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
||||||
@ -60,7 +64,7 @@ class _Encoder(nn.Module):
|
|||||||
def __init__(self, c=1, is_temp=False):
|
def __init__(self, c=1, is_temp=False):
|
||||||
super(_Encoder, self).__init__()
|
super(_Encoder, self).__init__()
|
||||||
self.is_temp = is_temp
|
self.is_temp = is_temp
|
||||||
self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
||||||
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
|
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
|
||||||
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
|
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
|
||||||
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32,
|
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32,
|
||||||
@ -72,33 +76,39 @@ class _Encoder(nn.Module):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _normlizationLayer(P=1):
|
def _normlizationLayer(P=1):
|
||||||
def _inner(z_hat: torch.Tensor):
|
def _inner(z_hat: torch.Tensor):
|
||||||
|
if z_hat.dim() == 4:
|
||||||
batch_size = z_hat.size()[0]
|
batch_size = z_hat.size()[0]
|
||||||
k = np.prod(z_hat.size()[1:])
|
k = np.prod(z_hat.size()[1:])
|
||||||
|
elif z_hat.dim() == 3:
|
||||||
|
batch_size = 1
|
||||||
|
k = np.prod(z_hat.size())
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown size of input')
|
||||||
k = torch.tensor(k)
|
k = torch.tensor(k)
|
||||||
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
|
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
|
||||||
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
|
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
|
||||||
tensor = torch.sqrt(P * k) * z_hat / (z_temp @ z_trans)
|
temp = z_temp@z_trans
|
||||||
|
temp = torch.sqrt((z_temp @ z_trans))
|
||||||
|
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
|
||||||
return tensor
|
return tensor
|
||||||
return _inner
|
return _inner
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.imgae_normalization(x)
|
#x = self.imgae_normalization(x)
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.conv3(x)
|
x = self.conv3(x)
|
||||||
x = self.conv4(x)
|
x = self.conv4(x)
|
||||||
if not self.is_temp:
|
if not self.is_temp:
|
||||||
x = self.conv5(x)
|
x = self.conv5(x)
|
||||||
|
x = self.norm(x)
|
||||||
z = self.norm(x)
|
return x
|
||||||
del x
|
|
||||||
return z
|
|
||||||
|
|
||||||
|
|
||||||
class _Decoder(nn.Module):
|
class _Decoder(nn.Module):
|
||||||
def __init__(self, c=1):
|
def __init__(self, c=1):
|
||||||
super(_Decoder, self).__init__()
|
super(_Decoder, self).__init__()
|
||||||
self.imgae_normalization = _image_normalization(norm_type='denormalization')
|
#self.imgae_normalization = _image_normalization(norm_type='denormalization')
|
||||||
self.tconv1 = _TransConvWithPReLU(
|
self.tconv1 = _TransConvWithPReLU(
|
||||||
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
|
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
|
||||||
self.tconv2 = _TransConvWithPReLU(
|
self.tconv2 = _TransConvWithPReLU(
|
||||||
@ -116,7 +126,7 @@ class _Decoder(nn.Module):
|
|||||||
x = self.tconv3(x)
|
x = self.tconv3(x)
|
||||||
x = self.tconv4(x)
|
x = self.tconv4(x)
|
||||||
x = self.tconv5(x)
|
x = self.tconv5(x)
|
||||||
x = self.imgae_normalization(x)
|
#x = self.imgae_normalization(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -124,7 +134,7 @@ class DeepJSCC(nn.Module):
|
|||||||
def __init__(self, c, channel_type='AWGN', snr=20):
|
def __init__(self, c, channel_type='AWGN', snr=20):
|
||||||
super(DeepJSCC, self).__init__()
|
super(DeepJSCC, self).__init__()
|
||||||
self.encoder = _Encoder(c=c)
|
self.encoder = _Encoder(c=c)
|
||||||
self.channel = channel.channel(channel_type,snr)
|
self.channel = channel.channel(channel_type, snr)
|
||||||
self.decoder = _Decoder(c=c)
|
self.decoder = _Decoder(c=c)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|||||||
34
train.py
34
train.py
@ -4,33 +4,31 @@ Created on Tue Dec 17:00:00 2023
|
|||||||
|
|
||||||
@author: chun
|
@author: chun
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision import datasets
|
from torchvision import datasets
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import tqdm
|
from tqdm import tqdm
|
||||||
from model import DeepJSCC, ratio2filtersize
|
from model import DeepJSCC, ratio2filtersize
|
||||||
from torch.nn.parallel import DataParallel
|
from torch.nn.parallel import DataParallel
|
||||||
from channel import channel
|
|
||||||
|
|
||||||
|
|
||||||
def config_parser():
|
def config_parser():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
|
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
|
||||||
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
|
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
|
||||||
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
|
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
|
||||||
parser.add_argument('--batch_size', default=64, type=int, help='batch size')
|
parser.add_argument('--batch_size', default=64, type=int, help='batch size')
|
||||||
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
|
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
|
||||||
parser.add_argument('--weight_decay', default=1e-3, type=float, help='weight decay')
|
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
|
||||||
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
||||||
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
||||||
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
|
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
|
||||||
parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list')
|
parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list')
|
||||||
parser.add_argument('--early_stop', default=True, type=bool, help='early_stop')
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -46,40 +44,46 @@ def main():
|
|||||||
|
|
||||||
def train(args: config_parser(), ratio: float, snr: float):
|
def train(args: config_parser(), ratio: float, snr: float):
|
||||||
|
|
||||||
print("training with ratio: {}, snr: {}, channel: {}".format(ratio, snr, args.channel))
|
print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
|
||||||
|
|
||||||
|
device = torch.device('cuda:1')
|
||||||
# load data
|
# load data
|
||||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||||
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
|
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
|
||||||
download=True, transform=transform)
|
download=True, transform=transform)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
|
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
|
||||||
test_dataset = datasets.MNIST(root='./Dataset/', train=False,
|
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
|
||||||
download=True, transform=transform)
|
download=True, transform=transform)
|
||||||
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size)
|
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
|
||||||
image_fisrt = train_dataset.__getitem__(0)[0]
|
image_fisrt = train_dataset.__getitem__(0)[0]
|
||||||
c = ratio2filtersize(image_fisrt, ratio)
|
c = ratio2filtersize(image_fisrt, ratio)
|
||||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
|
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
|
||||||
|
|
||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss().cuda(device=device)
|
||||||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
epoch_loop = tqdm((args.epochs), total=len(args.epochs), leave=False)
|
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
|
||||||
|
|
||||||
for epoch in epoch_loop:
|
for epoch in epoch_loop:
|
||||||
run_loss = 0.0
|
run_loss = 0.0
|
||||||
for images, _ in tqdm((train_loader), leave=False):
|
for images, _ in tqdm((train_loader), leave=False):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
images = images.cuda(device=device)
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
loss = criterion(outputs, images)
|
loss = criterion(outputs, images)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
run_loss += loss.item()
|
run_loss += loss.item()
|
||||||
|
|
||||||
epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]')
|
epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]')
|
||||||
epoch_loop.set_postfix(loss=run_loss)
|
epoch_loop.set_postfix(loss=run_loss/len(train_loader))
|
||||||
save_model(model, args.saved + '/model_{}_{}.pth'.format(ratio, snr))
|
save_model(model, args.saved + '/model_{:2f}_{:2f}.pth'.format(ratio, snr))
|
||||||
|
|
||||||
|
|
||||||
def save_model(model, path):
|
def save_model(model, path):
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
torch.save(model.state_dict(), path)
|
torch.save(model.state_dict(), path)
|
||||||
|
print("Model saved in {}".format(path))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user