add imagenet

This commit is contained in:
chun 2023-12-24 16:17:56 +08:00
parent 4f9c0063f7
commit 3ac47a6329
2 changed files with 7 additions and 8 deletions

View File

@ -19,21 +19,20 @@ conda or other virtual environment is recommended.
```
git clone https://github.com/chunbaobao/Deep-JSCC-PyTorch.git
cd ./Deep-JSCC-PyTorch
pip install requirements.txt
```
## Usage
### Training Model
Run(example presented in paper)
```
pip install requirements.txt
```
Run(example presented in paper) on cifar10
```
python train.py --lr 10e-4 --epochs 100 --batch_size 32 --channel 'AWGN' --saved ./saved --snr_list 1 4 7 13 19 --ratio_list 1/6 1/12 --dataset imagenet
python train.py --lr 10e-3 --epochs 100 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list 1 4 7 13 19 --ratio_list 1/6 1/12 --dataset cifar10 --num_workers 4 --parallel True --if_scheduler True --scheduler_step_size 50
```
or
or Run(example presented in paper) on imagenet
```
python train.py --lr 10e-3 --epochs 100 --batch_size 512 --channel 'AWGN' --saved ./saved --dataset cifar10 --num_workers 4 --parallel True
python train.py --lr 10e-4 --epochs 300 --batch_size 32 --channel 'AWGN' --saved ./saved --dataset imagenet --num_workers 4 --parallel True
```
### Evaluation
Run(example presented in paper)

View File

@ -75,7 +75,7 @@ def train(args: config_parser(), ratio: float, snr: float):
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
test_dataset = datasets.ImageNet(root='./Dataset/ImageNet/val', transform=transform)
test_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/val', transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
else: