From 3ecc36cb3c20286280180b2a34a6e23e36a79c88 Mon Sep 17 00:00:00 2001 From: chun Date: Sat, 23 Dec 2023 20:24:33 +0800 Subject: [PATCH] train.py modified --- train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 410279b..e943379 100644 --- a/train.py +++ b/train.py @@ -15,6 +15,7 @@ from tqdm import tqdm from model import DeepJSCC, ratio2filtersize from torch.nn.parallel import DataParallel from utils import image_normalization +from fractions import Fraction def config_parser(): @@ -29,7 +30,8 @@ def config_parser(): choices=['AWGN', 'Rayleigh'], help='channel type') parser.add_argument('--saved', default='./saved', type=str, help='saved_path') parser.add_argument('--snr_list', default=list(range(1, 19, 3)), nargs='+', help='snr_list') - parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], nargs='+', help='ratio_list') + parser.add_argument('--ratio_list', default=['1/3', + '1/6', '1/12'], nargs='+', help='ratio_list') parser.add_argument('--num_workers', default=0, type=int, help='num_workers') parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset') @@ -41,7 +43,8 @@ def config_parser(): def main(): args = config_parser() - + args.snr_list = list(map(float, args.snr_list)) + args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list)) print("Training Start") for ratio in args.ratio_list: for snr in args.snr_list: