Update-2024.06.04
4
.gitignore
vendored
@ -1,11 +1,11 @@
|
|||||||
test.py
|
test.*
|
||||||
*.pyc
|
*.pyc
|
||||||
*.log
|
*.log
|
||||||
dataset
|
dataset
|
||||||
*.ipynb
|
|
||||||
*.swp
|
*.swp
|
||||||
.vscode/*
|
.vscode/*
|
||||||
input.txt
|
input.txt
|
||||||
output.txt
|
output.txt
|
||||||
*.json
|
*.json
|
||||||
.vscode/*
|
.vscode/*
|
||||||
|
*.sh
|
||||||
42
.vscode/launch.json
vendored
@ -1,42 +0,0 @@
|
|||||||
{
|
|
||||||
// Use IntelliSense to learn about possible attributes.
|
|
||||||
// Hover to view descriptions of existing attributes.
|
|
||||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
|
||||||
"version": "0.2.0",
|
|
||||||
"configurations": [
|
|
||||||
{
|
|
||||||
"name": "Python: 当前文件",
|
|
||||||
"type": "python",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${file}",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"justMyCode": true,
|
|
||||||
"args": [
|
|
||||||
"--lr",
|
|
||||||
"1e-3",
|
|
||||||
"--epochs",
|
|
||||||
"100",
|
|
||||||
"--batch_size",
|
|
||||||
"512",
|
|
||||||
"--if_scheduler",
|
|
||||||
"1",
|
|
||||||
"--step_size",
|
|
||||||
"500",
|
|
||||||
"--dataset",
|
|
||||||
"cifar10",
|
|
||||||
"--num_workers",
|
|
||||||
"4",
|
|
||||||
"--device",
|
|
||||||
"cuda:0",
|
|
||||||
"--ratio_list",
|
|
||||||
"1/3",
|
|
||||||
"--snr_list",
|
|
||||||
"100",
|
|
||||||
"--seed",
|
|
||||||
"42",
|
|
||||||
"--disable_tqdm",
|
|
||||||
"False"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
115
README.md
@ -1,19 +1,35 @@
|
|||||||
# Deep JSCC
|
# Deep JSCC
|
||||||
This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow and keras implementations ](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
|
This implements training of Deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow and keras implementations ](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
|
||||||
|
|
||||||
This is my first time to use PyTorch and git to reproduce a paper, so there may be some mistakes. If you find any, please let me know. Thanks!
|
This is my first time to use PyTorch and git to reproduce a paper, so there may be some mistakes. If you find any, please let me know. Thanks!
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Update-2024.06.04
|
||||||
|
- modify the `train.py` to omit most of the args in command line, you can just us `python train.py --dataset ${dataset_name}` to train the model.
|
||||||
|
- add tensorboard to record the results in exp.
|
||||||
|
- add the `visualization/` file to visualize the result.
|
||||||
|
- add bash file to run the code in parallel.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
|
|
||||||

|
<div style="text-align: center;">
|
||||||
|
<img src="./demo/arc.png" alt="Image 1" style="width: 500px; height: 250px;">
|
||||||
|
</div>
|
||||||
|
|
||||||
## Demo
|
## Demo
|
||||||
|
|
||||||
the model trained on cifar10 which is 32\*32 but test on kodim which is 768\*512.
|
|
||||||

|
|
||||||
|
|
||||||
the model trained on imagenet which is resized to 128\*128 but test on kodim which is 768\*512.
|
The model trained on cifar10 which is 32\*32 but test on kodim which is 768\*512 (left); and the model trained on imagenet which is resized to 128\*128 but test on kodim which is 768\*512 (right).
|
||||||

|
<div style="display: flex;">
|
||||||
|
|
||||||
|
<img src="./demo/cifar10_kodim08.png" alt="Image 1" style="flex: 1;" width="150" height="450">
|
||||||
|
<div style="width: 20px;"></div>
|
||||||
|
<img src="./demo/imagenet_kodim08.png" alt="Image 1" style="flex: 1;" width="150" height="450">
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
conda or other virtual environment is recommended.
|
conda or other virtual environment is recommended.
|
||||||
@ -31,25 +47,88 @@ The cifar10 dataset can be downloaded automatically by torchvision. But the imag
|
|||||||
python dataset.py
|
python dataset.py
|
||||||
```
|
```
|
||||||
|
|
||||||
### Training Model
|
### Training
|
||||||
Run(example presented in paper) on cifar10
|
The training command used to be very long, but now you can just use `python train.py --dataset ${dataset_name} --channel ${channel}` to train the model.
|
||||||
|
The default dataset is cifar10.
|
||||||
|
The parmeters can be modified in the `train.py` file. The default parameters are similar to the paper.
|
||||||
|
|
||||||
```
|
|
||||||
python train.py --lr 1e-3 --epochs 1000 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list 1 4 7 13 19 --ratio_list 1/6 1/12 --dataset cifar10 --num_workers 4 --parallel True --if_scheduler True --scheduler_step_size 50
|
|
||||||
```
|
|
||||||
or Run(example presented in paper) on imagenet
|
|
||||||
|
|
||||||
|
| Parameters | CIFAR-10 | ImageNet |
|
||||||
|
|------------------------|------------------|------------------|
|
||||||
|
| `batch_size` | 64 | 32 |
|
||||||
|
| `init_lr` | 1e-3 | 1e-4 |
|
||||||
|
| `weight_decay` | 5e-4 | 5e-4 |
|
||||||
|
| `snr_list` | [19, 13, 7, 4, 1]| [19, 13, 7, 4, 1]|
|
||||||
|
| `ratio_list` | [1/6, 1/12] | [1/6, 1/12] |
|
||||||
|
| `if_scheduler` | True | False |
|
||||||
|
| `step_size` | 640 | N/A |
|
||||||
|
| `gamma` | 0.1 | 0.1 |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<!-- ALSO! The batch_size for cifar10 training in the paper is small causing the GPU utilization is low. So The bash script is provided to run the code in parallel for different snr and ratio for cifar10 dataset. (Example of two GPUs)
|
||||||
```
|
```
|
||||||
python train.py --lr 10e-4 --epochs 300 --batch_size 32 --channel 'AWGN' --saved ./saved --dataset imagenet --num_workers 4 --parallel True
|
bash parallel_train_cifar.sh --channel ${channel}
|
||||||
```
|
``` -->
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Evaluation
|
### Evaluation
|
||||||
Run(example presented in paper)
|
The `eval.py` provides the evaluation of the trained model.
|
||||||
|
|
||||||
|
You may need modify slightly to evaluate the model for different snr_list and channel type in `main` function.
|
||||||
```
|
```
|
||||||
python eval.py --channel 'AWGN' --saved ./saved/${mode_path} --snr 20 --test_img ${test_img_path}
|
python eval.py
|
||||||
```
|
```
|
||||||
|
All training and evaluation results are saved in the `./out` directory by default. The `./out` directory may contain the structure as follows:
|
||||||
|
```
|
||||||
|
./out
|
||||||
|
├── checkpoint # trained models
|
||||||
|
│ ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
|
||||||
|
│ ├── epoch_$num.pth
|
||||||
|
│ ├── ...
|
||||||
|
│ ├── CIFAR10_10_1.0_0.08_AWGN_13h21m37s_on_Jun_02_2024
|
||||||
|
│ ├── CIFAR10_20_7.0_0.17_Rayleigh_14h03m19s_on_Jun_03_2024
|
||||||
|
│ ├── ...
|
||||||
|
├── configs # training configurations
|
||||||
|
│ ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
|
||||||
|
│ ├── $CIFAR10_10_4.0_0.08_AWGN_13h21m38s_on_Jun_02_2024.yaml
|
||||||
|
│ ├── ...
|
||||||
|
├── logs # training logs
|
||||||
|
│ ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
|
||||||
|
│ ├── tensorboard logs
|
||||||
|
│ ├── ...
|
||||||
|
├── eval # evaluation results
|
||||||
|
│ ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
|
||||||
|
│ ├── tensorboard logs
|
||||||
|
│ ├── ...
|
||||||
|
```
|
||||||
|
### Visualization
|
||||||
|
|
||||||
|
The `./visualization` directory contains the scripts for visualization of the training and evaluation results.
|
||||||
|
|
||||||
|
- `single_visualization.ipynb` is used to get demo of the model on single image like the demo above.
|
||||||
|
- `plot_visualization.ipynb` is used to get visualizations of the perfprmance of the model on different snr and ratio.
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
Model results and logs for the **CIFAR-10** dataset, tested under various SNR, ratio, and channel types, are available in the `./out` directory. The models' performance is approximately 5dB worse than reported in the paper, likely due to the implementation reflecting a real communication system. However, the performance trends are consistent with those in the paper.
|
||||||
|
|
||||||
|
|
||||||
|
<div style="display: flex;">
|
||||||
|
<img src="demo/cifar_0.08.png" alt="Image 1" style="flex: 1; max-width: 50%; height: auto;">
|
||||||
|
<div style="width: 0px;"></div> <!-- 为了让两个图像之间有一点间距 -->
|
||||||
|
<img src="demo/cifar_0.17.png" alt="Image 2" style="flex: 1; max-width: 50%; height: auto;">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### TO-DO
|
### TO-DO
|
||||||
- Add visualization of the model
|
- ~~Add visualization of the model~~
|
||||||
- plot the results with different snr and ratio
|
- ~~plot the results with different snr and ratio~~
|
||||||
|
- ~~add Rayleigh channel~~
|
||||||
|
- train on imagenet
|
||||||
|
- **Convert the real communication system to a complex communication system**
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
If you find (part of) this code useful for your research, please consider citing
|
If you find (part of) this code useful for your research, please consider citing
|
||||||
|
|||||||
@ -1,30 +0,0 @@
|
|||||||
from time import time
|
|
||||||
import multiprocessing as mp
|
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
from torchvision import transforms
|
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
transform = transforms.Compose([
|
|
||||||
torchvision.transforms.ToTensor(),
|
|
||||||
])
|
|
||||||
|
|
||||||
trainset = torchvision.datasets.CIFAR10(
|
|
||||||
root='./dataset/',
|
|
||||||
train=True, # 如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
|
|
||||||
download=True, # 如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
|
|
||||||
transform=transform
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"num of CPU: {mp.cpu_count()}")
|
|
||||||
for num_workers in range(2, mp.cpu_count(), 2):
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
|
||||||
trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True)
|
|
||||||
start = time()
|
|
||||||
for epoch in range(1, 3):
|
|
||||||
for i, data in enumerate(train_loader, 0):
|
|
||||||
pass
|
|
||||||
end = time()
|
|
||||||
print("Finish with:{} second, num_workers={}".format(end - start, num_workers))
|
|
||||||
44
channel.py
@ -2,26 +2,44 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def channel(channel_type='AWGN', snr=20):
|
class Channel(nn.Module):
|
||||||
def AWGN_channel(z_hat: torch.Tensor):
|
def __init__(self, channel_type='AWGN', snr=20):
|
||||||
|
if channel_type not in ['AWGN', 'Rayleigh']:
|
||||||
|
raise Exception('Unknown type of channel')
|
||||||
|
super(Channel, self).__init__()
|
||||||
|
self.channel_type = channel_type
|
||||||
|
self.snr = snr
|
||||||
|
|
||||||
|
def forward(self, z_hat):
|
||||||
if z_hat.dim() == 4:
|
if z_hat.dim() == 4:
|
||||||
# k = np.prod(z_hat.size()[1:])
|
# k = np.prod(z_hat.size()[1:])
|
||||||
k = torch.prod(torch.tensor(z_hat.size()[1:]))
|
k = torch.prod(torch.tensor(z_hat.size()[1:]))
|
||||||
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)/k
|
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True) / k
|
||||||
elif z_hat.dim() == 3:
|
elif z_hat.dim() == 3:
|
||||||
# k = np.prod(z_hat.size())
|
# k = np.prod(z_hat.size())
|
||||||
k = torch.prod(torch.tensor(z_hat.size()))
|
k = torch.prod(torch.tensor(z_hat.size()))
|
||||||
sig_pwr = torch.sum(torch.abs(z_hat).square())/k
|
sig_pwr = torch.sum(torch.abs(z_hat).square()) / k
|
||||||
noi_pwr = sig_pwr / (10 ** (snr / 10))
|
noi_pwr = sig_pwr / (10 ** (self.snr / 10))
|
||||||
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
|
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
|
||||||
|
if self.channel_type == 'Rayleigh':
|
||||||
|
# hc = torch.randn_like(z_hat) wrong implement before
|
||||||
|
hc = torch.randn(1, device = z_hat.device)
|
||||||
|
z_hat = hc * z_hat
|
||||||
|
|
||||||
return z_hat + noise
|
return z_hat + noise
|
||||||
|
|
||||||
def Rayleigh_channel(z_hat: torch.Tensor):
|
def get_channel(self):
|
||||||
pass
|
return self.channel_type, self.snr
|
||||||
|
|
||||||
if channel_type == 'AWGN':
|
|
||||||
return AWGN_channel
|
if __name__ == '__main__':
|
||||||
elif channel_type == 'Rayleigh':
|
# test
|
||||||
return Rayleigh_channel
|
channel = Channel(channel_type='AWGN', snr=10)
|
||||||
else:
|
z_hat = torch.randn(64, 10, 5, 5)
|
||||||
raise Exception('Unknown type of channel')
|
z_hat = channel(z_hat)
|
||||||
|
print(z_hat)
|
||||||
|
|
||||||
|
channel = Channel(channel_type='Rayleigh', snr=10)
|
||||||
|
z_hat = torch.randn(64, 10, 5, 5)
|
||||||
|
z_hat = channel(z_hat)
|
||||||
|
print(z_hat)
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 94 KiB |
BIN
demo/127_127.jpg
|
Before Width: | Height: | Size: 20 KiB |
BIN
demo/32_32.jpg
|
Before Width: | Height: | Size: 2.1 KiB |
|
Before Width: | Height: | Size: 1.6 MiB After Width: | Height: | Size: 1.6 MiB |
BIN
demo/cifar_0.08.png
Normal file
|
After Width: | Height: | Size: 81 KiB |
BIN
demo/cifar_0.17.png
Normal file
|
After Width: | Height: | Size: 81 KiB |
|
Before Width: | Height: | Size: 1.5 MiB After Width: | Height: | Size: 1.5 MiB |
102
eval.py
@ -1,51 +1,75 @@
|
|||||||
# to be implemented
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from utils import get_psnr
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from utils import get_psnr, image_normalization
|
|
||||||
import os
|
import os
|
||||||
from model import DeepJSCC
|
from model import DeepJSCC
|
||||||
|
from train import evaluate_epoch
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision import datasets
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from dataset import Vanilla
|
||||||
|
import yaml
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
import glob
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
|
def eval_snr(model, test_loader, writer, param, times=10):
|
||||||
|
snr_list = range(0, 26, 1)
|
||||||
|
for snr in snr_list:
|
||||||
|
model.change_channel(param['channel'], snr)
|
||||||
|
test_loss = 0
|
||||||
|
for i in range(times):
|
||||||
|
test_loss += evaluate_epoch(model, param, test_loader)
|
||||||
|
|
||||||
|
test_loss /= times
|
||||||
|
psnr = get_psnr(image=None, gt=None, mse=test_loss)
|
||||||
|
writer.add_scalar('psnr', psnr, snr)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def config_parser():
|
def process_config(config_path, output_dir, dataset_name, times):
|
||||||
import argparse
|
with open(config_path, 'r') as f:
|
||||||
parser = argparse.ArgumentParser()
|
config = yaml.load(f, Loader=yaml.UnsafeLoader)
|
||||||
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
assert dataset_name == config['dataset_name']
|
||||||
parser.add_argument('--saved', type=str, help='saved_path')
|
params = config['params']
|
||||||
parser.add_argument('--snr', default=20, type=int, help='snr')
|
c = config['inner_channel']
|
||||||
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
|
|
||||||
parser.add_argument('--times', default=10, type=int, help='num_workers')
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
if dataset_name == 'cifar10':
|
||||||
|
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||||
|
test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
|
||||||
|
download=True, transform=transform)
|
||||||
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||||
|
batch_size=params['batch_size'], num_workers=params['num_workers'])
|
||||||
|
|
||||||
|
elif dataset_name == 'imagenet':
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
||||||
|
|
||||||
|
test_dataset = Vanilla(root='../dataset/ImageNet/val', transform=transform)
|
||||||
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||||
|
batch_size=params['batch_size'], num_workers=params['num_workers'])
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown dataset')
|
||||||
|
|
||||||
|
name = os.path.splitext(os.path.basename(config_path))[0]
|
||||||
|
writer = SummaryWriter(os.path.join(output_dir, 'eval', name))
|
||||||
|
model = DeepJSCC(c=c)
|
||||||
|
model = model.to(params['device'])
|
||||||
|
pkl_list = glob.glob(os.path.join(output_dir, 'checkpoints', name, '*.pkl'))
|
||||||
|
model.load_state_dict(torch.load(pkl_list[-1]))
|
||||||
|
eval_snr(model, test_loader, writer, params, times)
|
||||||
|
writer.close()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = config_parser()
|
times = 10
|
||||||
transform = transforms.Compose([transforms.ToTensor()])
|
dataset_name = 'cifar10'
|
||||||
test_image = Image.open(args.test_image)
|
output_dir = './out'
|
||||||
test_image.load()
|
channel_type = 'AWGN'
|
||||||
test_image = transform(test_image)
|
config_dir = os.path.join(output_dir, 'configs')
|
||||||
|
config_files = [os.path.join(config_dir, name) for name in os.listdir(config_dir)
|
||||||
|
if (dataset_name in name or dataset_name.upper() in name) and channel_type in name and name.endswith('.yaml')]
|
||||||
|
|
||||||
file_name = os.path.basename(args.saved)
|
with ProcessPoolExecutor() as executor:
|
||||||
c = file_name.split('_')[-1].split('.')[0]
|
executor.map(process_config, config_files, [output_dir] * len(config_files), [dataset_name] * len(config_files), [times] * len(config_files))
|
||||||
c = int(c)
|
|
||||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=args.snr)
|
|
||||||
# model.load_state_dict(torch.load(args.saved))
|
|
||||||
model.load_state_dict(torch.load(args.saved,map_location=torch.device('cuda:0')))
|
|
||||||
model.change_channel(args.channel, args.snr)
|
|
||||||
|
|
||||||
psnr_all = 0.0
|
|
||||||
|
|
||||||
for i in range(args.times):
|
|
||||||
demo_image = model(test_image)
|
|
||||||
demo_image = image_normalization('denormalization')(demo_image)
|
|
||||||
gt = image_normalization('denormalization')(test_image)
|
|
||||||
psnr_all += get_psnr(demo_image, gt)
|
|
||||||
demo_image = image_normalization('normalization')(demo_image)
|
|
||||||
demo_image = torch.cat([test_image, demo_image], dim=1)
|
|
||||||
demo_image = transforms.ToPILImage()(demo_image)
|
|
||||||
demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1], args.test_image.split('/')[-1]))
|
|
||||||
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
49
model.py
@ -7,7 +7,7 @@ Created on Tue Dec 11:00:00 2023
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import channel
|
from channel import Channel
|
||||||
|
|
||||||
|
|
||||||
""" def _image_normalization(norm_type):
|
""" def _image_normalization(norm_type):
|
||||||
@ -58,7 +58,8 @@ class _TransConvWithPReLU(nn.Module):
|
|||||||
in_channels, out_channels, kernel_size, stride, padding, output_padding)
|
in_channels, out_channels, kernel_size, stride, padding, output_padding)
|
||||||
self.activate = activate
|
self.activate = activate
|
||||||
if activate == nn.PReLU():
|
if activate == nn.PReLU():
|
||||||
nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out',
|
||||||
|
nonlinearity='leaky_relu')
|
||||||
else:
|
else:
|
||||||
nn.init.xavier_normal_(self.transconv.weight)
|
nn.init.xavier_normal_(self.transconv.weight)
|
||||||
|
|
||||||
@ -104,7 +105,7 @@ class _Encoder(nn.Module):
|
|||||||
return _inner
|
return _inner
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
#x = self.imgae_normalization(x)
|
# x = self.imgae_normalization(x)
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.conv3(x)
|
x = self.conv3(x)
|
||||||
@ -118,7 +119,7 @@ class _Encoder(nn.Module):
|
|||||||
class _Decoder(nn.Module):
|
class _Decoder(nn.Module):
|
||||||
def __init__(self, c=1):
|
def __init__(self, c=1):
|
||||||
super(_Decoder, self).__init__()
|
super(_Decoder, self).__init__()
|
||||||
#self.imgae_normalization = _image_normalization(norm_type='denormalization')
|
# self.imgae_normalization = _image_normalization(norm_type='denormalization')
|
||||||
self.tconv1 = _TransConvWithPReLU(
|
self.tconv1 = _TransConvWithPReLU(
|
||||||
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
|
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
|
||||||
self.tconv2 = _TransConvWithPReLU(
|
self.tconv2 = _TransConvWithPReLU(
|
||||||
@ -136,22 +137,50 @@ class _Decoder(nn.Module):
|
|||||||
x = self.tconv3(x)
|
x = self.tconv3(x)
|
||||||
x = self.tconv4(x)
|
x = self.tconv4(x)
|
||||||
x = self.tconv5(x)
|
x = self.tconv5(x)
|
||||||
#x = self.imgae_normalization(x)
|
# x = self.imgae_normalization(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class DeepJSCC(nn.Module):
|
class DeepJSCC(nn.Module):
|
||||||
def __init__(self, c, channel_type='AWGN', snr=20):
|
def __init__(self, c, channel_type='AWGN', snr=None):
|
||||||
super(DeepJSCC, self).__init__()
|
super(DeepJSCC, self).__init__()
|
||||||
self.encoder = _Encoder(c=c)
|
self.encoder = _Encoder(c=c)
|
||||||
self.channel = channel.channel(channel_type, snr)
|
if snr is not None:
|
||||||
|
self.channel = Channel(channel_type, snr)
|
||||||
self.decoder = _Decoder(c=c)
|
self.decoder = _Decoder(c=c)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
z = self.encoder(x)
|
z = self.encoder(x)
|
||||||
z = self.channel(z)
|
if hasattr(self, 'channel') and self.channel is not None:
|
||||||
|
z = self.channel(z)
|
||||||
x_hat = self.decoder(z)
|
x_hat = self.decoder(z)
|
||||||
return x_hat
|
return x_hat
|
||||||
|
|
||||||
def change_channel(self, channel_type, snr):
|
def change_channel(self, channel_type='AWGN', snr=None):
|
||||||
self.channel = channel.channel(channel_type, snr)
|
if snr is None:
|
||||||
|
self.channel = None
|
||||||
|
else:
|
||||||
|
self.channel = Channel(channel_type, snr)
|
||||||
|
|
||||||
|
def get_channel(self):
|
||||||
|
if hasattr(self, 'channel') and self.channel is not None:
|
||||||
|
return self.channel.get_channel()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def loss(self, prd, gt):
|
||||||
|
criterion = nn.MSELoss(reduction='mean')
|
||||||
|
loss = criterion(prd, gt)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = DeepJSCC(c=20)
|
||||||
|
print(model)
|
||||||
|
x = torch.rand(1, 3, 32, 32)
|
||||||
|
y = model(x)
|
||||||
|
print(y.size())
|
||||||
|
print(y)
|
||||||
|
print(model.encoder.norm)
|
||||||
|
print(model.encoder.norm(y))
|
||||||
|
print(model.encoder.norm(y).size())
|
||||||
|
print(model.encoder.norm(y).size()[1:])
|
||||||
|
|||||||
@ -0,0 +1,46 @@
|
|||||||
|
dataset_name: cifar10
|
||||||
|
inner_channel: 10
|
||||||
|
params:
|
||||||
|
ReduceLROnPlateau: false
|
||||||
|
batch_size: 64
|
||||||
|
channel: AWGN
|
||||||
|
dataset: cifar10
|
||||||
|
device: cuda:1
|
||||||
|
disable_tqdm: false
|
||||||
|
epochs: 1000
|
||||||
|
gamma: 0.1
|
||||||
|
if_scheduler: true
|
||||||
|
init_lr: 0.001
|
||||||
|
lr_reduce_factor: 0.5
|
||||||
|
lr_schedule_patience: 15
|
||||||
|
max_time: 12
|
||||||
|
min_lr: 1.0e-05
|
||||||
|
num_workers: 4
|
||||||
|
out_dir: ./out
|
||||||
|
parallel: false
|
||||||
|
ratio: 0.08333333333333333
|
||||||
|
ratio_list:
|
||||||
|
- 0.08333333333333333
|
||||||
|
seed: 42
|
||||||
|
snr: 1.0
|
||||||
|
snr_list:
|
||||||
|
- 1.0
|
||||||
|
step_size: 640
|
||||||
|
weight_decay: 0.0005
|
||||||
|
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||||
|
- !!python/object/apply:numpy.dtype
|
||||||
|
args:
|
||||||
|
- i8
|
||||||
|
- false
|
||||||
|
- true
|
||||||
|
state: !!python/tuple
|
||||||
|
- 3
|
||||||
|
- <
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- -1
|
||||||
|
- -1
|
||||||
|
- 0
|
||||||
|
- !!binary |
|
||||||
|
41QCAAAAAAA=
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
dataset_name: cifar10
|
||||||
|
inner_channel: 10
|
||||||
|
params:
|
||||||
|
ReduceLROnPlateau: false
|
||||||
|
batch_size: 64
|
||||||
|
channel: AWGN
|
||||||
|
dataset: cifar10
|
||||||
|
device: cuda:1
|
||||||
|
disable_tqdm: false
|
||||||
|
epochs: 1000
|
||||||
|
gamma: 0.1
|
||||||
|
if_scheduler: true
|
||||||
|
init_lr: 0.001
|
||||||
|
lr_reduce_factor: 0.5
|
||||||
|
lr_schedule_patience: 15
|
||||||
|
max_time: 12
|
||||||
|
min_lr: 1.0e-05
|
||||||
|
num_workers: 4
|
||||||
|
out_dir: ./out
|
||||||
|
parallel: false
|
||||||
|
ratio: 0.08333333333333333
|
||||||
|
ratio_list:
|
||||||
|
- 0.08333333333333333
|
||||||
|
seed: 42
|
||||||
|
snr: 13.0
|
||||||
|
snr_list:
|
||||||
|
- 13.0
|
||||||
|
step_size: 640
|
||||||
|
weight_decay: 0.0005
|
||||||
|
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||||
|
- !!python/object/apply:numpy.dtype
|
||||||
|
args:
|
||||||
|
- i8
|
||||||
|
- false
|
||||||
|
- true
|
||||||
|
state: !!python/tuple
|
||||||
|
- 3
|
||||||
|
- <
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- -1
|
||||||
|
- -1
|
||||||
|
- 0
|
||||||
|
- !!binary |
|
||||||
|
41QCAAAAAAA=
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
dataset_name: cifar10
|
||||||
|
inner_channel: 10
|
||||||
|
params:
|
||||||
|
ReduceLROnPlateau: false
|
||||||
|
batch_size: 64
|
||||||
|
channel: AWGN
|
||||||
|
dataset: cifar10
|
||||||
|
device: cuda:1
|
||||||
|
disable_tqdm: false
|
||||||
|
epochs: 1000
|
||||||
|
gamma: 0.1
|
||||||
|
if_scheduler: true
|
||||||
|
init_lr: 0.001
|
||||||
|
lr_reduce_factor: 0.5
|
||||||
|
lr_schedule_patience: 15
|
||||||
|
max_time: 12
|
||||||
|
min_lr: 1.0e-05
|
||||||
|
num_workers: 4
|
||||||
|
out_dir: ./out
|
||||||
|
parallel: false
|
||||||
|
ratio: 0.08333333333333333
|
||||||
|
ratio_list:
|
||||||
|
- 0.08333333333333333
|
||||||
|
seed: 42
|
||||||
|
snr: 19.0
|
||||||
|
snr_list:
|
||||||
|
- 19.0
|
||||||
|
step_size: 640
|
||||||
|
weight_decay: 0.0005
|
||||||
|
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||||
|
- !!python/object/apply:numpy.dtype
|
||||||
|
args:
|
||||||
|
- i8
|
||||||
|
- false
|
||||||
|
- true
|
||||||
|
state: !!python/tuple
|
||||||
|
- 3
|
||||||
|
- <
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- -1
|
||||||
|
- -1
|
||||||
|
- 0
|
||||||
|
- !!binary |
|
||||||
|
41QCAAAAAAA=
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
dataset_name: cifar10
|
||||||
|
inner_channel: 10
|
||||||
|
params:
|
||||||
|
ReduceLROnPlateau: false
|
||||||
|
batch_size: 64
|
||||||
|
channel: AWGN
|
||||||
|
dataset: cifar10
|
||||||
|
device: cuda:1
|
||||||
|
disable_tqdm: false
|
||||||
|
epochs: 1000
|
||||||
|
gamma: 0.1
|
||||||
|
if_scheduler: true
|
||||||
|
init_lr: 0.001
|
||||||
|
lr_reduce_factor: 0.5
|
||||||
|
lr_schedule_patience: 15
|
||||||
|
max_time: 12
|
||||||
|
min_lr: 1.0e-05
|
||||||
|
num_workers: 4
|
||||||
|
out_dir: ./out
|
||||||
|
parallel: false
|
||||||
|
ratio: 0.08333333333333333
|
||||||
|
ratio_list:
|
||||||
|
- 0.08333333333333333
|
||||||
|
seed: 42
|
||||||
|
snr: 4.0
|
||||||
|
snr_list:
|
||||||
|
- 4.0
|
||||||
|
step_size: 640
|
||||||
|
weight_decay: 0.0005
|
||||||
|
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||||
|
- !!python/object/apply:numpy.dtype
|
||||||
|
args:
|
||||||
|
- i8
|
||||||
|
- false
|
||||||
|
- true
|
||||||
|
state: !!python/tuple
|
||||||
|
- 3
|
||||||
|
- <
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- -1
|
||||||
|
- -1
|
||||||
|
- 0
|
||||||
|
- !!binary |
|
||||||
|
41QCAAAAAAA=
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
dataset_name: cifar10
|
||||||
|
inner_channel: 10
|
||||||
|
params:
|
||||||
|
ReduceLROnPlateau: false
|
||||||
|
batch_size: 64
|
||||||
|
channel: AWGN
|
||||||
|
dataset: cifar10
|
||||||
|
device: cuda:1
|
||||||
|
disable_tqdm: false
|
||||||
|
epochs: 1000
|
||||||
|
gamma: 0.1
|
||||||
|
if_scheduler: true
|
||||||
|
init_lr: 0.001
|
||||||
|
lr_reduce_factor: 0.5
|
||||||
|
lr_schedule_patience: 15
|
||||||
|
max_time: 12
|
||||||
|
min_lr: 1.0e-05
|
||||||
|
num_workers: 4
|
||||||
|
out_dir: ./out
|
||||||
|
parallel: false
|
||||||
|
ratio: 0.08333333333333333
|
||||||
|
ratio_list:
|
||||||
|
- 0.08333333333333333
|
||||||
|
seed: 42
|
||||||
|
snr: 7.0
|
||||||
|
snr_list:
|
||||||
|
- 7.0
|
||||||
|
step_size: 640
|
||||||
|
weight_decay: 0.0005
|
||||||
|
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||||
|
- !!python/object/apply:numpy.dtype
|
||||||
|
args:
|
||||||
|
- i8
|
||||||
|
- false
|
||||||
|
- true
|
||||||
|
state: !!python/tuple
|
||||||
|
- 3
|
||||||
|
- <
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- -1
|
||||||
|
- -1
|
||||||
|
- 0
|
||||||
|
- !!binary |
|
||||||
|
41QCAAAAAAA=
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
dataset_name: cifar10
|
||||||
|
inner_channel: 20
|
||||||
|
params:
|
||||||
|
ReduceLROnPlateau: false
|
||||||
|
batch_size: 64
|
||||||
|
channel: AWGN
|
||||||
|
dataset: cifar10
|
||||||
|
device: cuda:0
|
||||||
|
disable_tqdm: false
|
||||||
|
epochs: 1000
|
||||||
|
gamma: 0.1
|
||||||
|
if_scheduler: true
|
||||||
|
init_lr: 0.001
|
||||||
|
lr_reduce_factor: 0.5
|
||||||
|
lr_schedule_patience: 15
|
||||||
|
max_time: 12
|
||||||
|
min_lr: 1.0e-05
|
||||||
|
num_workers: 4
|
||||||
|
out_dir: ./out
|
||||||
|
parallel: false
|
||||||
|
ratio: 0.16666666666666666
|
||||||
|
ratio_list:
|
||||||
|
- 0.16666666666666666
|
||||||
|
seed: 42
|
||||||
|
snr: 1.0
|
||||||
|
snr_list:
|
||||||
|
- 1.0
|
||||||
|
step_size: 640
|
||||||
|
weight_decay: 0.0005
|
||||||
|
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||||
|
- !!python/object/apply:numpy.dtype
|
||||||
|
args:
|
||||||
|
- i8
|
||||||
|
- false
|
||||||
|
- true
|
||||||
|
state: !!python/tuple
|
||||||
|
- 3
|
||||||
|
- <
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- -1
|
||||||
|
- -1
|
||||||
|
- 0
|
||||||
|
- !!binary |
|
||||||
|
bZMCAAAAAAA=
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
dataset_name: cifar10
|
||||||
|
inner_channel: 20
|
||||||
|
params:
|
||||||
|
ReduceLROnPlateau: false
|
||||||
|
batch_size: 64
|
||||||
|
channel: AWGN
|
||||||
|
dataset: cifar10
|
||||||
|
device: cuda:0
|
||||||
|
disable_tqdm: false
|
||||||
|
epochs: 1000
|
||||||
|
gamma: 0.1
|
||||||
|
if_scheduler: true
|
||||||
|
init_lr: 0.001
|
||||||
|
lr_reduce_factor: 0.5
|
||||||
|
lr_schedule_patience: 15
|
||||||
|
max_time: 12
|
||||||
|
min_lr: 1.0e-05
|
||||||
|
num_workers: 4
|
||||||
|
out_dir: ./out
|
||||||
|
parallel: false
|
||||||
|
ratio: 0.16666666666666666
|
||||||
|
ratio_list:
|
||||||
|
- 0.16666666666666666
|
||||||
|
seed: 42
|
||||||
|
snr: 13.0
|
||||||
|
snr_list:
|
||||||
|
- 13.0
|
||||||
|
step_size: 640
|
||||||
|
weight_decay: 0.0005
|
||||||
|
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||||
|
- !!python/object/apply:numpy.dtype
|
||||||
|
args:
|
||||||
|
- i8
|
||||||
|
- false
|
||||||
|
- true
|
||||||
|
state: !!python/tuple
|
||||||
|
- 3
|
||||||
|
- <
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- -1
|
||||||
|
- -1
|
||||||
|
- 0
|
||||||
|
- !!binary |
|
||||||
|
bZMCAAAAAAA=
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
dataset_name: cifar10
|
||||||
|
inner_channel: 20
|
||||||
|
params:
|
||||||
|
ReduceLROnPlateau: false
|
||||||
|
batch_size: 64
|
||||||
|
channel: AWGN
|
||||||
|
dataset: cifar10
|
||||||
|
device: cuda:0
|
||||||
|
disable_tqdm: false
|
||||||
|
epochs: 1000
|
||||||
|
gamma: 0.1
|
||||||
|
if_scheduler: true
|
||||||
|
init_lr: 0.001
|
||||||
|
lr_reduce_factor: 0.5
|
||||||
|
lr_schedule_patience: 15
|
||||||
|
max_time: 12
|
||||||
|
min_lr: 1.0e-05
|
||||||
|
num_workers: 4
|
||||||
|
out_dir: ./out
|
||||||
|
parallel: false
|
||||||
|
ratio: 0.16666666666666666
|
||||||
|
ratio_list:
|
||||||
|
- 0.16666666666666666
|
||||||
|
seed: 42
|
||||||
|
snr: 19.0
|
||||||
|
snr_list:
|
||||||
|
- 19.0
|
||||||
|
step_size: 640
|
||||||
|
weight_decay: 0.0005
|
||||||
|
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||||
|
- !!python/object/apply:numpy.dtype
|
||||||
|
args:
|
||||||
|
- i8
|
||||||
|
- false
|
||||||
|
- true
|
||||||
|
state: !!python/tuple
|
||||||
|
- 3
|
||||||
|
- <
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- -1
|
||||||
|
- -1
|
||||||
|
- 0
|
||||||
|
- !!binary |
|
||||||
|
bZMCAAAAAAA=
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
dataset_name: cifar10
|
||||||
|
inner_channel: 20
|
||||||
|
params:
|
||||||
|
ReduceLROnPlateau: false
|
||||||
|
batch_size: 64
|
||||||
|
channel: AWGN
|
||||||
|
dataset: cifar10
|
||||||
|
device: cuda:0
|
||||||
|
disable_tqdm: false
|
||||||
|
epochs: 1000
|
||||||
|
gamma: 0.1
|
||||||
|
if_scheduler: true
|
||||||
|
init_lr: 0.001
|
||||||
|
lr_reduce_factor: 0.5
|
||||||
|
lr_schedule_patience: 15
|
||||||
|
max_time: 12
|
||||||
|
min_lr: 1.0e-05
|
||||||
|
num_workers: 4
|
||||||
|
out_dir: ./out
|
||||||
|
parallel: false
|
||||||
|
ratio: 0.16666666666666666
|
||||||
|
ratio_list:
|
||||||
|
- 0.16666666666666666
|
||||||
|
seed: 42
|
||||||
|
snr: 4.0
|
||||||
|
snr_list:
|
||||||
|
- 4.0
|
||||||
|
step_size: 640
|
||||||
|
weight_decay: 0.0005
|
||||||
|
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||||
|
- !!python/object/apply:numpy.dtype
|
||||||
|
args:
|
||||||
|
- i8
|
||||||
|
- false
|
||||||
|
- true
|
||||||
|
state: !!python/tuple
|
||||||
|
- 3
|
||||||
|
- <
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- -1
|
||||||
|
- -1
|
||||||
|
- 0
|
||||||
|
- !!binary |
|
||||||
|
bZMCAAAAAAA=
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
dataset_name: cifar10
|
||||||
|
inner_channel: 20
|
||||||
|
params:
|
||||||
|
ReduceLROnPlateau: false
|
||||||
|
batch_size: 64
|
||||||
|
channel: AWGN
|
||||||
|
dataset: cifar10
|
||||||
|
device: cuda:0
|
||||||
|
disable_tqdm: false
|
||||||
|
epochs: 1000
|
||||||
|
gamma: 0.1
|
||||||
|
if_scheduler: true
|
||||||
|
init_lr: 0.001
|
||||||
|
lr_reduce_factor: 0.5
|
||||||
|
lr_schedule_patience: 15
|
||||||
|
max_time: 12
|
||||||
|
min_lr: 1.0e-05
|
||||||
|
num_workers: 4
|
||||||
|
out_dir: ./out
|
||||||
|
parallel: false
|
||||||
|
ratio: 0.16666666666666666
|
||||||
|
ratio_list:
|
||||||
|
- 0.16666666666666666
|
||||||
|
seed: 42
|
||||||
|
snr: 7.0
|
||||||
|
snr_list:
|
||||||
|
- 7.0
|
||||||
|
step_size: 640
|
||||||
|
weight_decay: 0.0005
|
||||||
|
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
|
||||||
|
- !!python/object/apply:numpy.dtype
|
||||||
|
args:
|
||||||
|
- i8
|
||||||
|
- false
|
||||||
|
- true
|
||||||
|
state: !!python/tuple
|
||||||
|
- 3
|
||||||
|
- <
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- null
|
||||||
|
- -1
|
||||||
|
- -1
|
||||||
|
- 0
|
||||||
|
- !!binary |
|
||||||
|
bZMCAAAAAAA=
|
||||||
|
Before Width: | Height: | Size: 1.6 MiB |
|
Before Width: | Height: | Size: 1.2 MiB |
|
Before Width: | Height: | Size: 1.4 MiB |
|
Before Width: | Height: | Size: 1.1 MiB |
|
Before Width: | Height: | Size: 1.6 MiB |
|
Before Width: | Height: | Size: 1.4 MiB |
|
Before Width: | Height: | Size: 1.3 MiB |
|
Before Width: | Height: | Size: 1.4 MiB |
|
Before Width: | Height: | Size: 1.1 MiB |
|
Before Width: | Height: | Size: 1.6 MiB |
|
Before Width: | Height: | Size: 1.2 MiB |
|
Before Width: | Height: | Size: 1.0 MiB |
367
train.py
@ -14,60 +14,301 @@ import torch.optim as optim
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from model import DeepJSCC, ratio2filtersize
|
from model import DeepJSCC, ratio2filtersize
|
||||||
from torch.nn.parallel import DataParallel
|
from torch.nn.parallel import DataParallel
|
||||||
from utils import image_normalization
|
from utils import image_normalization, set_seed, save_model, view_model_param
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from dataset import Vanilla
|
from dataset import Vanilla
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
import glob
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
def train_epoch(model, optimizer, param, data_loader):
|
||||||
np.random.seed(seed)
|
model.train()
|
||||||
torch.manual_seed(seed)
|
epoch_loss = 0
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
for iter, (images, _) in enumerate(data_loader):
|
||||||
torch.backends.cudnn.deterministic = True
|
images = images.cuda() if param['parallel'] and torch.cuda.device_count(
|
||||||
torch.backends.cudnn.benchmark = False
|
) > 1 else images.to(param['device'])
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model.forward(images)
|
||||||
|
outputs = image_normalization('denormalization')(outputs)
|
||||||
|
images = image_normalization('denormalization')(images)
|
||||||
|
loss = model.loss(images, outputs) if not param['parallel'] else model.module.loss(
|
||||||
|
images, outputs)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
epoch_loss += loss.detach().item()
|
||||||
|
epoch_loss /= (iter + 1)
|
||||||
|
|
||||||
|
return epoch_loss, optimizer
|
||||||
|
|
||||||
|
|
||||||
def config_parser():
|
def evaluate_epoch(model, param, data_loader):
|
||||||
|
model.eval()
|
||||||
|
epoch_loss = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for iter, (images, _) in enumerate(data_loader):
|
||||||
|
images = images.cuda() if param['parallel'] and torch.cuda.device_count(
|
||||||
|
) > 1 else images.to(param['device'])
|
||||||
|
outputs = model.forward(images)
|
||||||
|
outputs = image_normalization('denormalization')(outputs)
|
||||||
|
images = image_normalization('denormalization')(images)
|
||||||
|
loss = model.loss(images, outputs) if not param['parallel'] else model.module.loss(
|
||||||
|
images, outputs)
|
||||||
|
epoch_loss += loss.detach().item()
|
||||||
|
epoch_loss /= (iter + 1)
|
||||||
|
|
||||||
|
return epoch_loss
|
||||||
|
|
||||||
|
|
||||||
|
def config_parser_pipeline():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
|
|
||||||
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
|
|
||||||
parser.add_argument('--epochs', default=256, type=int, help='number of epochs')
|
|
||||||
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
|
|
||||||
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
|
|
||||||
parser.add_argument('--channel', default='AWGN', type=str,
|
|
||||||
choices=['AWGN', 'Rayleigh'], help='channel type')
|
|
||||||
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
|
||||||
parser.add_argument('--snr_list', default=['19', '13',
|
|
||||||
'7', '4', '1'], nargs='+', help='snr_list')
|
|
||||||
parser.add_argument('--ratio_list', default=['1/3',
|
|
||||||
'1/6', '1/12'], nargs='+', help='ratio_list')
|
|
||||||
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
|
||||||
parser.add_argument('--dataset', default='cifar10', type=str,
|
parser.add_argument('--dataset', default='cifar10', type=str,
|
||||||
choices=['cifar10', 'imagenet'], help='dataset')
|
choices=['cifar10', 'imagenet'], help='dataset')
|
||||||
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
|
parser.add_argument('--out', default='./out', type=str, help='out_path')
|
||||||
parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler')
|
parser.add_argument('--disable_tqdm', default=False, type=bool, help='disable_tqdm')
|
||||||
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
|
|
||||||
parser.add_argument('--device', default='cuda:0', type=str, help='device')
|
parser.add_argument('--device', default='cuda:0', type=str, help='device')
|
||||||
parser.add_argument('--gamma', default=0.5, type=float, help='gamma')
|
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
|
||||||
parser.add_argument('--disable_tqdm', default=True, type=bool, help='disable_tqdm')
|
parser.add_argument('--snr_list', default=['19', '13',
|
||||||
|
'7', '4', '1'], nargs='+', help='snr_list')
|
||||||
|
parser.add_argument('--ratio_list', default=['1/6', '1/12'], nargs='+', help='ratio_list')
|
||||||
|
parser.add_argument('--channel', default='AWGN', type=str,
|
||||||
|
choices=['AWGN', 'Rayleigh'], help='channel')
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main_pipeline():
|
||||||
args = config_parser()
|
args = config_parser_pipeline()
|
||||||
|
|
||||||
|
print("Training Start")
|
||||||
|
dataset_name = args.dataset
|
||||||
|
out_dir = args.out
|
||||||
args.snr_list = list(map(float, args.snr_list))
|
args.snr_list = list(map(float, args.snr_list))
|
||||||
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
|
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
|
||||||
set_seed(args.seed)
|
params = {}
|
||||||
print("Training Start")
|
params['disable_tqdm'] = args.disable_tqdm
|
||||||
for ratio in args.ratio_list:
|
params['dataset'] = dataset_name
|
||||||
for snr in args.snr_list:
|
params['out_dir'] = out_dir
|
||||||
train(args, ratio, snr)
|
params['device'] = args.device
|
||||||
|
params['snr_list'] = args.snr_list
|
||||||
|
params['ratio_list'] = args.ratio_list
|
||||||
|
params['channel'] = args.channel
|
||||||
|
if dataset_name == 'cifar10':
|
||||||
|
params['batch_size'] = 64 # 1024
|
||||||
|
params['num_workers'] = 4
|
||||||
|
params['epochs'] = 1000
|
||||||
|
params['init_lr'] = 1e-3 # 1e-2
|
||||||
|
params['weight_decay'] = 5e-4
|
||||||
|
params['parallel'] = False
|
||||||
|
params['if_scheduler'] = True
|
||||||
|
params['step_size'] = 640
|
||||||
|
params['gamma'] = 0.1
|
||||||
|
params['seed'] = 42
|
||||||
|
params['ReduceLROnPlateau'] = False
|
||||||
|
params['lr_reduce_factor'] = 0.5
|
||||||
|
params['lr_schedule_patience'] = 15
|
||||||
|
params['max_time'] = 12
|
||||||
|
params['min_lr'] = 1e-5
|
||||||
|
elif dataset_name == 'imagenet':
|
||||||
|
params['batch_size'] = 32
|
||||||
|
params['num_workers'] = 4
|
||||||
|
params['epochs'] = 300
|
||||||
|
params['init_lr'] = 1e-4
|
||||||
|
params['weight_decay'] = 5e-4
|
||||||
|
params['parallel'] = True
|
||||||
|
params['if_scheduler'] = True
|
||||||
|
params['gamma'] = 0.1
|
||||||
|
params['seed'] = 42
|
||||||
|
params['ReduceLROnPlateau'] = True
|
||||||
|
params['lr_reduce_factor'] = 0.5
|
||||||
|
params['lr_schedule_patience'] = 15
|
||||||
|
params['max_time'] = 12
|
||||||
|
params['min_lr'] = 1e-5
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown dataset')
|
||||||
|
|
||||||
|
set_seed(params['seed'])
|
||||||
|
|
||||||
|
for ratio in params['ratio_list']:
|
||||||
|
for snr in params['snr_list']:
|
||||||
|
params['ratio'] = ratio
|
||||||
|
params['snr'] = snr
|
||||||
|
|
||||||
|
train_pipeline(params)
|
||||||
|
|
||||||
|
|
||||||
def train(args: config_parser(), ratio: float, snr: float):
|
# add train_pipeline to with only dataset_name args
|
||||||
|
def train_pipeline(params):
|
||||||
|
|
||||||
|
dataset_name = params['dataset']
|
||||||
|
# load data
|
||||||
|
if dataset_name == 'cifar10':
|
||||||
|
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||||
|
train_dataset = datasets.CIFAR10(root='../dataset/', train=True,
|
||||||
|
download=True, transform=transform)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||||
|
batch_size=params['batch_size'], num_workers=params['num_workers'])
|
||||||
|
test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
|
||||||
|
download=True, transform=transform)
|
||||||
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||||
|
batch_size=params['batch_size'], num_workers=params['num_workers'])
|
||||||
|
|
||||||
|
elif dataset_name == 'imagenet':
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
||||||
|
print("loading data of imagenet")
|
||||||
|
train_dataset = datasets.ImageFolder(root='../dataset/ImageNet/train', transform=transform)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||||
|
batch_size=params['batch_size'], num_workers=params['num_workers'])
|
||||||
|
test_dataset = Vanilla(root='../dataset/ImageNet/val', transform=transform)
|
||||||
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||||
|
batch_size=params['batch_size'], num_workers=params['num_workers'])
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown dataset')
|
||||||
|
|
||||||
|
# create model
|
||||||
|
image_fisrt = train_dataset.__getitem__(0)[0]
|
||||||
|
c = ratio2filtersize(image_fisrt, params['ratio'])
|
||||||
|
print("The snr is {}, the inner channel is {}, the ratio is {:.2f}".format(
|
||||||
|
params['snr'], c, params['ratio']))
|
||||||
|
model = DeepJSCC(c=c, channel_type=params['channel'], snr=params['snr'])
|
||||||
|
|
||||||
|
# init exp dir
|
||||||
|
out_dir = params['out_dir']
|
||||||
|
phaser = dataset_name.upper() + '_' + str(c) + '_' + str(params['snr']) + '_' + \
|
||||||
|
"{:.2f}".format(params['ratio']) + '_' + str(params['channel']) + \
|
||||||
|
'_' + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
|
||||||
|
root_log_dir = out_dir + '/' + 'logs/' + phaser
|
||||||
|
root_ckpt_dir = out_dir + '/' + 'checkpoints/' + phaser
|
||||||
|
root_config_dir = out_dir + '/' + 'configs/' + phaser
|
||||||
|
writer = SummaryWriter(log_dir=root_log_dir)
|
||||||
|
|
||||||
|
# model init
|
||||||
|
device = torch.device(params['device'] if torch.cuda.is_available() else 'cpu')
|
||||||
|
if params['parallel'] and torch.cuda.device_count() > 1:
|
||||||
|
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
||||||
|
model = model.cuda()
|
||||||
|
else:
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
# opt
|
||||||
|
optimizer = optim.Adam(
|
||||||
|
model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
|
||||||
|
if params['if_scheduler'] and not params['ReduceLROnPlateau']:
|
||||||
|
scheduler = optim.lr_scheduler.StepLR(
|
||||||
|
optimizer, step_size=params['step_size'], gamma=params['gamma'])
|
||||||
|
elif params['ReduceLROnPlateau']:
|
||||||
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
|
||||||
|
factor=params['lr_reduce_factor'],
|
||||||
|
patience=params['lr_schedule_patience'],
|
||||||
|
verbose=False)
|
||||||
|
else:
|
||||||
|
print("No scheduler")
|
||||||
|
scheduler = None
|
||||||
|
|
||||||
|
writer.add_text('config', str(params))
|
||||||
|
t0 = time.time()
|
||||||
|
epoch_train_losses, epoch_val_losses = [], []
|
||||||
|
per_epoch_time = []
|
||||||
|
|
||||||
|
# train
|
||||||
|
# At any point you can hit Ctrl + C to break out of training early.
|
||||||
|
try:
|
||||||
|
with tqdm(range(params['epochs']), disable=params['disable_tqdm']) as t:
|
||||||
|
for epoch in t:
|
||||||
|
|
||||||
|
t.set_description('Epoch %d' % epoch)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
epoch_train_loss, optimizer = train_epoch(
|
||||||
|
model, optimizer, params, train_loader)
|
||||||
|
|
||||||
|
epoch_val_loss = evaluate_epoch(model, params, test_loader)
|
||||||
|
|
||||||
|
epoch_train_losses.append(epoch_train_loss)
|
||||||
|
epoch_val_losses.append(epoch_val_loss)
|
||||||
|
|
||||||
|
writer.add_scalar('train/_loss', epoch_train_loss, epoch)
|
||||||
|
writer.add_scalar('val/_loss', epoch_val_loss, epoch)
|
||||||
|
writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)
|
||||||
|
|
||||||
|
t.set_postfix(time=time.time() - start, lr=optimizer.param_groups[0]['lr'],
|
||||||
|
train_loss=epoch_train_loss, val_loss=epoch_val_loss)
|
||||||
|
|
||||||
|
per_epoch_time.append(time.time() - start)
|
||||||
|
|
||||||
|
# Saving checkpoint
|
||||||
|
|
||||||
|
if not os.path.exists(root_ckpt_dir):
|
||||||
|
os.makedirs(root_ckpt_dir)
|
||||||
|
torch.save(model.state_dict(), '{}.pkl'.format(
|
||||||
|
root_ckpt_dir + "/epoch_" + str(epoch)))
|
||||||
|
|
||||||
|
files = glob.glob(root_ckpt_dir + '/*.pkl')
|
||||||
|
for file in files:
|
||||||
|
epoch_nb = file.split('_')[-1]
|
||||||
|
epoch_nb = int(epoch_nb.split('.')[0])
|
||||||
|
if epoch_nb < epoch - 1:
|
||||||
|
os.remove(file)
|
||||||
|
|
||||||
|
if params['ReduceLROnPlateau'] and scheduler is not None:
|
||||||
|
scheduler.step(epoch_val_loss)
|
||||||
|
elif params['if_scheduler'] and not params['ReduceLROnPlateau']:
|
||||||
|
scheduler.step() # use only information from the validation loss
|
||||||
|
|
||||||
|
if optimizer.param_groups[0]['lr'] < params['min_lr']:
|
||||||
|
print("\n!! LR EQUAL TO MIN LR SET.")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Stop training after params['max_time'] hours
|
||||||
|
if time.time() - t0 > params['max_time'] * 3600:
|
||||||
|
print('-' * 89)
|
||||||
|
print("Max_time for training elapsed {:.2f} hours, so stopping".format(
|
||||||
|
params['max_time']))
|
||||||
|
break
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print('-' * 89)
|
||||||
|
print('Exiting from training early because of KeyboardInterrupt')
|
||||||
|
|
||||||
|
test_loss = evaluate_epoch(model, params, test_loader)
|
||||||
|
train_loss = evaluate_epoch(model, params, train_loader)
|
||||||
|
print("Test Accuracy: {:.4f}".format(test_loss))
|
||||||
|
print("Train Accuracy: {:.4f}".format(train_loss))
|
||||||
|
print("Convergence Time (Epochs): {:.4f}".format(epoch))
|
||||||
|
print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - t0))
|
||||||
|
print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))
|
||||||
|
|
||||||
|
"""
|
||||||
|
Write the results in out_dir/results folder
|
||||||
|
"""
|
||||||
|
|
||||||
|
writer.add_text(tag='result', text_string="""Dataset: {}\nparams={}\n\nTotal Parameters: {}\n\n
|
||||||
|
FINAL RESULTS\nTEST Loss: {:.4f}\nTRAIN Loss: {:.4f}\n\n
|
||||||
|
Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""
|
||||||
|
.format(dataset_name, params, view_model_param(model), np.mean(np.array(train_loss)),
|
||||||
|
np.mean(np.array(test_loss)), epoch, (time.time() - t0) / 3600, np.mean(per_epoch_time)))
|
||||||
|
writer.close()
|
||||||
|
if not os.path.exists(os.path.dirname(root_config_dir)):
|
||||||
|
os.makedirs(os.path.dirname(root_config_dir))
|
||||||
|
with open(root_config_dir + '.yaml', 'w') as f:
|
||||||
|
dict_yaml = {'dataset_name': dataset_name, 'params': params,
|
||||||
|
'inner_channel': c, 'total_parameters': view_model_param(model)}
|
||||||
|
import yaml
|
||||||
|
yaml.dump(dict_yaml, f)
|
||||||
|
|
||||||
|
del model, optimizer, scheduler, train_loader, test_loader
|
||||||
|
del writer
|
||||||
|
|
||||||
|
|
||||||
|
def train(args, ratio: float, snr: float): # deprecated
|
||||||
|
|
||||||
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
||||||
# load data
|
# load data
|
||||||
@ -141,23 +382,49 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
model.train()
|
model.train()
|
||||||
# epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
# epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
||||||
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format(
|
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format(
|
||||||
epoch, run_loss/len(train_loader), test_mse/len(test_loader), optimizer.param_groups[0]['lr']))
|
epoch, run_loss / len(train_loader), test_mse / len(test_loader), optimizer.param_groups[0]['lr']))
|
||||||
save_model(model, args.saved, args.saved +
|
save_model(model, args.saved, args.saved + '/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'
|
||||||
'/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, args.batch_size, c))
|
.format(args.dataset, args.epochs, ratio, snr, args.batch_size, c))
|
||||||
|
|
||||||
|
|
||||||
def save_model(model, dir, path):
|
def config_parser(): # deprecated
|
||||||
os.makedirs(dir, exist_ok=True)
|
import argparse
|
||||||
flag = 1
|
parser = argparse.ArgumentParser()
|
||||||
while True:
|
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
|
||||||
if os.path.exists(path):
|
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
|
||||||
path = path+'_'+str(flag)
|
parser.add_argument('--epochs', default=256, type=int, help='number of epochs')
|
||||||
flag += 1
|
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
|
||||||
else:
|
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
|
||||||
break
|
parser.add_argument('--channel', default='AWGN', type=str,
|
||||||
torch.save(model.state_dict(), path)
|
choices=['AWGN', 'Rayleigh'], help='channel type')
|
||||||
print("Model saved in {}".format(path))
|
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
||||||
|
parser.add_argument('--snr_list', default=['19', '13',
|
||||||
|
'7', '4', '1'], nargs='+', help='snr_list')
|
||||||
|
parser.add_argument('--ratio_list', default=['1/3',
|
||||||
|
'1/6', '1/12'], nargs='+', help='ratio_list')
|
||||||
|
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
||||||
|
parser.add_argument('--dataset', default='cifar10', type=str,
|
||||||
|
choices=['cifar10', 'imagenet'], help='dataset')
|
||||||
|
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
|
||||||
|
parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler')
|
||||||
|
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
|
||||||
|
parser.add_argument('--device', default='cuda:0', type=str, help='device')
|
||||||
|
parser.add_argument('--gamma', default=0.5, type=float, help='gamma')
|
||||||
|
parser.add_argument('--disable_tqdm', default=True, type=bool, help='disable_tqdm')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main(): # deprecated
|
||||||
|
args = config_parser()
|
||||||
|
args.snr_list = list(map(float, args.snr_list))
|
||||||
|
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
|
||||||
|
set_seed(args.seed)
|
||||||
|
print("Training Start")
|
||||||
|
for ratio in args.ratio_list:
|
||||||
|
for snr in args.snr_list:
|
||||||
|
train(args, ratio, snr)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main_pipeline()
|
||||||
|
# main()
|
||||||
|
|||||||
43
utils.py
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def image_normalization(norm_type):
|
def image_normalization(norm_type):
|
||||||
@ -14,9 +15,41 @@ def image_normalization(norm_type):
|
|||||||
return _inner
|
return _inner
|
||||||
|
|
||||||
|
|
||||||
def get_psnr(image, gt, max=255):
|
def get_psnr(image, gt, max_val=255, mse=None):
|
||||||
|
if mse is None:
|
||||||
|
mse = F.mse_loss(image, gt)
|
||||||
|
mse = torch.tensor(mse) if not isinstance(mse, torch.Tensor) else mse
|
||||||
|
|
||||||
mse = F.mse_loss(image, gt)
|
psnr = 10 * torch.log10(max_val**2 / mse)
|
||||||
|
|
||||||
psnr = 10 * torch.log10(max**2 / mse)
|
|
||||||
return psnr
|
return psnr
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(model, dir, path):
|
||||||
|
os.makedirs(dir, exist_ok=True)
|
||||||
|
flag = 1
|
||||||
|
while True:
|
||||||
|
if os.path.exists(path):
|
||||||
|
path = path + '_' + str(flag)
|
||||||
|
flag += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
torch.save(model.state_dict(), path)
|
||||||
|
print("Model saved in {}".format(path))
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
def view_model_param(model):
|
||||||
|
total_param = 0
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
# print(param.data.size())
|
||||||
|
total_param += np.prod(list(param.data.size()))
|
||||||
|
return total_param
|
||||||
|
|||||||