Update-2024.06.04

This commit is contained in:
chun 2024-06-04 11:19:08 +08:00
parent ca6377e98c
commit e1cef6aead
87 changed files with 1436 additions and 209 deletions

4
.gitignore vendored
View File

@ -1,11 +1,11 @@
test.py test.*
*.pyc *.pyc
*.log *.log
dataset dataset
*.ipynb
*.swp *.swp
.vscode/* .vscode/*
input.txt input.txt
output.txt output.txt
*.json *.json
.vscode/* .vscode/*
*.sh

42
.vscode/launch.json vendored
View File

@ -1,42 +0,0 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: 当前文件",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"args": [
"--lr",
"1e-3",
"--epochs",
"100",
"--batch_size",
"512",
"--if_scheduler",
"1",
"--step_size",
"500",
"--dataset",
"cifar10",
"--num_workers",
"4",
"--device",
"cuda:0",
"--ratio_list",
"1/3",
"--snr_list",
"100",
"--seed",
"42",
"--disable_tqdm",
"False"
]
}
]
}

115
README.md
View File

@ -1,19 +1,35 @@
# Deep JSCC # Deep JSCC
This implements training of deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow and keras implementations ](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission). This implements training of Deep JSCC models for wireless image transmission as described in the paper [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://ieeexplore.ieee.org/abstract/document/8723589) by Pytorch. And there has been a [Tensorflow and keras implementations ](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission).
This is my first time to use PyTorch and git to reproduce a paper, so there may be some mistakes. If you find any, please let me know. Thanks! This is my first time to use PyTorch and git to reproduce a paper, so there may be some mistakes. If you find any, please let me know. Thanks!
## Update-2024.06.04
- modify the `train.py` to omit most of the args in command line, you can just us `python train.py --dataset ${dataset_name}` to train the model.
- add tensorboard to record the results in exp.
- add the `visualization/` file to visualize the result.
- add bash file to run the code in parallel.
## Architecture ## Architecture
![architecture](./demo/arc.png) <div style="text-align: center;">
<img src="./demo/arc.png" alt="Image 1" style="width: 500px; height: 250px;">
</div>
## Demo ## Demo
the model trained on cifar10 which is 32\*32 but test on kodim which is 768\*512.
![demo1](./run/cifar10_3000_0.33_100.00_256_40.pth_kodim08.png)
the model trained on imagenet which is resized to 128\*128 but test on kodim which is 768\*512. The model trained on cifar10 which is 32\*32 but test on kodim which is 768\*512 (left); and the model trained on imagenet which is resized to 128\*128 but test on kodim which is 768\*512 (right).
![demo2](./run/imagenet_10_0.33_200.00_32_19.pth_kodim08.png) <div style="display: flex;">
<img src="./demo/cifar10_kodim08.png" alt="Image 1" style="flex: 1;" width="150" height="450">
<div style="width: 20px;"></div>
<img src="./demo/imagenet_kodim08.png" alt="Image 1" style="flex: 1;" width="150" height="450">
</div>
## Installation ## Installation
conda or other virtual environment is recommended. conda or other virtual environment is recommended.
@ -31,25 +47,88 @@ The cifar10 dataset can be downloaded automatically by torchvision. But the imag
python dataset.py python dataset.py
``` ```
### Training Model ### Training
Run(example presented in paper) on cifar10 The training command used to be very long, but now you can just use `python train.py --dataset ${dataset_name} --channel ${channel}` to train the model.
The default dataset is cifar10.
The parmeters can be modified in the `train.py` file. The default parameters are similar to the paper.
```
python train.py --lr 1e-3 --epochs 1000 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list 1 4 7 13 19 --ratio_list 1/6 1/12 --dataset cifar10 --num_workers 4 --parallel True --if_scheduler True --scheduler_step_size 50
```
or Run(example presented in paper) on imagenet
| Parameters | CIFAR-10 | ImageNet |
|------------------------|------------------|------------------|
| `batch_size` | 64 | 32 |
| `init_lr` | 1e-3 | 1e-4 |
| `weight_decay` | 5e-4 | 5e-4 |
| `snr_list` | [19, 13, 7, 4, 1]| [19, 13, 7, 4, 1]|
| `ratio_list` | [1/6, 1/12] | [1/6, 1/12] |
| `if_scheduler` | True | False |
| `step_size` | 640 | N/A |
| `gamma` | 0.1 | 0.1 |
<!-- ALSO! The batch_size for cifar10 training in the paper is small causing the GPU utilization is low. So The bash script is provided to run the code in parallel for different snr and ratio for cifar10 dataset. (Example of two GPUs)
``` ```
python train.py --lr 10e-4 --epochs 300 --batch_size 32 --channel 'AWGN' --saved ./saved --dataset imagenet --num_workers 4 --parallel True bash parallel_train_cifar.sh --channel ${channel}
``` ``` -->
### Evaluation ### Evaluation
Run(example presented in paper) The `eval.py` provides the evaluation of the trained model.
You may need modify slightly to evaluate the model for different snr_list and channel type in `main` function.
``` ```
python eval.py --channel 'AWGN' --saved ./saved/${mode_path} --snr 20 --test_img ${test_img_path} python eval.py
``` ```
All training and evaluation results are saved in the `./out` directory by default. The `./out` directory may contain the structure as follows:
```
./out
├── checkpoint # trained models
│   ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
│   ├── epoch_$num.pth
│   ├── ...
│   ├── CIFAR10_10_1.0_0.08_AWGN_13h21m37s_on_Jun_02_2024
│   ├── CIFAR10_20_7.0_0.17_Rayleigh_14h03m19s_on_Jun_03_2024
│   ├── ...
├── configs # training configurations
│   ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
│   ├── $CIFAR10_10_4.0_0.08_AWGN_13h21m38s_on_Jun_02_2024.yaml
│   ├── ...
├── logs # training logs
│   ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
│   ├── tensorboard logs
│   ├── ...
├── eval # evaluation results
│   ├── $DATASETNAME_$INNERCHANNEL_$SNR_$RATIO_$CHANNEL_TYPE_$TIMES_on_$DATE
│   ├── tensorboard logs
│   ├── ...
```
### Visualization
The `./visualization` directory contains the scripts for visualization of the training and evaluation results.
- `single_visualization.ipynb` is used to get demo of the model on single image like the demo above.
- `plot_visualization.ipynb` is used to get visualizations of the perfprmance of the model on different snr and ratio.
## Results
Model results and logs for the **CIFAR-10** dataset, tested under various SNR, ratio, and channel types, are available in the `./out` directory. The models' performance is approximately 5dB worse than reported in the paper, likely due to the implementation reflecting a real communication system. However, the performance trends are consistent with those in the paper.
<div style="display: flex;">
<img src="demo/cifar_0.08.png" alt="Image 1" style="flex: 1; max-width: 50%; height: auto;">
<div style="width: 0px;"></div> <!-- 为了让两个图像之间有一点间距 -->
<img src="demo/cifar_0.17.png" alt="Image 2" style="flex: 1; max-width: 50%; height: auto;">
</div>
### TO-DO ### TO-DO
- Add visualization of the model - ~~Add visualization of the model~~
- plot the results with different snr and ratio - ~~plot the results with different snr and ratio~~
- ~~add Rayleigh channel~~
- train on imagenet
- **Convert the real communication system to a complex communication system**
## Citation ## Citation
If you find (part of) this code useful for your research, please consider citing If you find (part of) this code useful for your research, please consider citing

View File

@ -1,30 +0,0 @@
from time import time
import multiprocessing as mp
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, RandomSampler
if __name__ == '__main__':
transform = transforms.Compose([
torchvision.transforms.ToTensor(),
])
trainset = torchvision.datasets.CIFAR10(
root='./dataset/',
train=True, # 如果为True从 training.pt 创建数据,否则从 test.pt 创建数据。
download=True, # 如果为true则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
transform=transform
)
print(f"num of CPU: {mp.cpu_count()}")
for num_workers in range(2, mp.cpu_count(), 2):
train_loader = torch.utils.data.DataLoader(
trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True)
start = time()
for epoch in range(1, 3):
for i, data in enumerate(train_loader, 0):
pass
end = time()
print("Finish with:{} second, num_workers={}".format(end - start, num_workers))

View File

@ -2,26 +2,44 @@ import torch
import torch.nn as nn import torch.nn as nn
def channel(channel_type='AWGN', snr=20): class Channel(nn.Module):
def AWGN_channel(z_hat: torch.Tensor): def __init__(self, channel_type='AWGN', snr=20):
if channel_type not in ['AWGN', 'Rayleigh']:
raise Exception('Unknown type of channel')
super(Channel, self).__init__()
self.channel_type = channel_type
self.snr = snr
def forward(self, z_hat):
if z_hat.dim() == 4: if z_hat.dim() == 4:
# k = np.prod(z_hat.size()[1:]) # k = np.prod(z_hat.size()[1:])
k = torch.prod(torch.tensor(z_hat.size()[1:])) k = torch.prod(torch.tensor(z_hat.size()[1:]))
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)/k sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True) / k
elif z_hat.dim() == 3: elif z_hat.dim() == 3:
# k = np.prod(z_hat.size()) # k = np.prod(z_hat.size())
k = torch.prod(torch.tensor(z_hat.size())) k = torch.prod(torch.tensor(z_hat.size()))
sig_pwr = torch.sum(torch.abs(z_hat).square())/k sig_pwr = torch.sum(torch.abs(z_hat).square()) / k
noi_pwr = sig_pwr / (10 ** (snr / 10)) noi_pwr = sig_pwr / (10 ** (self.snr / 10))
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr) noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
if self.channel_type == 'Rayleigh':
# hc = torch.randn_like(z_hat) wrong implement before
hc = torch.randn(1, device = z_hat.device)
z_hat = hc * z_hat
return z_hat + noise return z_hat + noise
def Rayleigh_channel(z_hat: torch.Tensor): def get_channel(self):
pass return self.channel_type, self.snr
if channel_type == 'AWGN':
return AWGN_channel if __name__ == '__main__':
elif channel_type == 'Rayleigh': # test
return Rayleigh_channel channel = Channel(channel_type='AWGN', snr=10)
else: z_hat = torch.randn(64, 10, 5, 5)
raise Exception('Unknown type of channel') z_hat = channel(z_hat)
print(z_hat)
channel = Channel(channel_type='Rayleigh', snr=10)
z_hat = torch.randn(64, 10, 5, 5)
z_hat = channel(z_hat)
print(z_hat)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.1 KiB

View File

Before

Width:  |  Height:  |  Size: 1.6 MiB

After

Width:  |  Height:  |  Size: 1.6 MiB

BIN
demo/cifar_0.08.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

BIN
demo/cifar_0.17.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

View File

Before

Width:  |  Height:  |  Size: 1.5 MiB

After

Width:  |  Height:  |  Size: 1.5 MiB

102
eval.py
View File

@ -1,51 +1,75 @@
# to be implemented
import torch import torch
import torch.nn as nn from utils import get_psnr
from PIL import Image
from torchvision import transforms
from utils import get_psnr, image_normalization
import os import os
from model import DeepJSCC from model import DeepJSCC
from train import evaluate_epoch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from dataset import Vanilla
import yaml
from tensorboardX import SummaryWriter
import glob
from concurrent.futures import ProcessPoolExecutor
def eval_snr(model, test_loader, writer, param, times=10):
snr_list = range(0, 26, 1)
for snr in snr_list:
model.change_channel(param['channel'], snr)
test_loss = 0
for i in range(times):
test_loss += evaluate_epoch(model, param, test_loader)
test_loss /= times
psnr = get_psnr(image=None, gt=None, mse=test_loss)
writer.add_scalar('psnr', psnr, snr)
def config_parser(): def process_config(config_path, output_dir, dataset_name, times):
import argparse with open(config_path, 'r') as f:
parser = argparse.ArgumentParser() config = yaml.load(f, Loader=yaml.UnsafeLoader)
parser.add_argument('--channel', default='AWGN', type=str, help='channel type') assert dataset_name == config['dataset_name']
parser.add_argument('--saved', type=str, help='saved_path') params = config['params']
parser.add_argument('--snr', default=20, type=int, help='snr') c = config['inner_channel']
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
parser.add_argument('--times', default=10, type=int, help='num_workers')
return parser.parse_args()
if dataset_name == 'cifar10':
transform = transforms.Compose([transforms.ToTensor(), ])
test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=params['batch_size'], num_workers=params['num_workers'])
elif dataset_name == 'imagenet':
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
test_dataset = Vanilla(root='../dataset/ImageNet/val', transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=params['batch_size'], num_workers=params['num_workers'])
else:
raise Exception('Unknown dataset')
name = os.path.splitext(os.path.basename(config_path))[0]
writer = SummaryWriter(os.path.join(output_dir, 'eval', name))
model = DeepJSCC(c=c)
model = model.to(params['device'])
pkl_list = glob.glob(os.path.join(output_dir, 'checkpoints', name, '*.pkl'))
model.load_state_dict(torch.load(pkl_list[-1]))
eval_snr(model, test_loader, writer, params, times)
writer.close()
def main(): def main():
args = config_parser() times = 10
transform = transforms.Compose([transforms.ToTensor()]) dataset_name = 'cifar10'
test_image = Image.open(args.test_image) output_dir = './out'
test_image.load() channel_type = 'AWGN'
test_image = transform(test_image) config_dir = os.path.join(output_dir, 'configs')
config_files = [os.path.join(config_dir, name) for name in os.listdir(config_dir)
if (dataset_name in name or dataset_name.upper() in name) and channel_type in name and name.endswith('.yaml')]
file_name = os.path.basename(args.saved) with ProcessPoolExecutor() as executor:
c = file_name.split('_')[-1].split('.')[0] executor.map(process_config, config_files, [output_dir] * len(config_files), [dataset_name] * len(config_files), [times] * len(config_files))
c = int(c)
model = DeepJSCC(c=c, channel_type=args.channel, snr=args.snr)
# model.load_state_dict(torch.load(args.saved))
model.load_state_dict(torch.load(args.saved,map_location=torch.device('cuda:0')))
model.change_channel(args.channel, args.snr)
psnr_all = 0.0
for i in range(args.times):
demo_image = model(test_image)
demo_image = image_normalization('denormalization')(demo_image)
gt = image_normalization('denormalization')(test_image)
psnr_all += get_psnr(demo_image, gt)
demo_image = image_normalization('normalization')(demo_image)
demo_image = torch.cat([test_image, demo_image], dim=1)
demo_image = transforms.ToPILImage()(demo_image)
demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1], args.test_image.split('/')[-1]))
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -7,7 +7,7 @@ Created on Tue Dec 11:00:00 2023
import torch import torch
import torch.nn as nn import torch.nn as nn
import channel from channel import Channel
""" def _image_normalization(norm_type): """ def _image_normalization(norm_type):
@ -58,7 +58,8 @@ class _TransConvWithPReLU(nn.Module):
in_channels, out_channels, kernel_size, stride, padding, output_padding) in_channels, out_channels, kernel_size, stride, padding, output_padding)
self.activate = activate self.activate = activate
if activate == nn.PReLU(): if activate == nn.PReLU():
nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out', nonlinearity='leaky_relu') nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out',
nonlinearity='leaky_relu')
else: else:
nn.init.xavier_normal_(self.transconv.weight) nn.init.xavier_normal_(self.transconv.weight)
@ -104,7 +105,7 @@ class _Encoder(nn.Module):
return _inner return _inner
def forward(self, x): def forward(self, x):
#x = self.imgae_normalization(x) # x = self.imgae_normalization(x)
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x) x = self.conv2(x)
x = self.conv3(x) x = self.conv3(x)
@ -118,7 +119,7 @@ class _Encoder(nn.Module):
class _Decoder(nn.Module): class _Decoder(nn.Module):
def __init__(self, c=1): def __init__(self, c=1):
super(_Decoder, self).__init__() super(_Decoder, self).__init__()
#self.imgae_normalization = _image_normalization(norm_type='denormalization') # self.imgae_normalization = _image_normalization(norm_type='denormalization')
self.tconv1 = _TransConvWithPReLU( self.tconv1 = _TransConvWithPReLU(
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2) in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv2 = _TransConvWithPReLU( self.tconv2 = _TransConvWithPReLU(
@ -136,22 +137,50 @@ class _Decoder(nn.Module):
x = self.tconv3(x) x = self.tconv3(x)
x = self.tconv4(x) x = self.tconv4(x)
x = self.tconv5(x) x = self.tconv5(x)
#x = self.imgae_normalization(x) # x = self.imgae_normalization(x)
return x return x
class DeepJSCC(nn.Module): class DeepJSCC(nn.Module):
def __init__(self, c, channel_type='AWGN', snr=20): def __init__(self, c, channel_type='AWGN', snr=None):
super(DeepJSCC, self).__init__() super(DeepJSCC, self).__init__()
self.encoder = _Encoder(c=c) self.encoder = _Encoder(c=c)
self.channel = channel.channel(channel_type, snr) if snr is not None:
self.channel = Channel(channel_type, snr)
self.decoder = _Decoder(c=c) self.decoder = _Decoder(c=c)
def forward(self, x): def forward(self, x):
z = self.encoder(x) z = self.encoder(x)
z = self.channel(z) if hasattr(self, 'channel') and self.channel is not None:
z = self.channel(z)
x_hat = self.decoder(z) x_hat = self.decoder(z)
return x_hat return x_hat
def change_channel(self, channel_type, snr): def change_channel(self, channel_type='AWGN', snr=None):
self.channel = channel.channel(channel_type, snr) if snr is None:
self.channel = None
else:
self.channel = Channel(channel_type, snr)
def get_channel(self):
if hasattr(self, 'channel') and self.channel is not None:
return self.channel.get_channel()
return None
def loss(self, prd, gt):
criterion = nn.MSELoss(reduction='mean')
loss = criterion(prd, gt)
return loss
if __name__ == '__main__':
model = DeepJSCC(c=20)
print(model)
x = torch.rand(1, 3, 32, 32)
y = model(x)
print(y.size())
print(y)
print(model.encoder.norm)
print(model.encoder.norm(y))
print(model.encoder.norm(y).size())
print(model.encoder.norm(y).size()[1:])

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 10
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:1
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.08333333333333333
ratio_list:
- 0.08333333333333333
seed: 42
snr: 1.0
snr_list:
- 1.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
41QCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 10
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:1
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.08333333333333333
ratio_list:
- 0.08333333333333333
seed: 42
snr: 13.0
snr_list:
- 13.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
41QCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 10
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:1
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.08333333333333333
ratio_list:
- 0.08333333333333333
seed: 42
snr: 19.0
snr_list:
- 19.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
41QCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 10
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:1
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.08333333333333333
ratio_list:
- 0.08333333333333333
seed: 42
snr: 4.0
snr_list:
- 4.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
41QCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 10
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:1
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.08333333333333333
ratio_list:
- 0.08333333333333333
seed: 42
snr: 7.0
snr_list:
- 7.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
41QCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 20
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:0
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.16666666666666666
ratio_list:
- 0.16666666666666666
seed: 42
snr: 1.0
snr_list:
- 1.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
bZMCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 20
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:0
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.16666666666666666
ratio_list:
- 0.16666666666666666
seed: 42
snr: 13.0
snr_list:
- 13.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
bZMCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 20
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:0
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.16666666666666666
ratio_list:
- 0.16666666666666666
seed: 42
snr: 19.0
snr_list:
- 19.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
bZMCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 20
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:0
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.16666666666666666
ratio_list:
- 0.16666666666666666
seed: 42
snr: 4.0
snr_list:
- 4.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
bZMCAAAAAAA=

View File

@ -0,0 +1,46 @@
dataset_name: cifar10
inner_channel: 20
params:
ReduceLROnPlateau: false
batch_size: 64
channel: AWGN
dataset: cifar10
device: cuda:0
disable_tqdm: false
epochs: 1000
gamma: 0.1
if_scheduler: true
init_lr: 0.001
lr_reduce_factor: 0.5
lr_schedule_patience: 15
max_time: 12
min_lr: 1.0e-05
num_workers: 4
out_dir: ./out
parallel: false
ratio: 0.16666666666666666
ratio_list:
- 0.16666666666666666
seed: 42
snr: 7.0
snr_list:
- 7.0
step_size: 640
weight_decay: 0.0005
total_parameters: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args:
- i8
- false
- true
state: !!python/tuple
- 3
- <
- null
- null
- null
- -1
- -1
- 0
- !!binary |
bZMCAAAAAAA=

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.0 MiB

367
train.py
View File

@ -14,60 +14,301 @@ import torch.optim as optim
from tqdm import tqdm from tqdm import tqdm
from model import DeepJSCC, ratio2filtersize from model import DeepJSCC, ratio2filtersize
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
from utils import image_normalization from utils import image_normalization, set_seed, save_model, view_model_param
from fractions import Fraction from fractions import Fraction
from dataset import Vanilla from dataset import Vanilla
import numpy as np import numpy as np
import time
from tensorboardX import SummaryWriter
import glob
def set_seed(seed): def train_epoch(model, optimizer, param, data_loader):
np.random.seed(seed) model.train()
torch.manual_seed(seed) epoch_loss = 0
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) for iter, (images, _) in enumerate(data_loader):
torch.backends.cudnn.deterministic = True images = images.cuda() if param['parallel'] and torch.cuda.device_count(
torch.backends.cudnn.benchmark = False ) > 1 else images.to(param['device'])
optimizer.zero_grad()
outputs = model.forward(images)
outputs = image_normalization('denormalization')(outputs)
images = image_normalization('denormalization')(images)
loss = model.loss(images, outputs) if not param['parallel'] else model.module.loss(
images, outputs)
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_loss /= (iter + 1)
return epoch_loss, optimizer
def config_parser(): def evaluate_epoch(model, param, data_loader):
model.eval()
epoch_loss = 0
with torch.no_grad():
for iter, (images, _) in enumerate(data_loader):
images = images.cuda() if param['parallel'] and torch.cuda.device_count(
) > 1 else images.to(param['device'])
outputs = model.forward(images)
outputs = image_normalization('denormalization')(outputs)
images = image_normalization('denormalization')(images)
loss = model.loss(images, outputs) if not param['parallel'] else model.module.loss(
images, outputs)
epoch_loss += loss.detach().item()
epoch_loss /= (iter + 1)
return epoch_loss
def config_parser_pipeline():
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--epochs', default=256, type=int, help='number of epochs')
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str,
choices=['AWGN', 'Rayleigh'], help='channel type')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=['19', '13',
'7', '4', '1'], nargs='+', help='snr_list')
parser.add_argument('--ratio_list', default=['1/3',
'1/6', '1/12'], nargs='+', help='ratio_list')
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
parser.add_argument('--dataset', default='cifar10', type=str, parser.add_argument('--dataset', default='cifar10', type=str,
choices=['cifar10', 'imagenet'], help='dataset') choices=['cifar10', 'imagenet'], help='dataset')
parser.add_argument('--parallel', default=False, type=bool, help='parallel') parser.add_argument('--out', default='./out', type=str, help='out_path')
parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler') parser.add_argument('--disable_tqdm', default=False, type=bool, help='disable_tqdm')
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
parser.add_argument('--device', default='cuda:0', type=str, help='device') parser.add_argument('--device', default='cuda:0', type=str, help='device')
parser.add_argument('--gamma', default=0.5, type=float, help='gamma') parser.add_argument('--parallel', default=False, type=bool, help='parallel')
parser.add_argument('--disable_tqdm', default=True, type=bool, help='disable_tqdm') parser.add_argument('--snr_list', default=['19', '13',
'7', '4', '1'], nargs='+', help='snr_list')
parser.add_argument('--ratio_list', default=['1/6', '1/12'], nargs='+', help='ratio_list')
parser.add_argument('--channel', default='AWGN', type=str,
choices=['AWGN', 'Rayleigh'], help='channel')
return parser.parse_args() return parser.parse_args()
def main(): def main_pipeline():
args = config_parser() args = config_parser_pipeline()
print("Training Start")
dataset_name = args.dataset
out_dir = args.out
args.snr_list = list(map(float, args.snr_list)) args.snr_list = list(map(float, args.snr_list))
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list)) args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
set_seed(args.seed) params = {}
print("Training Start") params['disable_tqdm'] = args.disable_tqdm
for ratio in args.ratio_list: params['dataset'] = dataset_name
for snr in args.snr_list: params['out_dir'] = out_dir
train(args, ratio, snr) params['device'] = args.device
params['snr_list'] = args.snr_list
params['ratio_list'] = args.ratio_list
params['channel'] = args.channel
if dataset_name == 'cifar10':
params['batch_size'] = 64 # 1024
params['num_workers'] = 4
params['epochs'] = 1000
params['init_lr'] = 1e-3 # 1e-2
params['weight_decay'] = 5e-4
params['parallel'] = False
params['if_scheduler'] = True
params['step_size'] = 640
params['gamma'] = 0.1
params['seed'] = 42
params['ReduceLROnPlateau'] = False
params['lr_reduce_factor'] = 0.5
params['lr_schedule_patience'] = 15
params['max_time'] = 12
params['min_lr'] = 1e-5
elif dataset_name == 'imagenet':
params['batch_size'] = 32
params['num_workers'] = 4
params['epochs'] = 300
params['init_lr'] = 1e-4
params['weight_decay'] = 5e-4
params['parallel'] = True
params['if_scheduler'] = True
params['gamma'] = 0.1
params['seed'] = 42
params['ReduceLROnPlateau'] = True
params['lr_reduce_factor'] = 0.5
params['lr_schedule_patience'] = 15
params['max_time'] = 12
params['min_lr'] = 1e-5
else:
raise Exception('Unknown dataset')
set_seed(params['seed'])
for ratio in params['ratio_list']:
for snr in params['snr_list']:
params['ratio'] = ratio
params['snr'] = snr
train_pipeline(params)
def train(args: config_parser(), ratio: float, snr: float): # add train_pipeline to with only dataset_name args
def train_pipeline(params):
dataset_name = params['dataset']
# load data
if dataset_name == 'cifar10':
transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='../dataset/', train=True,
download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=params['batch_size'], num_workers=params['num_workers'])
test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=params['batch_size'], num_workers=params['num_workers'])
elif dataset_name == 'imagenet':
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
print("loading data of imagenet")
train_dataset = datasets.ImageFolder(root='../dataset/ImageNet/train', transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=params['batch_size'], num_workers=params['num_workers'])
test_dataset = Vanilla(root='../dataset/ImageNet/val', transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=params['batch_size'], num_workers=params['num_workers'])
else:
raise Exception('Unknown dataset')
# create model
image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, params['ratio'])
print("The snr is {}, the inner channel is {}, the ratio is {:.2f}".format(
params['snr'], c, params['ratio']))
model = DeepJSCC(c=c, channel_type=params['channel'], snr=params['snr'])
# init exp dir
out_dir = params['out_dir']
phaser = dataset_name.upper() + '_' + str(c) + '_' + str(params['snr']) + '_' + \
"{:.2f}".format(params['ratio']) + '_' + str(params['channel']) + \
'_' + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
root_log_dir = out_dir + '/' + 'logs/' + phaser
root_ckpt_dir = out_dir + '/' + 'checkpoints/' + phaser
root_config_dir = out_dir + '/' + 'configs/' + phaser
writer = SummaryWriter(log_dir=root_log_dir)
# model init
device = torch.device(params['device'] if torch.cuda.is_available() else 'cpu')
if params['parallel'] and torch.cuda.device_count() > 1:
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
model = model.cuda()
else:
model = model.to(device)
# opt
optimizer = optim.Adam(
model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
if params['if_scheduler'] and not params['ReduceLROnPlateau']:
scheduler = optim.lr_scheduler.StepLR(
optimizer, step_size=params['step_size'], gamma=params['gamma'])
elif params['ReduceLROnPlateau']:
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
factor=params['lr_reduce_factor'],
patience=params['lr_schedule_patience'],
verbose=False)
else:
print("No scheduler")
scheduler = None
writer.add_text('config', str(params))
t0 = time.time()
epoch_train_losses, epoch_val_losses = [], []
per_epoch_time = []
# train
# At any point you can hit Ctrl + C to break out of training early.
try:
with tqdm(range(params['epochs']), disable=params['disable_tqdm']) as t:
for epoch in t:
t.set_description('Epoch %d' % epoch)
start = time.time()
epoch_train_loss, optimizer = train_epoch(
model, optimizer, params, train_loader)
epoch_val_loss = evaluate_epoch(model, params, test_loader)
epoch_train_losses.append(epoch_train_loss)
epoch_val_losses.append(epoch_val_loss)
writer.add_scalar('train/_loss', epoch_train_loss, epoch)
writer.add_scalar('val/_loss', epoch_val_loss, epoch)
writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)
t.set_postfix(time=time.time() - start, lr=optimizer.param_groups[0]['lr'],
train_loss=epoch_train_loss, val_loss=epoch_val_loss)
per_epoch_time.append(time.time() - start)
# Saving checkpoint
if not os.path.exists(root_ckpt_dir):
os.makedirs(root_ckpt_dir)
torch.save(model.state_dict(), '{}.pkl'.format(
root_ckpt_dir + "/epoch_" + str(epoch)))
files = glob.glob(root_ckpt_dir + '/*.pkl')
for file in files:
epoch_nb = file.split('_')[-1]
epoch_nb = int(epoch_nb.split('.')[0])
if epoch_nb < epoch - 1:
os.remove(file)
if params['ReduceLROnPlateau'] and scheduler is not None:
scheduler.step(epoch_val_loss)
elif params['if_scheduler'] and not params['ReduceLROnPlateau']:
scheduler.step() # use only information from the validation loss
if optimizer.param_groups[0]['lr'] < params['min_lr']:
print("\n!! LR EQUAL TO MIN LR SET.")
break
# Stop training after params['max_time'] hours
if time.time() - t0 > params['max_time'] * 3600:
print('-' * 89)
print("Max_time for training elapsed {:.2f} hours, so stopping".format(
params['max_time']))
break
except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early because of KeyboardInterrupt')
test_loss = evaluate_epoch(model, params, test_loader)
train_loss = evaluate_epoch(model, params, train_loader)
print("Test Accuracy: {:.4f}".format(test_loss))
print("Train Accuracy: {:.4f}".format(train_loss))
print("Convergence Time (Epochs): {:.4f}".format(epoch))
print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - t0))
print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))
"""
Write the results in out_dir/results folder
"""
writer.add_text(tag='result', text_string="""Dataset: {}\nparams={}\n\nTotal Parameters: {}\n\n
FINAL RESULTS\nTEST Loss: {:.4f}\nTRAIN Loss: {:.4f}\n\n
Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""
.format(dataset_name, params, view_model_param(model), np.mean(np.array(train_loss)),
np.mean(np.array(test_loss)), epoch, (time.time() - t0) / 3600, np.mean(per_epoch_time)))
writer.close()
if not os.path.exists(os.path.dirname(root_config_dir)):
os.makedirs(os.path.dirname(root_config_dir))
with open(root_config_dir + '.yaml', 'w') as f:
dict_yaml = {'dataset_name': dataset_name, 'params': params,
'inner_channel': c, 'total_parameters': view_model_param(model)}
import yaml
yaml.dump(dict_yaml, f)
del model, optimizer, scheduler, train_loader, test_loader
del writer
def train(args, ratio: float, snr: float): # deprecated
device = torch.device(args.device if torch.cuda.is_available() else 'cpu') device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
# load data # load data
@ -141,23 +382,49 @@ def train(args: config_parser(), ratio: float, snr: float):
model.train() model.train()
# epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader)) # epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format( print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format(
epoch, run_loss/len(train_loader), test_mse/len(test_loader), optimizer.param_groups[0]['lr'])) epoch, run_loss / len(train_loader), test_mse / len(test_loader), optimizer.param_groups[0]['lr']))
save_model(model, args.saved, args.saved + save_model(model, args.saved, args.saved + '/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'
'/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, args.batch_size, c)) .format(args.dataset, args.epochs, ratio, snr, args.batch_size, c))
def save_model(model, dir, path): def config_parser(): # deprecated
os.makedirs(dir, exist_ok=True) import argparse
flag = 1 parser = argparse.ArgumentParser()
while True: parser.add_argument('--seed', default=2048, type=int, help='Random seed')
if os.path.exists(path): parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
path = path+'_'+str(flag) parser.add_argument('--epochs', default=256, type=int, help='number of epochs')
flag += 1 parser.add_argument('--batch_size', default=256, type=int, help='batch size')
else: parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
break parser.add_argument('--channel', default='AWGN', type=str,
torch.save(model.state_dict(), path) choices=['AWGN', 'Rayleigh'], help='channel type')
print("Model saved in {}".format(path)) parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=['19', '13',
'7', '4', '1'], nargs='+', help='snr_list')
parser.add_argument('--ratio_list', default=['1/3',
'1/6', '1/12'], nargs='+', help='ratio_list')
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
parser.add_argument('--dataset', default='cifar10', type=str,
choices=['cifar10', 'imagenet'], help='dataset')
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler')
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
parser.add_argument('--device', default='cuda:0', type=str, help='device')
parser.add_argument('--gamma', default=0.5, type=float, help='gamma')
parser.add_argument('--disable_tqdm', default=True, type=bool, help='disable_tqdm')
return parser.parse_args()
def main(): # deprecated
args = config_parser()
args.snr_list = list(map(float, args.snr_list))
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
set_seed(args.seed)
print("Training Start")
for ratio in args.ratio_list:
for snr in args.snr_list:
train(args, ratio, snr)
if __name__ == '__main__': if __name__ == '__main__':
main() main_pipeline()
# main()

View File

@ -1,6 +1,7 @@
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import os
import numpy as np
def image_normalization(norm_type): def image_normalization(norm_type):
@ -14,9 +15,41 @@ def image_normalization(norm_type):
return _inner return _inner
def get_psnr(image, gt, max=255): def get_psnr(image, gt, max_val=255, mse=None):
if mse is None:
mse = F.mse_loss(image, gt)
mse = torch.tensor(mse) if not isinstance(mse, torch.Tensor) else mse
mse = F.mse_loss(image, gt) psnr = 10 * torch.log10(max_val**2 / mse)
psnr = 10 * torch.log10(max**2 / mse)
return psnr return psnr
def save_model(model, dir, path):
os.makedirs(dir, exist_ok=True)
flag = 1
while True:
if os.path.exists(path):
path = path + '_' + str(flag)
flag += 1
else:
break
torch.save(model.state_dict(), path)
print("Model saved in {}".format(path))
def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def view_model_param(model):
total_param = 0
for param in model.parameters():
# print(param.data.size())
total_param += np.prod(list(param.data.size()))
return total_param

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long