add dataset.py
This commit is contained in:
parent
b92cbc2c35
commit
0818c20233
@ -7,7 +7,7 @@ This is my first time to use PyTorch and git to reproduce a paper, so there may
|
||||

|
||||
|
||||
## Demo
|
||||
I spend 3 days from 12-20 to 12-24 to reproduce the paper, and i get the result as follow. The result is not good, because i trained the model on cifar10 which is 32*32 but test on kodim which is 768*512 and the model is not trained enough.
|
||||
I spend 3 days to reproduce the paper, and i get the result as follow. The result is not good, because i trained the model on cifar10 which is 32\*32 but test on kodim which is 768\*512 and the model is not trained enough.
|
||||
|
||||
That is all enough!!!
|
||||

|
||||
@ -18,14 +18,14 @@ conda or other virtual environment is recommended.
|
||||
|
||||
```
|
||||
git clone https://github.com/chunbaobao/Deep-JSCC-PyTorch.git
|
||||
pip install requirements.txt
|
||||
cd ./Deep-JSCC-PyTorch
|
||||
```
|
||||
|
||||
## Usage
|
||||
### Training Model
|
||||
Run(example presented in paper)
|
||||
```
|
||||
cd ./Deep-JSCC-PyTorch
|
||||
pip install requirements.txt
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
21
dataset.py
Normal file
21
dataset.py
Normal file
@ -0,0 +1,21 @@
|
||||
import os
|
||||
|
||||
|
||||
def main():
|
||||
data_path = './Dataset'
|
||||
os.makedirs(data_path, exist_ok=True)
|
||||
|
||||
if not os.path.exists('./Dataset/ILSVRC2012_img_train.tar') or not os.path.exists('./Dataset/ILSVRC2012_img_val.tar'):
|
||||
print('Please download the dataset from http://www.image-net.org/challenges/LSVRC/2012/downloads and put it in ./Dataset')
|
||||
raise Exception('not find dataset')
|
||||
phases = ['train', 'val']
|
||||
for phase in phases:
|
||||
path = './Dataset/ImageNet/{}'.format(phase)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
os.system('tar -xvf ./Dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
|
||||
for tar in os.listdir(path):
|
||||
os.system('tar -xvf {}/{} -C {}/{}'.format(path, tar, path,tar))
|
||||
os.remove('{}/{}'.format(path, tar))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
2
train.py
2
train.py
@ -125,7 +125,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
model.train()
|
||||
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
||||
save_model(model, args.saved, args.saved +
|
||||
'/model_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, ratio, snr, c))
|
||||
'/model_{}_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, args.epochs,ratio, snr, c))
|
||||
|
||||
|
||||
def save_model(model, dir, path):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user