add dataset.py

This commit is contained in:
chun 2023-12-24 11:39:24 +08:00
parent b92cbc2c35
commit 0818c20233
3 changed files with 25 additions and 4 deletions

View File

@ -7,7 +7,7 @@ This is my first time to use PyTorch and git to reproduce a paper, so there may
![architecture](./demo/arc.png)
## Demo
I spend 3 days from 12-20 to 12-24 to reproduce the paper, and i get the result as follow. The result is not good, because i trained the model on cifar10 which is 32*32 but test on kodim which is 768*512 and the model is not trained enough.
I spend 3 days to reproduce the paper, and i get the result as follow. The result is not good, because i trained the model on cifar10 which is 32\*32 but test on kodim which is 768\*512 and the model is not trained enough.
That is all enough!!
![demo](./demo/demo.png)
@ -18,14 +18,14 @@ conda or other virtual environment is recommended.
```
git clone https://github.com/chunbaobao/Deep-JSCC-PyTorch.git
pip install requirements.txt
cd ./Deep-JSCC-PyTorch
```
## Usage
### Training Model
Run(example presented in paper)
```
cd ./Deep-JSCC-PyTorch
pip install requirements.txt
```
```

21
dataset.py Normal file
View File

@ -0,0 +1,21 @@
import os
def main():
data_path = './Dataset'
os.makedirs(data_path, exist_ok=True)
if not os.path.exists('./Dataset/ILSVRC2012_img_train.tar') or not os.path.exists('./Dataset/ILSVRC2012_img_val.tar'):
print('Please download the dataset from http://www.image-net.org/challenges/LSVRC/2012/downloads and put it in ./Dataset')
raise Exception('not find dataset')
phases = ['train', 'val']
for phase in phases:
path = './Dataset/ImageNet/{}'.format(phase)
os.makedirs(path, exist_ok=True)
os.system('tar -xvf ./Dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
for tar in os.listdir(path):
os.system('tar -xvf {}/{} -C {}/{}'.format(path, tar, path,tar))
os.remove('{}/{}'.format(path, tar))
if __name__ == '__main__':
main()

View File

@ -125,7 +125,7 @@ def train(args: config_parser(), ratio: float, snr: float):
model.train()
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
save_model(model, args.saved, args.saved +
'/model_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, ratio, snr, c))
'/model_{}_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, args.epochs,ratio, snr, c))
def save_model(model, dir, path):