train.py modified
This commit is contained in:
parent
8b1631ae90
commit
3ecc36cb3c
7
train.py
7
train.py
@ -15,6 +15,7 @@ from tqdm import tqdm
|
||||
from model import DeepJSCC, ratio2filtersize
|
||||
from torch.nn.parallel import DataParallel
|
||||
from utils import image_normalization
|
||||
from fractions import Fraction
|
||||
|
||||
|
||||
def config_parser():
|
||||
@ -29,7 +30,8 @@ def config_parser():
|
||||
choices=['AWGN', 'Rayleigh'], help='channel type')
|
||||
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
||||
parser.add_argument('--snr_list', default=list(range(1, 19, 3)), nargs='+', help='snr_list')
|
||||
parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], nargs='+', help='ratio_list')
|
||||
parser.add_argument('--ratio_list', default=['1/3',
|
||||
'1/6', '1/12'], nargs='+', help='ratio_list')
|
||||
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
||||
parser.add_argument('--dataset', default='cifar10', type=str,
|
||||
choices=['cifar10', 'imagenet'], help='dataset')
|
||||
@ -41,7 +43,8 @@ def config_parser():
|
||||
|
||||
def main():
|
||||
args = config_parser()
|
||||
|
||||
args.snr_list = list(map(float, args.snr_list))
|
||||
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
|
||||
print("Training Start")
|
||||
for ratio in args.ratio_list:
|
||||
for snr in args.snr_list:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user