train.py modified

This commit is contained in:
chun 2023-12-23 20:24:33 +08:00
parent 8b1631ae90
commit 3ecc36cb3c

View File

@ -15,6 +15,7 @@ from tqdm import tqdm
from model import DeepJSCC, ratio2filtersize from model import DeepJSCC, ratio2filtersize
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
from utils import image_normalization from utils import image_normalization
from fractions import Fraction
def config_parser(): def config_parser():
@ -29,7 +30,8 @@ def config_parser():
choices=['AWGN', 'Rayleigh'], help='channel type') choices=['AWGN', 'Rayleigh'], help='channel type')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path') parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=list(range(1, 19, 3)), nargs='+', help='snr_list') parser.add_argument('--snr_list', default=list(range(1, 19, 3)), nargs='+', help='snr_list')
parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], nargs='+', help='ratio_list') parser.add_argument('--ratio_list', default=['1/3',
'1/6', '1/12'], nargs='+', help='ratio_list')
parser.add_argument('--num_workers', default=0, type=int, help='num_workers') parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
parser.add_argument('--dataset', default='cifar10', type=str, parser.add_argument('--dataset', default='cifar10', type=str,
choices=['cifar10', 'imagenet'], help='dataset') choices=['cifar10', 'imagenet'], help='dataset')
@ -41,7 +43,8 @@ def config_parser():
def main(): def main():
args = config_parser() args = config_parser()
args.snr_list = list(map(float, args.snr_list))
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
print("Training Start") print("Training Start")
for ratio in args.ratio_list: for ratio in args.ratio_list:
for snr in args.snr_list: for snr in args.snr_list: