diff --git a/train.py b/train.py index 410279b..e943379 100644 --- a/train.py +++ b/train.py @@ -15,6 +15,7 @@ from tqdm import tqdm from model import DeepJSCC, ratio2filtersize from torch.nn.parallel import DataParallel from utils import image_normalization +from fractions import Fraction def config_parser(): @@ -29,7 +30,8 @@ def config_parser(): choices=['AWGN', 'Rayleigh'], help='channel type') parser.add_argument('--saved', default='./saved', type=str, help='saved_path') parser.add_argument('--snr_list', default=list(range(1, 19, 3)), nargs='+', help='snr_list') - parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], nargs='+', help='ratio_list') + parser.add_argument('--ratio_list', default=['1/3', + '1/6', '1/12'], nargs='+', help='ratio_list') parser.add_argument('--num_workers', default=0, type=int, help='num_workers') parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset') @@ -41,7 +43,8 @@ def config_parser(): def main(): args = config_parser() - + args.snr_list = list(map(float, args.snr_list)) + args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list)) print("Training Start") for ratio in args.ratio_list: for snr in args.snr_list: