From 874d6fb2f091b9f382212c9adc5cd914ce9af4d9 Mon Sep 17 00:00:00 2001 From: chun Date: Sat, 23 Dec 2023 14:17:21 +0800 Subject: [PATCH] debug nargs --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 3087f45..bdf1e65 100644 --- a/train.py +++ b/train.py @@ -28,8 +28,8 @@ def config_parser(): parser.add_argument('--channel', default='AWGN', type=str, choices=['AWGN', 'Rayleigh'], help='channel type') parser.add_argument('--saved', default='./saved', type=str, help='saved_path') - parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list') - parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], type=list, help='ratio_list') + parser.add_argument('--snr_list', default=list(range(1, 19, 3)), nargs='+', help='snr_list') + parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], nargs='+', help='ratio_list') parser.add_argument('--num_workers', default=0, type=int, help='num_workers') parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset') @@ -107,7 +107,7 @@ def train(args: config_parser(), ratio: float, snr: float): test_mse += loss.item() model.train() epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader)) - save_model(model, args.saved + '/model{}_{:2f}_{:2f}.pth'.format(args.dataset, ratio, snr)) + save_model(model, args.saved + '/model{}_{:.2f}_{:.2f}.pth'.format(args.dataset, ratio, snr)) def save_model(model, path):