add TODO list

This commit is contained in:
chun 2024-03-18 13:52:46 +08:00
parent c81101c269
commit e2852526eb
6 changed files with 14 additions and 8 deletions

1
.gitignore vendored
View File

@ -8,3 +8,4 @@ dataset
input.txt input.txt
output.txt output.txt
*.json *.json
.vscode/*

8
.vscode/launch.json vendored
View File

@ -15,7 +15,7 @@
"--lr", "--lr",
"1e-3", "1e-3",
"--epochs", "--epochs",
"1000", "100",
"--batch_size", "--batch_size",
"512", "512",
"--if_scheduler", "--if_scheduler",
@ -27,13 +27,15 @@
"--num_workers", "--num_workers",
"4", "4",
"--device", "--device",
"cuda:1", "cuda:0",
"--ratio_list", "--ratio_list",
"1/3", "1/3",
"--snr_list", "--snr_list",
"100", "100",
"--seed", "--seed",
"42" "42",
"--disable_tqdm",
"False"
] ]
} }
] ]

View File

@ -47,7 +47,9 @@ Run(example presented in paper)
``` ```
python eval.py --channel 'AWGN' --saved ./saved/${mode_path} --snr 20 --ratio_list 1/3 --test_img ${test_img_path} python eval.py --channel 'AWGN' --saved ./saved/${mode_path} --snr 20 --ratio_list 1/3 --test_img ${test_img_path}
``` ```
### TO-DO
- Add visualization of the model
- plot the results with different snr and ratio
## Citation ## Citation
If you find (part of) this code useful for your research, please consider citing If you find (part of) this code useful for your research, please consider citing

View File

@ -30,7 +30,8 @@ def main():
c = file_name.split('_')[-1].split('.')[0] c = file_name.split('_')[-1].split('.')[0]
c = int(c) c = int(c)
model = DeepJSCC(c=c, channel_type=args.channel, snr=args.snr) model = DeepJSCC(c=c, channel_type=args.channel, snr=args.snr)
model.load_state_dict(torch.load(args.saved)) # model.load_state_dict(torch.load(args.saved))
model.load_state_dict(torch.load(args.saved,map_location=torch.device('cuda:0')))
model.change_channel(args.channel, args.snr) model.change_channel(args.channel, args.snr)
psnr_all = 0.0 psnr_all = 0.0

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 MiB

After

Width:  |  Height:  |  Size: 1.6 MiB

View File

@ -73,12 +73,12 @@ def train(args: config_parser(), ratio: float, snr: float):
# load data # load data
if args.dataset == 'cifar10': if args.dataset == 'cifar10':
transform = transforms.Compose([transforms.ToTensor(), ]) transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='./dataset/', train=True, train_dataset = datasets.CIFAR10(root='../dataset/', train=True,
download=True, transform=transform) download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers) batch_size=args.batch_size, num_workers=args.num_workers)
test_dataset = datasets.CIFAR10(root='./dataset/', train=False, test_dataset = datasets.CIFAR10(root='../dataset/', train=False,
download=True, transform=transform) download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True, test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers) batch_size=args.batch_size, num_workers=args.num_workers)