Compare commits
10 Commits
ca6377e98c
...
2665e0dc6d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2665e0dc6d | ||
|
|
cf516747c1 | ||
|
|
143570f7b5 | ||
|
|
dce3f80d36 | ||
|
|
8796899e49 | ||
|
|
b7bd3bdd42 | ||
|
|
6d9b108ae5 | ||
|
|
ff3d377583 | ||
|
|
74d0f3f3e6 | ||
|
|
e1cef6aead |
4
.gitignore
vendored
@ -1,11 +1,11 @@
|
||||
test.py
|
||||
test.*
|
||||
*.pyc
|
||||
*.log
|
||||
dataset
|
||||
*.ipynb
|
||||
*.swp
|
||||
.vscode/*
|
||||
input.txt
|
||||
output.txt
|
||||
*.json
|
||||
.vscode/*
|
||||
*.sh
|
||||
42
.vscode/launch.json
vendored
@ -1,42 +0,0 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: 当前文件",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"args": [
|
||||
"--lr",
|
||||
"1e-3",
|
||||
"--epochs",
|
||||
"100",
|
||||
"--batch_size",
|
||||
"512",
|
||||
"--if_scheduler",
|
||||
"1",
|
||||
"--step_size",
|
||||
"500",
|
||||
"--dataset",
|
||||
"cifar10",
|
||||
"--num_workers",
|
||||
"4",
|
||||
"--device",
|
||||
"cuda:0",
|
||||
"--ratio_list",
|
||||
"1/3",
|
||||
"--snr_list",
|
||||
"100",
|
||||
"--seed",
|
||||
"42",
|
||||
"--disable_tqdm",
|
||||
"False"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
121
README.md
@ -1,18 +1,33 @@
|
||||
# Deep JSCC
|
||||
This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow and keras implementations ](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
|
||||
This implements training of Deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow and keras implementations ](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
|
||||
|
||||
This is my first time to use PyTorch and git to reproduce a paper, so there may be some mistakes. If you find any, please let me know. Thanks!
|
||||
|
||||
|
||||
|
||||
## Update-2024.06.04
|
||||
- modify the `train.py` to omit most of the args in command line, you can just use `python train.py --dataset ${dataset_name}` to train the model.
|
||||
- add tensorboard to record the results in exp.
|
||||
- add the `visualization/` file to visualize the result.
|
||||
- add bash file to run the code in parallel.
|
||||
|
||||
|
||||
|
||||
## Architecture
|
||||
|
||||

|
||||
<div style="text-align: center;">
|
||||
<img src="./demo/arc.png" alt="Image 1" style="width: 500px; height: 250px;">
|
||||
</div>
|
||||
|
||||
## Demo
|
||||
|
||||
the model trained on cifar10 which is 32\*32 but test on kodim which is 768\*512.
|
||||

|
||||
|
||||
the model trained on imagenet which is resized to 128\*128 but test on kodim which is 768\*512.
|
||||

|
||||
The model trained on cifar10 which is 32\*32 but test on kodim which is 768\*512 (top); and the model trained on imagenet which is resized to 128\*128 but test on kodim which is 768\*512 (bottom).
|
||||
<div style="display: flex;">
|
||||
<img src="./demo/cifar10_kodim08.png" alt="Image 1" style="flex: 1; max-width: 48%; height: auto;">
|
||||
<div style="width: 5px;"></div>
|
||||
<img src="./demo/imagenet_kodim08.png" alt="Image 2" style="flex: 1; max-width: 48%; height: auto;">
|
||||
</div>
|
||||
|
||||
|
||||
## Installation
|
||||
@ -31,25 +46,95 @@ The cifar10 dataset can be downloaded automatically by torchvision. But the imag
|
||||
python dataset.py
|
||||
```
|
||||
|
||||
### Training Model
|
||||
Run(example presented in paper) on cifar10
|
||||
### Training
|
||||
The training command used to be very long, but now you can just use `python train.py --dataset ${dataset_name} --channel ${channel}` to train the model.
|
||||
The default dataset is cifar10.
|
||||
The parmeters can be modified in the `train.py` file. The default parameters are similar to the paper.
|
||||
|
||||
```
|
||||
python train.py --lr 1e-3 --epochs 1000 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list 1 4 7 13 19 --ratio_list 1/6 1/12 --dataset cifar10 --num_workers 4 --parallel True --if_scheduler True --scheduler_step_size 50
|
||||
```
|
||||
or Run(example presented in paper) on imagenet
|
||||
|
||||
| Parameters | CIFAR-10 | ImageNet |
|
||||
|------------------------|------------------|------------------|
|
||||
| `batch_size` | 64 | 32 |
|
||||
| `init_lr` | 1e-3 | 1e-4 |
|
||||
| `weight_decay` | 5e-4 | 5e-4 |
|
||||
| `snr_list` | [19, 13, 7, 4, 1]| [19, 13, 7, 4, 1]|
|
||||
| `ratio_list` | [1/6, 1/12] | [1/6, 1/12] |
|
||||
| `if_scheduler` | True | False |
|
||||
| `step_size` | 640 | N/A |
|
||||
| `gamma` | 0.1 | 0.1 |
|
||||
|
||||
|
||||
|
||||
<!-- ALSO! The batch_size for cifar10 training in the paper is small causing the GPU utilization is low. So The bash script is provided to run the code in parallel for different snr and ratio for cifar10 dataset. (Example of two GPUs)
|
||||
```
|
||||
python train.py --lr 10e-4 --epochs 300 --batch_size 32 --channel 'AWGN' --saved ./saved --dataset imagenet --num_workers 4 --parallel True
|
||||
```
|
||||
bash parallel_train_cifar.sh --channel ${channel}
|
||||
``` -->
|
||||
|
||||
|
||||
|
||||
### Evaluation
|
||||
Run(example presented in paper)
|
||||
The `eval.py` provides the evaluation of the trained model.
|
||||
|
||||
You may need modify slightly to evaluate the model for different snr_list and channel type in `main` function.
|
||||
```
|
||||
python eval.py --channel 'AWGN' --saved ./saved/${mode_path} --snr 20 --test_img ${test_img_path}
|
||||
python eval.py
|
||||
```
|
||||
All training and evaluation results are saved in the `./out` directory by default. The `./out` directory may contain the structure as follows:
|
||||
```
|
||||
./out
|
||||
├── checkpoint # trained models
|
||||
│ ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
|
||||
│ ├── epoch_$num.pth
|
||||
│ ├── ...
|
||||
│ ├── CIFAR10_10_1.0_0.08_AWGN_13h21m37s_on_Jun_02_2024
|
||||
│ ├── CIFAR10_20_7.0_0.17_Rayleigh_14h03m19s_on_Jun_03_2024
|
||||
│ ├── ...
|
||||
├── configs # training configurations
|
||||
│ ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
|
||||
│ ├── $CIFAR10_10_4.0_0.08_AWGN_13h21m38s_on_Jun_02_2024.yaml
|
||||
│ ├── ...
|
||||
├── logs # training logs
|
||||
│ ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
|
||||
│ ├── tensorboard logs
|
||||
│ ├── ...
|
||||
├── eval # evaluation results
|
||||
│ ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
|
||||
│ ├── tensorboard logs
|
||||
│ ├── ...
|
||||
```
|
||||
### Visualization
|
||||
|
||||
The `./visualization` directory contains the scripts for visualization of the training and evaluation results.
|
||||
|
||||
- `single_visualization.ipynb` is used to get demo of the model on single image like the demo above.
|
||||
- `plot_visualization.ipynb` is used to get visualizations of the perfprmance of the model on different snr and ratio.
|
||||
|
||||
## Results
|
||||
|
||||
|
||||
|
||||
|
||||
<div style="display: flex;">
|
||||
<img src="demo/cifar_0.08_AWGN.png" alt="Image 1" style="flex: 1; max-width: 48%; height: auto;">
|
||||
<div style="width: 0px;"></div> <!-- 为了让两个图像之间有一点间距 -->
|
||||
<img src="demo/cifar_0.17_AWGN.png" alt="Image 2" style="flex: 1; max-width: 48%; height: auto;">
|
||||
</div>
|
||||
|
||||
<div style="display: flex;">
|
||||
<img src="demo/cifar_0.17_Rayleigh.png" alt="Image 1" style="flex: 1; max-width: 48%; height: auto;">
|
||||
<div style="width: 0px;"></div> <!-- 为了让两个图像之间有一点间距 -->
|
||||
<img src="demo/cifar_0.34_Rayleigh.png" alt="Image 2" style="flex: 1; max-width: 48%; height: auto;">
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
### TO-DO
|
||||
- Add visualization of the model
|
||||
- plot the results with different snr and ratio
|
||||
- ~~Add visualization of the model~~
|
||||
- ~~plot the results with different snr and ratio~~
|
||||
- ~~add Rayleigh channel~~
|
||||
- train on imagenet
|
||||
- **Convert the real communication system to a complex communication system**
|
||||
|
||||
## Citation
|
||||
If you find (part of) this code useful for your research, please consider citing
|
||||
|
||||
@ -1,30 +0,0 @@
|
||||
from time import time
|
||||
import multiprocessing as mp
|
||||
import torch
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
transform = transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
])
|
||||
|
||||
trainset = torchvision.datasets.CIFAR10(
|
||||
root='./dataset/',
|
||||
train=True, # 如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
|
||||
download=True, # 如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
|
||||
transform=transform
|
||||
)
|
||||
|
||||
print(f"num of CPU: {mp.cpu_count()}")
|
||||
for num_workers in range(2, mp.cpu_count(), 2):
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True)
|
||||
start = time()
|
||||
for epoch in range(1, 3):
|
||||
for i, data in enumerate(train_loader, 0):
|
||||
pass
|
||||
end = time()
|
||||
print("Finish with:{} second, num_workers={}".format(end - start, num_workers))
|
||||
75
channel.py
@ -2,26 +2,61 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def channel(channel_type='AWGN', snr=20):
|
||||
def AWGN_channel(z_hat: torch.Tensor):
|
||||
if z_hat.dim() == 4:
|
||||
# k = np.prod(z_hat.size()[1:])
|
||||
k = torch.prod(torch.tensor(z_hat.size()[1:]))
|
||||
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)/k
|
||||
elif z_hat.dim() == 3:
|
||||
# k = np.prod(z_hat.size())
|
||||
k = torch.prod(torch.tensor(z_hat.size()))
|
||||
sig_pwr = torch.sum(torch.abs(z_hat).square())/k
|
||||
noi_pwr = sig_pwr / (10 ** (snr / 10))
|
||||
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
|
||||
class Channel(nn.Module):
|
||||
def __init__(self, channel_type='AWGN', snr=20):
|
||||
if channel_type not in ['AWGN', 'Rayleigh']:
|
||||
raise Exception('Unknown type of channel')
|
||||
super(Channel, self).__init__()
|
||||
self.channel_type = channel_type
|
||||
self.snr = snr
|
||||
|
||||
def forward(self, z_hat):
|
||||
if z_hat.dim() not in {3, 4}:
|
||||
raise ValueError('Input tensor must be 3D or 4D')
|
||||
|
||||
# if z_hat.dim() == 4:
|
||||
# # k = np.prod(z_hat.size()[1:])
|
||||
# k = torch.prod(torch.tensor(z_hat.size()[1:]))
|
||||
# sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True) / k
|
||||
# elif z_hat.dim() == 3:
|
||||
# # k = np.prod(z_hat.size())
|
||||
# k = torch.prod(torch.tensor(z_hat.size()))
|
||||
# sig_pwr = torch.sum(torch.abs(z_hat).square()) / k
|
||||
|
||||
if z_hat.dim() == 3:
|
||||
z_hat = z_hat.unsqueeze(0)
|
||||
|
||||
k = z_hat[0].numel()
|
||||
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True) / k
|
||||
noi_pwr = sig_pwr / (10 ** (self.snr / 10))
|
||||
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr/2)
|
||||
if self.channel_type == 'Rayleigh':
|
||||
# hc = torch.randn_like(z_hat) wrong implement before
|
||||
# hc = torch.randn(1, device = z_hat.device)
|
||||
hc = torch.randn(2, device = z_hat.device)
|
||||
|
||||
# clone for in-place operation
|
||||
z_hat = z_hat.clone()
|
||||
z_hat[:,:z_hat.size(1)//2] = hc[0] * z_hat[:,:z_hat.size(1)//2]
|
||||
z_hat[:,z_hat.size(1)//2:] = hc[1] * z_hat[:,z_hat.size(1)//2:]
|
||||
|
||||
|
||||
# z_hat = hc * z_hat
|
||||
|
||||
return z_hat + noise
|
||||
|
||||
def Rayleigh_channel(z_hat: torch.Tensor):
|
||||
pass
|
||||
def get_channel(self):
|
||||
return self.channel_type, self.snr
|
||||
|
||||
if channel_type == 'AWGN':
|
||||
return AWGN_channel
|
||||
elif channel_type == 'Rayleigh':
|
||||
return Rayleigh_channel
|
||||
else:
|
||||
raise Exception('Unknown type of channel')
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test
|
||||
channel = Channel(channel_type='AWGN', snr=10)
|
||||
z_hat = torch.randn(64, 10, 5, 5)
|
||||
z_hat = channel(z_hat)
|
||||
print(z_hat)
|
||||
|
||||
channel = Channel(channel_type='Rayleigh', snr=10)
|
||||
z_hat = torch.randn(10, 5, 5)
|
||||
z_hat = channel(z_hat)
|
||||
print(z_hat)
|
||||
|
||||
|
Before Width: | Height: | Size: 94 KiB |
BIN
demo/127_127.jpg
|
Before Width: | Height: | Size: 20 KiB |
BIN
demo/32_32.jpg
|
Before Width: | Height: | Size: 2.1 KiB |
|
Before Width: | Height: | Size: 1.6 MiB After Width: | Height: | Size: 1.6 MiB |
BIN
demo/cifar_0.08_AWGN.png
Normal file
|
After Width: | Height: | Size: 82 KiB |
BIN
demo/cifar_0.17_AWGN.png
Normal file
|
After Width: | Height: | Size: 82 KiB |
BIN
demo/cifar_0.17_Rayleigh.png
Normal file
|
After Width: | Height: | Size: 83 KiB |
BIN
demo/cifar_0.34_Rayleigh.png
Normal file
|
After Width: | Height: | Size: 83 KiB |
|
Before Width: | Height: | Size: 1.5 MiB After Width: | Height: | Size: 1.5 MiB |
102
eval.py
@ -1,51 +1,75 @@
|
||||
# to be implemented
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from utils import get_psnr, image_normalization
|
||||
from utils import get_psnr
|
||||
import os
|
||||
from model import DeepJSCC
|
||||
from train import evaluate_epoch
|
||||
from torchvision import transforms
|
||||
from torchvision import datasets
|
||||
from torch.utils.data import DataLoader
|
||||
from dataset import Vanilla
|
||||
import yaml
|
||||
from tensorboardX import SummaryWriter
|
||||
import glob
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
def eval_snr(model, test_loader, writer, param, times=10):
|
||||
snr_list = range(0, 26, 1)
|
||||
for snr in snr_list:
|
||||
model.change_channel(param['channel'], snr)
|
||||
test_loss = 0
|
||||
for i in range(times):
|
||||
test_loss += evaluate_epoch(model, param, test_loader)
|
||||
|
||||
test_loss /= times
|
||||
psnr = get_psnr(image=None, gt=None, mse=test_loss)
|
||||
writer.add_scalar('psnr', psnr, snr)
|
||||
|
||||
|
||||
|
||||
def config_parser():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
||||
parser.add_argument('--saved', type=str, help='saved_path')
|
||||
parser.add_argument('--snr', default=20, type=int, help='snr')
|
||||
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
|
||||
parser.add_argument('--times', default=10, type=int, help='num_workers')
|
||||
return parser.parse_args()
|
||||
def process_config(config_path, output_dir, dataset_name, times):
|
||||
with open(config_path, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.UnsafeLoader)
|
||||
assert dataset_name == config['dataset_name']
|
||||
params = config['params']
|
||||
c = config['inner_channel']
|
||||
|
||||
if dataset_name == 'cifar10':
|
||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||
test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
|
||||
download=True, transform=transform)
|
||||
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||
batch_size=params['batch_size'], num_workers=params['num_workers'])
|
||||
|
||||
elif dataset_name == 'imagenet':
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
||||
|
||||
test_dataset = Vanilla(root='../dataset/ImageNet/val', transform=transform)
|
||||
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||
batch_size=params['batch_size'], num_workers=params['num_workers'])
|
||||
else:
|
||||
raise Exception('Unknown dataset')
|
||||
|
||||
name = os.path.splitext(os.path.basename(config_path))[0]
|
||||
writer = SummaryWriter(os.path.join(output_dir, 'eval', name))
|
||||
model = DeepJSCC(c=c)
|
||||
model = model.to(params['device'])
|
||||
pkl_list = glob.glob(os.path.join(output_dir, 'checkpoints', name, '*.pkl'))
|
||||
model.load_state_dict(torch.load(pkl_list[-1]))
|
||||
eval_snr(model, test_loader, writer, params, times)
|
||||
writer.close()
|
||||
|
||||
def main():
|
||||
args = config_parser()
|
||||
transform = transforms.Compose([transforms.ToTensor()])
|
||||
test_image = Image.open(args.test_image)
|
||||
test_image.load()
|
||||
test_image = transform(test_image)
|
||||
times = 10
|
||||
dataset_name = 'cifar10'
|
||||
output_dir = './out'
|
||||
channel_type = 'AWGN'
|
||||
config_dir = os.path.join(output_dir, 'configs')
|
||||
config_files = [os.path.join(config_dir, name) for name in os.listdir(config_dir)
|
||||
if (dataset_name in name or dataset_name.upper() in name) and channel_type in name and name.endswith('.yaml')]
|
||||
|
||||
file_name = os.path.basename(args.saved)
|
||||
c = file_name.split('_')[-1].split('.')[0]
|
||||
c = int(c)
|
||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=args.snr)
|
||||
# model.load_state_dict(torch.load(args.saved))
|
||||
model.load_state_dict(torch.load(args.saved,map_location=torch.device('cuda:0')))
|
||||
model.change_channel(args.channel, args.snr)
|
||||
|
||||
psnr_all = 0.0
|
||||
|
||||
for i in range(args.times):
|
||||
demo_image = model(test_image)
|
||||
demo_image = image_normalization('denormalization')(demo_image)
|
||||
gt = image_normalization('denormalization')(test_image)
|
||||
psnr_all += get_psnr(demo_image, gt)
|
||||
demo_image = image_normalization('normalization')(demo_image)
|
||||
demo_image = torch.cat([test_image, demo_image], dim=1)
|
||||
demo_image = transforms.ToPILImage()(demo_image)
|
||||
demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1], args.test_image.split('/')[-1]))
|
||||
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
||||
with ProcessPoolExecutor() as executor:
|
||||
executor.map(process_config, config_files, [output_dir] * len(config_files), [dataset_name] * len(config_files), [times] * len(config_files))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
61
model.py
@ -7,7 +7,7 @@ Created on Tue Dec 11:00:00 2023
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import channel
|
||||
from channel import Channel
|
||||
|
||||
|
||||
""" def _image_normalization(norm_type):
|
||||
@ -58,7 +58,8 @@ class _TransConvWithPReLU(nn.Module):
|
||||
in_channels, out_channels, kernel_size, stride, padding, output_padding)
|
||||
self.activate = activate
|
||||
if activate == nn.PReLU():
|
||||
nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out',
|
||||
nonlinearity='leaky_relu')
|
||||
else:
|
||||
nn.init.xavier_normal_(self.transconv.weight)
|
||||
|
||||
@ -73,12 +74,12 @@ class _Encoder(nn.Module):
|
||||
super(_Encoder, self).__init__()
|
||||
self.is_temp = is_temp
|
||||
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
||||
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
|
||||
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
|
||||
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2, padding=2)
|
||||
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=2)
|
||||
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32,
|
||||
kernel_size=5, padding=2) # padding size could be changed here
|
||||
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2)
|
||||
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2)
|
||||
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=2*c, kernel_size=5, padding=2)
|
||||
self.norm = self._normlizationLayer(P=P)
|
||||
|
||||
@staticmethod
|
||||
@ -104,7 +105,7 @@ class _Encoder(nn.Module):
|
||||
return _inner
|
||||
|
||||
def forward(self, x):
|
||||
#x = self.imgae_normalization(x)
|
||||
# x = self.imgae_normalization(x)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
@ -118,16 +119,16 @@ class _Encoder(nn.Module):
|
||||
class _Decoder(nn.Module):
|
||||
def __init__(self, c=1):
|
||||
super(_Decoder, self).__init__()
|
||||
#self.imgae_normalization = _image_normalization(norm_type='denormalization')
|
||||
# self.imgae_normalization = _image_normalization(norm_type='denormalization')
|
||||
self.tconv1 = _TransConvWithPReLU(
|
||||
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
|
||||
in_channels=2*c, out_channels=32, kernel_size=5, stride=1, padding=2)
|
||||
self.tconv2 = _TransConvWithPReLU(
|
||||
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
|
||||
self.tconv3 = _TransConvWithPReLU(
|
||||
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
|
||||
self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=6, stride=2)
|
||||
self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=5, stride=2, padding=2, output_padding=1)
|
||||
self.tconv5 = _TransConvWithPReLU(
|
||||
in_channels=16, out_channels=3, kernel_size=6, stride=2, activate=nn.Sigmoid())
|
||||
in_channels=16, out_channels=3, kernel_size=5, stride=2, padding=2, output_padding=1,activate=nn.Sigmoid())
|
||||
# may be some problems in tconv4 and tconv5, the kernal_size is not the same as the paper which is 5
|
||||
|
||||
def forward(self, x):
|
||||
@ -136,22 +137,50 @@ class _Decoder(nn.Module):
|
||||
x = self.tconv3(x)
|
||||
x = self.tconv4(x)
|
||||
x = self.tconv5(x)
|
||||
#x = self.imgae_normalization(x)
|
||||
# x = self.imgae_normalization(x)
|
||||
return x
|
||||
|
||||
|
||||
class DeepJSCC(nn.Module):
|
||||
def __init__(self, c, channel_type='AWGN', snr=20):
|
||||
def __init__(self, c, channel_type='AWGN', snr=None):
|
||||
super(DeepJSCC, self).__init__()
|
||||
self.encoder = _Encoder(c=c)
|
||||
self.channel = channel.channel(channel_type, snr)
|
||||
if snr is not None:
|
||||
self.channel = Channel(channel_type, snr)
|
||||
self.decoder = _Decoder(c=c)
|
||||
|
||||
def forward(self, x):
|
||||
z = self.encoder(x)
|
||||
z = self.channel(z)
|
||||
if hasattr(self, 'channel') and self.channel is not None:
|
||||
z = self.channel(z)
|
||||
x_hat = self.decoder(z)
|
||||
return x_hat
|
||||
|
||||
def change_channel(self, channel_type, snr):
|
||||
self.channel = channel.channel(channel_type, snr)
|
||||
def change_channel(self, channel_type='AWGN', snr=None):
|
||||
if snr is None:
|
||||
self.channel = None
|
||||
else:
|
||||
self.channel = Channel(channel_type, snr)
|
||||
|
||||
def get_channel(self):
|
||||
if hasattr(self, 'channel') and self.channel is not None:
|
||||
return self.channel.get_channel()
|
||||
return None
|
||||
|
||||
def loss(self, prd, gt):
|
||||
criterion = nn.MSELoss(reduction='mean')
|
||||
loss = criterion(prd, gt)
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = DeepJSCC(c=20)
|
||||
print(model)
|
||||
x = torch.rand(1, 3, 128, 128)
|
||||
y = model(x)
|
||||
print(y.size())
|
||||
print(y)
|
||||
print(model.encoder.norm)
|
||||
print(model.encoder.norm(y))
|
||||
print(model.encoder.norm(y).size())
|
||||
print(model.encoder.norm(y).size()[1:])
|
||||
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 16
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: Rayleigh
|
||||
dataset: cifar10
|
||||
device: cuda:1
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.3333333333333333
|
||||
ratio_list:
|
||||
- 0.3333333333333333
|
||||
seed: 42
|
||||
snr: 1.0
|
||||
snr_list:
|
||||
- 1.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
acYCAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 16
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: Rayleigh
|
||||
dataset: cifar10
|
||||
device: cuda:1
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.3333333333333333
|
||||
ratio_list:
|
||||
- 0.3333333333333333
|
||||
seed: 42
|
||||
snr: 13.0
|
||||
snr_list:
|
||||
- 13.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
acYCAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 16
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: Rayleigh
|
||||
dataset: cifar10
|
||||
device: cuda:1
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.3333333333333333
|
||||
ratio_list:
|
||||
- 0.3333333333333333
|
||||
seed: 42
|
||||
snr: 19.0
|
||||
snr_list:
|
||||
- 19.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
acYCAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 16
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: Rayleigh
|
||||
dataset: cifar10
|
||||
device: cuda:1
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.3333333333333333
|
||||
ratio_list:
|
||||
- 0.3333333333333333
|
||||
seed: 42
|
||||
snr: 4.0
|
||||
snr_list:
|
||||
- 4.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
acYCAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 16
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: Rayleigh
|
||||
dataset: cifar10
|
||||
device: cuda:1
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.3333333333333333
|
||||
ratio_list:
|
||||
- 0.3333333333333333
|
||||
seed: 42
|
||||
snr: 7.0
|
||||
snr_list:
|
||||
- 7.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
acYCAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 4
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: AWGN
|
||||
dataset: cifar10
|
||||
device: cuda:1
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.08333333333333333
|
||||
ratio_list:
|
||||
- 0.08333333333333333
|
||||
seed: 42
|
||||
snr: 1.0
|
||||
snr_list:
|
||||
- 1.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
UTACAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 4
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: AWGN
|
||||
dataset: cifar10
|
||||
device: cuda:1
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.08333333333333333
|
||||
ratio_list:
|
||||
- 0.08333333333333333
|
||||
seed: 42
|
||||
snr: 13.0
|
||||
snr_list:
|
||||
- 13.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
UTACAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 4
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: AWGN
|
||||
dataset: cifar10
|
||||
device: cuda:1
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.08333333333333333
|
||||
ratio_list:
|
||||
- 0.08333333333333333
|
||||
seed: 42
|
||||
snr: 19.0
|
||||
snr_list:
|
||||
- 19.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
UTACAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 4
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: AWGN
|
||||
dataset: cifar10
|
||||
device: cuda:1
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.08333333333333333
|
||||
ratio_list:
|
||||
- 0.08333333333333333
|
||||
seed: 42
|
||||
snr: 4.0
|
||||
snr_list:
|
||||
- 4.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
UTACAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 4
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: AWGN
|
||||
dataset: cifar10
|
||||
device: cuda:1
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.08333333333333333
|
||||
ratio_list:
|
||||
- 0.08333333333333333
|
||||
seed: 42
|
||||
snr: 7.0
|
||||
snr_list:
|
||||
- 7.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
UTACAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 8
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: AWGN
|
||||
dataset: cifar10
|
||||
device: cuda:0
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.16666666666666666
|
||||
ratio_list:
|
||||
- 0.16666666666666666
|
||||
seed: 42
|
||||
snr: 1.0
|
||||
snr_list:
|
||||
- 1.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
WWICAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 8
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: Rayleigh
|
||||
dataset: cifar10
|
||||
device: cuda:0
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.16666666666666666
|
||||
ratio_list:
|
||||
- 0.16666666666666666
|
||||
seed: 42
|
||||
snr: 1.0
|
||||
snr_list:
|
||||
- 1.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
WWICAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 8
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: AWGN
|
||||
dataset: cifar10
|
||||
device: cuda:0
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.16666666666666666
|
||||
ratio_list:
|
||||
- 0.16666666666666666
|
||||
seed: 42
|
||||
snr: 13.0
|
||||
snr_list:
|
||||
- 13.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
WWICAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 8
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: Rayleigh
|
||||
dataset: cifar10
|
||||
device: cuda:0
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.16666666666666666
|
||||
ratio_list:
|
||||
- 0.16666666666666666
|
||||
seed: 42
|
||||
snr: 13.0
|
||||
snr_list:
|
||||
- 13.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
WWICAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 8
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: AWGN
|
||||
dataset: cifar10
|
||||
device: cuda:0
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.16666666666666666
|
||||
ratio_list:
|
||||
- 0.16666666666666666
|
||||
seed: 42
|
||||
snr: 19.0
|
||||
snr_list:
|
||||
- 19.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
WWICAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 8
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: Rayleigh
|
||||
dataset: cifar10
|
||||
device: cuda:0
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.16666666666666666
|
||||
ratio_list:
|
||||
- 0.16666666666666666
|
||||
seed: 42
|
||||
snr: 19.0
|
||||
snr_list:
|
||||
- 19.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
WWICAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 8
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: AWGN
|
||||
dataset: cifar10
|
||||
device: cuda:0
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.16666666666666666
|
||||
ratio_list:
|
||||
- 0.16666666666666666
|
||||
seed: 42
|
||||
snr: 4.0
|
||||
snr_list:
|
||||
- 4.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
WWICAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 8
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: Rayleigh
|
||||
dataset: cifar10
|
||||
device: cuda:0
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.16666666666666666
|
||||
ratio_list:
|
||||
- 0.16666666666666666
|
||||
seed: 42
|
||||
snr: 4.0
|
||||
snr_list:
|
||||
- 4.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
WWICAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 8
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: AWGN
|
||||
dataset: cifar10
|
||||
device: cuda:0
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.16666666666666666
|
||||
ratio_list:
|
||||
- 0.16666666666666666
|
||||
seed: 42
|
||||
snr: 7.0
|
||||
snr_list:
|
||||
- 7.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
WWICAAAAAAA=
|
||||
@ -0,0 +1,46 @@
|
||||
dataset_name: cifar10
|
||||
inner_channel: 8
|
||||
params:
|
||||
ReduceLROnPlateau: false
|
||||
batch_size: 64
|
||||
channel: Rayleigh
|
||||
dataset: cifar10
|
||||
device: cuda:0
|
||||
disable_tqdm: false
|
||||
epochs: 1000
|
||||
gamma: 0.1
|
||||
if_scheduler: true
|
||||
init_lr: 0.001
|
||||
lr_reduce_factor: 0.5
|
||||
lr_schedule_patience: 15
|
||||
max_time: 12
|
||||
min_lr: 1.0e-05
|
||||
num_workers: 4
|
||||
out_dir: ./out
|
||||
parallel: false
|
||||
ratio: 0.16666666666666666
|
||||
ratio_list:
|
||||
- 0.16666666666666666
|
||||
seed: 42
|
||||
snr: 7.0
|
||||
snr_list:
|
||||
- 7.0
|
||||
step_size: 640
|
||||
weight_decay: 0.0005
|
||||
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||
- !!python/object/apply:numpy.dtype
|
||||
args:
|
||||
- i8
|
||||
- false
|
||||
- true
|
||||
state: !!python/tuple
|
||||
- 3
|
||||
- <
|
||||
- null
|
||||
- null
|
||||
- null
|
||||
- -1
|
||||
- -1
|
||||
- 0
|
||||
- !!binary |
|
||||
WWICAAAAAAA=
|
||||