add TODO list
This commit is contained in:
parent
c81101c269
commit
e2852526eb
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,3 +8,4 @@ dataset
|
|||||||
input.txt
|
input.txt
|
||||||
output.txt
|
output.txt
|
||||||
*.json
|
*.json
|
||||||
|
.vscode/*
|
||||||
|
|||||||
8
.vscode/launch.json
vendored
8
.vscode/launch.json
vendored
@ -15,7 +15,7 @@
|
|||||||
"--lr",
|
"--lr",
|
||||||
"1e-3",
|
"1e-3",
|
||||||
"--epochs",
|
"--epochs",
|
||||||
"1000",
|
"100",
|
||||||
"--batch_size",
|
"--batch_size",
|
||||||
"512",
|
"512",
|
||||||
"--if_scheduler",
|
"--if_scheduler",
|
||||||
@ -27,13 +27,15 @@
|
|||||||
"--num_workers",
|
"--num_workers",
|
||||||
"4",
|
"4",
|
||||||
"--device",
|
"--device",
|
||||||
"cuda:1",
|
"cuda:0",
|
||||||
"--ratio_list",
|
"--ratio_list",
|
||||||
"1/3",
|
"1/3",
|
||||||
"--snr_list",
|
"--snr_list",
|
||||||
"100",
|
"100",
|
||||||
"--seed",
|
"--seed",
|
||||||
"42"
|
"42",
|
||||||
|
"--disable_tqdm",
|
||||||
|
"False"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@ -47,7 +47,9 @@ Run(example presented in paper)
|
|||||||
```
|
```
|
||||||
python eval.py --channel 'AWGN' --saved ./saved/${mode_path} --snr 20 --ratio_list 1/3 --test_img ${test_img_path}
|
python eval.py --channel 'AWGN' --saved ./saved/${mode_path} --snr 20 --ratio_list 1/3 --test_img ${test_img_path}
|
||||||
```
|
```
|
||||||
|
### TO-DO
|
||||||
|
- Add visualization of the model
|
||||||
|
- plot the results with different snr and ratio
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
If you find (part of) this code useful for your research, please consider citing
|
If you find (part of) this code useful for your research, please consider citing
|
||||||
|
|||||||
3
eval.py
3
eval.py
@ -30,7 +30,8 @@ def main():
|
|||||||
c = file_name.split('_')[-1].split('.')[0]
|
c = file_name.split('_')[-1].split('.')[0]
|
||||||
c = int(c)
|
c = int(c)
|
||||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=args.snr)
|
model = DeepJSCC(c=c, channel_type=args.channel, snr=args.snr)
|
||||||
model.load_state_dict(torch.load(args.saved))
|
# model.load_state_dict(torch.load(args.saved))
|
||||||
|
model.load_state_dict(torch.load(args.saved,map_location=torch.device('cuda:0')))
|
||||||
model.change_channel(args.channel, args.snr)
|
model.change_channel(args.channel, args.snr)
|
||||||
|
|
||||||
psnr_all = 0.0
|
psnr_all = 0.0
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.6 MiB After Width: | Height: | Size: 1.6 MiB |
4
train.py
4
train.py
@ -73,12 +73,12 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
# load data
|
# load data
|
||||||
if args.dataset == 'cifar10':
|
if args.dataset == 'cifar10':
|
||||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||||
train_dataset = datasets.CIFAR10(root='./dataset/', train=True,
|
train_dataset = datasets.CIFAR10(root='../dataset/', train=True,
|
||||||
download=True, transform=transform)
|
download=True, transform=transform)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
test_dataset = datasets.CIFAR10(root='./dataset/', train=False,
|
test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
|
||||||
download=True, transform=transform)
|
download=True, transform=transform)
|
||||||
test_loader = DataLoader(test_dataset, shuffle=True,
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user