Update-2024.06.04

This commit is contained in:
chun 2024-06-04 11:19:08 +08:00
parent ca6377e98c
commit e1cef6aead
87 changed files with 1436 additions and 209 deletions

4
.gitignore vendored
View File

@ -1,11 +1,11 @@
test.py
test.*
*.pyc
*.log
dataset
*.ipynb
*.swp
.vscode/*
input.txt
output.txt
*.json
.vscode/*
*.sh

42
.vscode/launch.json vendored
View File

@ -1,42 +0,0 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: 当前文件",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"args": [
"--lr",
"1e-3",
"--epochs",
"100",
"--batch_size",
"512",
"--if_scheduler",
"1",
"--step_size",
"500",
"--dataset",
"cifar10",
"--num_workers",
"4",
"--device",
"cuda:0",
"--ratio_list",
"1/3",
"--snr_list",
"100",
"--seed",
"42",
"--disable_tqdm",
"False"
]
}
]
}

115
README.md
View File

@ -1,19 +1,35 @@
# Deep JSCC
This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow and keras implementations ](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
This implements training of Deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow and keras implementations ](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
This is my first time to use PyTorch and git to reproduce a paper, so there may be some mistakes. If you find any, please let me know. Thanks!
## Update-2024.06.04
- modify the `train.py` to omit most of the args in command line, you can just us `python train.py --dataset ${dataset_name}` to train the model.
- add tensorboard to record the results in exp.
- add the `visualization/` file to visualize the result.
- add bash file to run the code in parallel.
## Architecture
![architecture](./demo/arc.png)
<div style="text-align: center;">
<img src="./demo/arc.png" alt="Image 1" style="width: 500px; height: 250px;">
</div>
## Demo
the model trained on cifar10 which is 32\*32 but test on kodim which is 768\*512.
![demo1](./run/cifar10_3000_0.33_100.00_256_40.pth_kodim08.png)
the model trained on imagenet which is resized to 128\*128 but test on kodim which is 768\*512.
![demo2](./run/imagenet_10_0.33_200.00_32_19.pth_kodim08.png)
The model trained on cifar10 which is 32\*32 but test on kodim which is 768\*512 (left); and the model trained on imagenet which is resized to 128\*128 but test on kodim which is 768\*512 (right).
<div style="display: flex;">
<img src="./demo/cifar10_kodim08.png" alt="Image 1" style="flex: 1;" width="150" height="450">
<div style="width: 20px;"></div>
<img src="./demo/imagenet_kodim08.png" alt="Image 1" style="flex: 1;" width="150" height="450">
</div>
## Installation
conda or other virtual environment is recommended.
@ -31,25 +47,88 @@ The cifar10 dataset can be downloaded automatically by torchvision. But the imag
python dataset.py
```
### Training Model
Run(example presented in paper) on cifar10
### Training
The training command used to be very long, but now you can just use `python train.py --dataset ${dataset_name} --channel ${channel}` to train the model.
The default dataset is cifar10.
The parmeters can be modified in the `train.py` file. The default parameters are similar to the paper.
```
python train.py --lr 1e-3 --epochs 1000 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list 1 4 7 13 19 --ratio_list 1/6 1/12 --dataset cifar10 --num_workers 4 --parallel True --if_scheduler True --scheduler_step_size 50
```
or Run(example presented in paper) on imagenet
| Parameters | CIFAR-10 | ImageNet |
|------------------------|------------------|------------------|
| `batch_size` | 64 | 32 |
| `init_lr` | 1e-3 | 1e-4 |
| `weight_decay` | 5e-4 | 5e-4 |
| `snr_list` | [19, 13, 7, 4, 1]| [19, 13, 7, 4, 1]|
| `ratio_list` | [1/6, 1/12] | [1/6, 1/12] |
| `if_scheduler` | True | False |
| `step_size` | 640 | N/A |
| `gamma` | 0.1 | 0.1 |
<!-- ALSO! The batch_size for cifar10 training in the paper is small causing the GPU utilization is low. So The bash script is provided to run the code in parallel for different snr and ratio for cifar10 dataset. (Example of two GPUs)
```
python train.py --lr 10e-4 --epochs 300 --batch_size 32 --channel 'AWGN' --saved ./saved --dataset imagenet --num_workers 4 --parallel True
```
bash parallel_train_cifar.sh --channel ${channel}
``` -->
### Evaluation
Run(example presented in paper)
The `eval.py` provides the evaluation of the trained model.
You may need modify slightly to evaluate the model for different snr_list and channel type in `main` function.
```
python eval.py --channel 'AWGN' --saved ./saved/${mode_path} --snr 20 --test_img ${test_img_path}
python eval.py
```
All training and evaluation results are saved in the `./out` directory by default. The `./out` directory may contain the structure as follows:
```
./out
├── checkpoint # trained models
│   ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
│   ├── epoch_$num.pth
│   ├── ...
│   ├── CIFAR10_10_1.0_0.08_AWGN_13h21m37s_on_Jun_02_2024
│   ├── CIFAR10_20_7.0_0.17_Rayleigh_14h03m19s_on_Jun_03_2024
│   ├── ...
├── configs # training configurations
│   ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
│   ├── $CIFAR10_10_4.0_0.08_AWGN_13h21m38s_on_Jun_02_2024.yaml
│   ├── ...
├── logs # training logs
│   ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
│   ├── tensorboard logs
│   ├── ...
├── eval # evaluation results
│   ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
│   ├── tensorboard logs
│   ├── ...
```
### Visualization
The `./visualization` directory contains the scripts for visualization of the training and evaluation results.
- `single_visualization.ipynb` is used to get demo of the model on single image like the demo above.
- `plot_visualization.ipynb` is used to get visualizations of the perfprmance of the model on different snr and ratio.
## Results
Model results and logs for the **CIFAR-10** dataset, tested under various SNR, ratio, and channel types, are available in the `./out` directory. The models' performance is approximately 5dB worse than reported in the paper, likely due to the implementation reflecting a real communication system. However, the performance trends are consistent with those in the paper.
<div style="display: flex;">
<img src="demo/cifar_0.08.png" alt="Image 1" style="flex: 1; max-width: 50%; height: auto;">
<div style="width: 0px;"></div> <!-- 为了让两个图像之间有一点间距 -->
<img src="demo/cifar_0.17.png" alt="Image 2" style="flex: 1; max-width: 50%; height: auto;">
</div>
### TO-DO
- Add visualization of the model
- plot the results with different snr and ratio
- ~~Add visualization of the model~~
- ~~plot the results with different snr and ratio~~
- ~~add Rayleigh channel~~
- train on imagenet
- **Convert the real communication system to a complex communication system**
## Citation
If you find (part of) this code useful for your research, please consider citing

View File

@ -1,30 +0,0 @@
from time import time
import multiprocessing as mp
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, RandomSampler
if __name__ == '__main__':
transform = transforms.Compose([
torchvision.transforms.ToTensor(),
])
trainset = torchvision.datasets.CIFAR10(
root='./dataset/',
train=True, # 如果为True从 training.pt 创建数据,否则从 test.pt 创建数据。
download=True, # 如果为true则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
transform=transform
)
print(f"num of CPU: {mp.cpu_count()}")
for num_workers in range(2, mp.cpu_count(), 2):
train_loader = torch.utils.data.DataLoader(
trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True)
start = time()
for epoch in range(1, 3):
for i, data in enumerate(train_loader, 0):
pass
end = time()
print("Finish with:{} second, num_workers={}".format(end - start, num_workers))

View File

@ -2,8 +2,15 @@ import torch
import torch.nn as nn
def channel(channel_type='AWGN', snr=20):
def AWGN_channel(z_hat: torch.Tensor):
class Channel(nn.Module):
def __init__(self, channel_type='AWGN', snr=20):
if channel_type not in ['AWGN', 'Rayleigh']:
raise Exception('Unknown type of channel')
super(Channel, self).__init__()
self.channel_type = channel_type
self.snr = snr
def forward(self, z_hat):
if z_hat.dim() == 4:
# k = np.prod(z_hat.size()[1:])
k = torch.prod(torch.tensor(z_hat.size()[1:]))
@ -12,16 +19,27 @@ def channel(channel_type='AWGN', snr=20):
# k = np.prod(z_hat.size())
k = torch.prod(torch.tensor(z_hat.size()))
sig_pwr = torch.sum(torch.abs(z_hat).square()) / k
noi_pwr = sig_pwr / (10 ** (snr / 10))
noi_pwr = sig_pwr / (10 ** (self.snr / 10))
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
if self.channel_type == 'Rayleigh':
# hc = torch.randn_like(z_hat) wrong implement before
hc = torch.randn(1, device = z_hat.device)
z_hat = hc * z_hat
return z_hat + noise
def Rayleigh_channel(z_hat: torch.Tensor):
pass
def get_channel(self):
return self.channel_type, self.snr
if channel_type == 'AWGN':
return AWGN_channel
elif channel_type == 'Rayleigh':
return Rayleigh_channel
else:
raise Exception('Unknown type of channel')
if __name__ == '__main__':
# test
channel = Channel(channel_type='AWGN', snr=10)
z_hat = torch.randn(64, 10, 5, 5)
z_hat = channel(z_hat)
print(z_hat)
channel = Channel(channel_type='Rayleigh', snr=10)
z_hat = torch.randn(64, 10, 5, 5)
z_hat = channel(z_hat)
print(z_hat)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.1 KiB

View File

Before

Width:  |  Height:  |  Size: 1.6 MiB

After

Width:  |  Height:  |  Size: 1.6 MiB

BIN
demo/cifar_0.08.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

BIN
demo/cifar_0.17.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

View File

Before

Width:  |  Height:  |  Size: 1.5 MiB

After

Width:  |  Height:  |  Size: 1.5 MiB

102
eval.py
View File

@ -1,51 +1,75 @@
# to be implemented
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from utils import get_psnr, image_normalization
from utils import get_psnr
import os
from model import DeepJSCC
from train import evaluate_epoch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from dataset import Vanilla
import yaml
from tensorboardX import SummaryWriter
import glob
from concurrent.futures import ProcessPoolExecutor
def eval_snr(model, test_loader, writer, param, times=10):
snr_list = range(0, 26, 1)
for snr in snr_list:
model.change_channel(param['channel'], snr)
test_loss = 0
for i in range(times):
test_loss += evaluate_epoch(model, param, test_loader)
test_loss /= times
psnr = get_psnr(image=None, gt=None, mse=test_loss)
writer.add_scalar('psnr', psnr, snr)
def config_parser():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
parser.add_argument('--saved', type=str, help='saved_path')
parser.add_argument('--snr', default=20, type=int, help='snr')
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
parser.add_argument('--times', default=10, type=int, help='num_workers')
return parser.parse_args()
def process_config(config_path, output_dir, dataset_name, times):
with open(config_path, 'r') as f:
config = yaml.load(f, Loader=yaml.UnsafeLoader)
assert dataset_name == config['dataset_name']
params = config['params']
c = config['inner_channel']
if dataset_name == 'cifar10':
transform = transforms.Compose([transforms.ToTensor(), ])
test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=params['batch_size'], num_workers=params['num_workers'])
elif dataset_name == 'imagenet':
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
test_dataset = Vanilla(root='../dataset/ImageNet/val', transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=params['batch_size'], num_workers=params['num_workers'])
else:
raise Exception('Unknown dataset')
name = os.path.splitext(os.path.basename(config_path))[0]
writer = SummaryWriter(os.path.join(output_dir, 'eval', name))
model = DeepJSCC(c=c)
model = model.to(params['device'])
pkl_list = glob.glob(os.path.join(output_dir, 'checkpoints', name, '*.pkl'))
model.load_state_dict(torch.load(pkl_list[-1]))
eval_snr(model, test_loader, writer, params, times)
writer.close()
def main():
args = config_parser()
transform = transforms.Compose([transforms.ToTensor()])
test_image = Image.open(args.test_image)
test_image.load()
test_image = transform(test_image)
times = 10
dataset_name = 'cifar10'
output_dir = './out'
channel_type = 'AWGN'
config_dir = os.path.join(output_dir, 'configs')
config_files = [os.path.join(config_dir, name) for name in os.listdir(config_dir)
if (dataset_name in name or dataset_name.upper() in name) and channel_type in name and name.endswith('.yaml')]
file_name = os.path.basename(args.saved)
c = file_name.split('_')[-1].split('.')[0]
c = int(c)
model = DeepJSCC(c=c, channel_type=args.channel, snr=args.snr)
# model.load_state_dict(torch.load(args.saved))
model.load_state_dict(torch.load(args.saved,map_location=torch.device('cuda:0')))
model.change_channel(args.channel, args.snr)
psnr_all = 0.0
for i in range(args.times):
demo_image = model(test_image)
demo_image = image_normalization('denormalization')(demo_image)
gt = image_normalization('denormalization')(test_image)
psnr_all += get_psnr(demo_image, gt)
demo_image = image_normalization('normalization')(demo_image)
demo_image = torch.cat([test_image, demo_image], dim=1)
demo_image = transforms.ToPILImage()(demo_image)
demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1], args.test_image.split('/')[-1]))
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
with ProcessPoolExecutor() as executor:
executor.map(process_config, config_files, [output_dir] * len(config_files), [dataset_name] * len(config_files), [times] * len(config_files))
if __name__ == '__main__':

View File

@ -7,7 +7,7 @@ Created on Tue Dec 11:00:00 2023
import torch
import torch.nn as nn
import channel
from channel import Channel
""" def _image_normalization(norm_type):
@ -58,7 +58,8 @@ class _TransConvWithPReLU(nn.Module):
in_channels, out_channels, kernel_size, stride, padding, output_padding)
self.activate = activate
if activate == nn.PReLU():
nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out', nonlinearity='leaky_relu')
nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out',
nonlinearity='leaky_relu')
else:
nn.init.xavier_normal_(self.transconv.weight)
@ -141,17 +142,45 @@ class _Decoder(nn.Module):
class DeepJSCC(nn.Module):
def __init__(self, c, channel_type='AWGN', snr=20):
def __init__(self, c, channel_type='AWGN', snr=None):
super(DeepJSCC, self).__init__()
self.encoder = _Encoder(c=c)
self.channel = channel.channel(channel_type, snr)
if snr is not None:
self.channel = Channel(channel_type, snr)
self.decoder = _Decoder(c=c)
def forward(self, x):
z = self.encoder(x)
if hasattr(self, 'channel') and self.channel is not None:
z = self.channel(z)
x_hat = self.decoder(z)
return x_hat
def change_channel(self, channel_type, snr):
self.channel = channel.channel(channel_type, snr)
def change_channel(self, channel_type='AWGN', snr=None):
if snr is None:
self.channel = None
else:
self.channel = Channel(channel_type, snr)
def get_channel(self):
if hasattr(self, 'channel') and self.channel is not None:
return self.channel.get_channel()
return None
def loss(self, prd, gt):
criterion = nn.MSELoss(reduction='mean')
loss = criterion(prd, gt)
return loss
if __name__ == '__main__':
model = DeepJSCC(c=20)
print(model)
x = torch.rand(1, 3, 32, 32)
y = model(x)
print(y.size())
print(y)
print(model.encoder.norm)
print(model.encoder.norm(y))
print(model.encoder.norm(y).size())
print(model.encoder.norm(y).size()[1:])

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 10
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:1
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.08333333333333333
ratio_list:
- 0.08333333333333333
seed: 42
snr: 1.0
snr_list:
- 1.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
41QCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 10
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:1
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.08333333333333333
ratio_list:
- 0.08333333333333333
seed: 42
snr: 13.0
snr_list:
- 13.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
41QCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 10
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:1
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.08333333333333333
ratio_list:
- 0.08333333333333333
seed: 42
snr: 19.0
snr_list:
- 19.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
41QCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 10
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:1
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.08333333333333333
ratio_list:
- 0.08333333333333333
seed: 42
snr: 4.0
snr_list:
- 4.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
41QCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 10
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:1
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.08333333333333333
ratio_list:
- 0.08333333333333333
seed: 42
snr: 7.0
snr_list:
- 7.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
41QCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 20
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:0
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.16666666666666666
ratio_list:
- 0.16666666666666666
seed: 42
snr: 1.0
snr_list:
- 1.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
bZMCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 20
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:0
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.16666666666666666
ratio_list:
- 0.16666666666666666
seed: 42
snr: 13.0
snr_list:
- 13.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
bZMCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 20
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:0
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.16666666666666666
ratio_list:
- 0.16666666666666666
seed: 42
snr: 19.0
snr_list:
- 19.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
bZMCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 20
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:0
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.16666666666666666
ratio_list:
- 0.16666666666666666
seed: 42
snr: 4.0
snr_list:
- 4.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
bZMCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 20
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:0
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.16666666666666666
ratio_list:
- 0.16666666666666666
seed: 42
snr: 7.0
snr_list:
- 7.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
bZMCAAAAAAA=

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.0 MiB

365
train.py
View File

@ -14,60 +14,301 @@ import torch.optim as optim
from tqdm import tqdm
from model import DeepJSCC, ratio2filtersize
from torch.nn.parallel import DataParallel
from utils import image_normalization
from utils import image_normalization, set_seed, save_model, view_model_param
from fractions import Fraction
from dataset import Vanilla
import numpy as np
import time
from tensorboardX import SummaryWriter
import glob
def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_epoch(model, optimizer, param, data_loader):
model.train()
epoch_loss = 0
for iter, (images, _) in enumerate(data_loader):
images = images.cuda() if param['parallel'] and torch.cuda.device_count(
) > 1 else images.to(param['device'])
optimizer.zero_grad()
outputs = model.forward(images)
outputs = image_normalization('denormalization')(outputs)
images = image_normalization('denormalization')(images)
loss = model.loss(images, outputs) if not param['parallel'] else model.module.loss(
images, outputs)
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_loss /= (iter + 1)
return epoch_loss, optimizer
def config_parser():
def evaluate_epoch(model, param, data_loader):
model.eval()
epoch_loss = 0
with torch.no_grad():
for iter, (images, _) in enumerate(data_loader):
images = images.cuda() if param['parallel'] and torch.cuda.device_count(
) > 1 else images.to(param['device'])
outputs = model.forward(images)
outputs = image_normalization('denormalization')(outputs)
images = image_normalization('denormalization')(images)
loss = model.loss(images, outputs) if not param['parallel'] else model.module.loss(
images, outputs)
epoch_loss += loss.detach().item()
epoch_loss /= (iter + 1)
return epoch_loss
def config_parser_pipeline():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--epochs', default=256, type=int, help='number of epochs')
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str,
choices=['AWGN', 'Rayleigh'], help='channel type')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=['19', '13',
'7', '4', '1'], nargs='+', help='snr_list')
parser.add_argument('--ratio_list', default=['1/3',
'1/6', '1/12'], nargs='+', help='ratio_list')
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
parser.add_argument('--dataset', default='cifar10', type=str,
choices=['cifar10', 'imagenet'], help='dataset')
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler')
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
parser.add_argument('--out', default='./out', type=str, help='out_path')
parser.add_argument('--disable_tqdm', default=False, type=bool, help='disable_tqdm')
parser.add_argument('--device', default='cuda:0', type=str, help='device')
parser.add_argument('--gamma', default=0.5, type=float, help='gamma')
parser.add_argument('--disable_tqdm', default=True, type=bool, help='disable_tqdm')
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
parser.add_argument('--snr_list', default=['19', '13',
'7', '4', '1'], nargs='+', help='snr_list')
parser.add_argument('--ratio_list', default=['1/6', '1/12'], nargs='+', help='ratio_list')
parser.add_argument('--channel', default='AWGN', type=str,
choices=['AWGN', 'Rayleigh'], help='channel')
return parser.parse_args()
def main():
args = config_parser()
def main_pipeline():
args = config_parser_pipeline()
print("Training Start")
dataset_name = args.dataset
out_dir = args.out
args.snr_list = list(map(float, args.snr_list))
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
set_seed(args.seed)
print("Training Start")
for ratio in args.ratio_list:
for snr in args.snr_list:
train(args, ratio, snr)
params = {}
params['disable_tqdm'] = args.disable_tqdm
params['dataset'] = dataset_name
params['out_dir'] = out_dir
params['device'] = args.device
params['snr_list'] = args.snr_list
params['ratio_list'] = args.ratio_list
params['channel'] = args.channel
if dataset_name == 'cifar10':
params['batch_size'] = 64 # 1024
params['num_workers'] = 4
params['epochs'] = 1000
params['init_lr'] = 1e-3 # 1e-2
params['weight_decay'] = 5e-4
params['parallel'] = False
params['if_scheduler'] = True
params['step_size'] = 640
params['gamma'] = 0.1
params['seed'] = 42
params['ReduceLROnPlateau'] = False
params['lr_reduce_factor'] = 0.5
params['lr_schedule_patience'] = 15
params['max_time'] = 12
params['min_lr'] = 1e-5
elif dataset_name == 'imagenet':
params['batch_size'] = 32
params['num_workers'] = 4
params['epochs'] = 300
params['init_lr'] = 1e-4
params['weight_decay'] = 5e-4
params['parallel'] = True
params['if_scheduler'] = True
params['gamma'] = 0.1
params['seed'] = 42
params['ReduceLROnPlateau'] = True
params['lr_reduce_factor'] = 0.5
params['lr_schedule_patience'] = 15
params['max_time'] = 12
params['min_lr'] = 1e-5
else:
raise Exception('Unknown dataset')
set_seed(params['seed'])
for ratio in params['ratio_list']:
for snr in params['snr_list']:
params['ratio'] = ratio
params['snr'] = snr
train_pipeline(params)
def train(args: config_parser(), ratio: float, snr: float):
# add train_pipeline to with only dataset_name args
def train_pipeline(params):
dataset_name = params['dataset']
# load data
if dataset_name == 'cifar10':
transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='../dataset/', train=True,
download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=params['batch_size'], num_workers=params['num_workers'])
test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=params['batch_size'], num_workers=params['num_workers'])
elif dataset_name == 'imagenet':
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
print("loading data of imagenet")
train_dataset = datasets.ImageFolder(root='../dataset/ImageNet/train', transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=params['batch_size'], num_workers=params['num_workers'])
test_dataset = Vanilla(root='../dataset/ImageNet/val', transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=params['batch_size'], num_workers=params['num_workers'])
else:
raise Exception('Unknown dataset')
# create model
image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, params['ratio'])
print("The snr is {}, the inner channel is {}, the ratio is {:.2f}".format(
params['snr'], c, params['ratio']))
model = DeepJSCC(c=c, channel_type=params['channel'], snr=params['snr'])
# init exp dir
out_dir = params['out_dir']
phaser = dataset_name.upper() + '_' + str(c) + '_' + str(params['snr']) + '_' + \
"{:.2f}".format(params['ratio']) + '_' + str(params['channel']) + \
'_' + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
root_log_dir = out_dir + '/' + 'logs/' + phaser
root_ckpt_dir = out_dir + '/' + 'checkpoints/' + phaser
root_config_dir = out_dir + '/' + 'configs/' + phaser
writer = SummaryWriter(log_dir=root_log_dir)
# model init
device = torch.device(params['device'] if torch.cuda.is_available() else 'cpu')
if params['parallel'] and torch.cuda.device_count() > 1:
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
model = model.cuda()
else:
model = model.to(device)
# opt
optimizer = optim.Adam(
model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
if params['if_scheduler'] and not params['ReduceLROnPlateau']:
scheduler = optim.lr_scheduler.StepLR(
optimizer, step_size=params['step_size'], gamma=params['gamma'])
elif params['ReduceLROnPlateau']:
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
factor=params['lr_reduce_factor'],
patience=params['lr_schedule_patience'],
verbose=False)
else:
print("No scheduler")
scheduler = None
writer.add_text('config', str(params))
t0 = time.time()
epoch_train_losses, epoch_val_losses = [], []
per_epoch_time = []
# train
# At any point you can hit Ctrl + C to break out of training early.
try:
with tqdm(range(params['epochs']), disable=params['disable_tqdm']) as t:
for epoch in t:
t.set_description('Epoch %d' % epoch)
start = time.time()
epoch_train_loss, optimizer = train_epoch(
model, optimizer, params, train_loader)
epoch_val_loss = evaluate_epoch(model, params, test_loader)
epoch_train_losses.append(epoch_train_loss)
epoch_val_losses.append(epoch_val_loss)
writer.add_scalar('train/_loss', epoch_train_loss, epoch)
writer.add_scalar('val/_loss', epoch_val_loss, epoch)
writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)
t.set_postfix(time=time.time() - start, lr=optimizer.param_groups[0]['lr'],
train_loss=epoch_train_loss, val_loss=epoch_val_loss)
per_epoch_time.append(time.time() - start)
# Saving checkpoint
if not os.path.exists(root_ckpt_dir):
os.makedirs(root_ckpt_dir)
torch.save(model.state_dict(), '{}.pkl'.format(
root_ckpt_dir + "/epoch_" + str(epoch)))
files = glob.glob(root_ckpt_dir + '/*.pkl')
for file in files:
epoch_nb = file.split('_')[-1]
epoch_nb = int(epoch_nb.split('.')[0])
if epoch_nb < epoch - 1:
os.remove(file)
if params['ReduceLROnPlateau'] and scheduler is not None:
scheduler.step(epoch_val_loss)
elif params['if_scheduler'] and not params['ReduceLROnPlateau']:
scheduler.step() # use only information from the validation loss
if optimizer.param_groups[0]['lr'] < params['min_lr']:
print("\n!! LR EQUAL TO MIN LR SET.")
break
# Stop training after params['max_time'] hours
if time.time() - t0 > params['max_time'] * 3600:
print('-' * 89)
print("Max_time for training elapsed {:.2f} hours, so stopping".format(
params['max_time']))
break
except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early because of KeyboardInterrupt')
test_loss = evaluate_epoch(model, params, test_loader)
train_loss = evaluate_epoch(model, params, train_loader)
print("Test Accuracy: {:.4f}".format(test_loss))
print("Train Accuracy: {:.4f}".format(train_loss))
print("Convergence Time (Epochs): {:.4f}".format(epoch))
print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - t0))
print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))
"""
Write the results in out_dir/results folder
"""
writer.add_text(tag='result', text_string="""Dataset: {}\nparams={}\n\nTotal Parameters: {}\n\n
FINAL RESULTS\nTEST Loss: {:.4f}\nTRAIN Loss: {:.4f}\n\n
Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""
.format(dataset_name, params, view_model_param(model), np.mean(np.array(train_loss)),
np.mean(np.array(test_loss)), epoch, (time.time() - t0) / 3600, np.mean(per_epoch_time)))
writer.close()
if not os.path.exists(os.path.dirname(root_config_dir)):
os.makedirs(os.path.dirname(root_config_dir))
with open(root_config_dir + '.yaml', 'w') as f:
dict_yaml = {'dataset_name': dataset_name, 'params': params,
'inner_channel': c, 'total_parameters': view_model_param(model)}
import yaml
yaml.dump(dict_yaml, f)
del model, optimizer, scheduler, train_loader, test_loader
del writer
def train(args, ratio: float, snr: float): # deprecated
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
# load data
@ -142,22 +383,48 @@ def train(args: config_parser(), ratio: float, snr: float):
# epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format(
epoch, run_loss / len(train_loader), test_mse / len(test_loader), optimizer.param_groups[0]['lr']))
save_model(model, args.saved, args.saved +
'/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, args.batch_size, c))
save_model(model, args.saved, args.saved + '/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'
.format(args.dataset, args.epochs, ratio, snr, args.batch_size, c))
def save_model(model, dir, path):
os.makedirs(dir, exist_ok=True)
flag = 1
while True:
if os.path.exists(path):
path = path+'_'+str(flag)
flag += 1
else:
break
torch.save(model.state_dict(), path)
print("Model saved in {}".format(path))
def config_parser(): # deprecated
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--epochs', default=256, type=int, help='number of epochs')
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str,
choices=['AWGN', 'Rayleigh'], help='channel type')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=['19', '13',
'7', '4', '1'], nargs='+', help='snr_list')
parser.add_argument('--ratio_list', default=['1/3',
'1/6', '1/12'], nargs='+', help='ratio_list')
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
parser.add_argument('--dataset', default='cifar10', type=str,
choices=['cifar10', 'imagenet'], help='dataset')
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler')
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
parser.add_argument('--device', default='cuda:0', type=str, help='device')
parser.add_argument('--gamma', default=0.5, type=float, help='gamma')
parser.add_argument('--disable_tqdm', default=True, type=bool, help='disable_tqdm')
return parser.parse_args()
def main(): # deprecated
args = config_parser()
args.snr_list = list(map(float, args.snr_list))
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
set_seed(args.seed)
print("Training Start")
for ratio in args.ratio_list:
for snr in args.snr_list:
train(args, ratio, snr)
if __name__ == '__main__':
main()
main_pipeline()
# main()

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
def image_normalization(norm_type):
@ -14,9 +15,41 @@ def image_normalization(norm_type):
return _inner
def get_psnr(image, gt, max=255):
def get_psnr(image, gt, max_val=255, mse=None):
if mse is None:
mse = F.mse_loss(image, gt)
mse = torch.tensor(mse) if not isinstance(mse, torch.Tensor) else mse
psnr = 10 * torch.log10(max**2 / mse)
psnr = 10 * torch.log10(max_val**2 / mse)
return psnr
def save_model(model, dir, path):
os.makedirs(dir, exist_ok=True)
flag = 1
while True:
if os.path.exists(path):
path = path + '_' + str(flag)
flag += 1
else:
break
torch.save(model.state_dict(), path)
print("Model saved in {}".format(path))
def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def view_model_param(model):
total_param = 0
for param in model.parameters():
# print(param.data.size())
total_param += np.prod(list(param.data.size()))
return total_param

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long