add TODO list

This commit is contained in:
chun 2024-03-18 13:52:46 +08:00
parent c81101c269
commit e2852526eb
6 changed files with 14 additions and 8 deletions

3
.gitignore vendored
View File

@ -7,4 +7,5 @@ dataset
.vscode/*
input.txt
output.txt
*.json
*.json
.vscode/*

8
.vscode/launch.json vendored
View File

@ -15,7 +15,7 @@
"--lr",
"1e-3",
"--epochs",
"1000",
"100",
"--batch_size",
"512",
"--if_scheduler",
@ -27,13 +27,15 @@
"--num_workers",
"4",
"--device",
"cuda:1",
"cuda:0",
"--ratio_list",
"1/3",
"--snr_list",
"100",
"--seed",
"42"
"42",
"--disable_tqdm",
"False"
]
}
]

View File

@ -47,7 +47,9 @@ Run(example presented in paper)
```
python eval.py --channel 'AWGN' --saved ./saved/${mode_path} --snr 20 --ratio_list 1/3 --test_img ${test_img_path}
```
### TO-DO
- Add visualization of the model
- plot the results with different snr and ratio
## Citation
If you find (part of) this code useful for your research, please consider citing

View File

@ -30,7 +30,8 @@ def main():
c = file_name.split('_')[-1].split('.')[0]
c = int(c)
model = DeepJSCC(c=c, channel_type=args.channel, snr=args.snr)
model.load_state_dict(torch.load(args.saved))
# model.load_state_dict(torch.load(args.saved))
model.load_state_dict(torch.load(args.saved,map_location=torch.device('cuda:0')))
model.change_channel(args.channel, args.snr)
psnr_all = 0.0

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 MiB

After

Width:  |  Height:  |  Size: 1.6 MiB

View File

@ -73,12 +73,12 @@ def train(args: config_parser(), ratio: float, snr: float):
# load data
if args.dataset == 'cifar10':
transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='./dataset/', train=True,
train_dataset = datasets.CIFAR10(root='../dataset/', train=True,
download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
test_dataset = datasets.CIFAR10(root='./dataset/', train=False,
test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)