add TODO list
This commit is contained in:
parent
c81101c269
commit
e2852526eb
3
.gitignore
vendored
3
.gitignore
vendored
@ -7,4 +7,5 @@ dataset
|
||||
.vscode/*
|
||||
input.txt
|
||||
output.txt
|
||||
*.json
|
||||
*.json
|
||||
.vscode/*
|
||||
|
||||
8
.vscode/launch.json
vendored
8
.vscode/launch.json
vendored
@ -15,7 +15,7 @@
|
||||
"--lr",
|
||||
"1e-3",
|
||||
"--epochs",
|
||||
"1000",
|
||||
"100",
|
||||
"--batch_size",
|
||||
"512",
|
||||
"--if_scheduler",
|
||||
@ -27,13 +27,15 @@
|
||||
"--num_workers",
|
||||
"4",
|
||||
"--device",
|
||||
"cuda:1",
|
||||
"cuda:0",
|
||||
"--ratio_list",
|
||||
"1/3",
|
||||
"--snr_list",
|
||||
"100",
|
||||
"--seed",
|
||||
"42"
|
||||
"42",
|
||||
"--disable_tqdm",
|
||||
"False"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@ -47,7 +47,9 @@ Run(example presented in paper)
|
||||
```
|
||||
python eval.py --channel 'AWGN' --saved ./saved/${mode_path} --snr 20 --ratio_list 1/3 --test_img ${test_img_path}
|
||||
```
|
||||
|
||||
### TO-DO
|
||||
- Add visualization of the model
|
||||
- plot the results with different snr and ratio
|
||||
|
||||
## Citation
|
||||
If you find (part of) this code useful for your research, please consider citing
|
||||
|
||||
3
eval.py
3
eval.py
@ -30,7 +30,8 @@ def main():
|
||||
c = file_name.split('_')[-1].split('.')[0]
|
||||
c = int(c)
|
||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=args.snr)
|
||||
model.load_state_dict(torch.load(args.saved))
|
||||
# model.load_state_dict(torch.load(args.saved))
|
||||
model.load_state_dict(torch.load(args.saved,map_location=torch.device('cuda:0')))
|
||||
model.change_channel(args.channel, args.snr)
|
||||
|
||||
psnr_all = 0.0
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.6 MiB After Width: | Height: | Size: 1.6 MiB |
4
train.py
4
train.py
@ -73,12 +73,12 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
# load data
|
||||
if args.dataset == 'cifar10':
|
||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||
train_dataset = datasets.CIFAR10(root='./dataset/', train=True,
|
||||
train_dataset = datasets.CIFAR10(root='../dataset/', train=True,
|
||||
download=True, transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
test_dataset = datasets.CIFAR10(root='./dataset/', train=False,
|
||||
test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
|
||||
download=True, transform=transform)
|
||||
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user