train.py modified
This commit is contained in:
parent
8b1631ae90
commit
3ecc36cb3c
7
train.py
7
train.py
@ -15,6 +15,7 @@ from tqdm import tqdm
|
|||||||
from model import DeepJSCC, ratio2filtersize
|
from model import DeepJSCC, ratio2filtersize
|
||||||
from torch.nn.parallel import DataParallel
|
from torch.nn.parallel import DataParallel
|
||||||
from utils import image_normalization
|
from utils import image_normalization
|
||||||
|
from fractions import Fraction
|
||||||
|
|
||||||
|
|
||||||
def config_parser():
|
def config_parser():
|
||||||
@ -29,7 +30,8 @@ def config_parser():
|
|||||||
choices=['AWGN', 'Rayleigh'], help='channel type')
|
choices=['AWGN', 'Rayleigh'], help='channel type')
|
||||||
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
||||||
parser.add_argument('--snr_list', default=list(range(1, 19, 3)), nargs='+', help='snr_list')
|
parser.add_argument('--snr_list', default=list(range(1, 19, 3)), nargs='+', help='snr_list')
|
||||||
parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], nargs='+', help='ratio_list')
|
parser.add_argument('--ratio_list', default=['1/3',
|
||||||
|
'1/6', '1/12'], nargs='+', help='ratio_list')
|
||||||
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
||||||
parser.add_argument('--dataset', default='cifar10', type=str,
|
parser.add_argument('--dataset', default='cifar10', type=str,
|
||||||
choices=['cifar10', 'imagenet'], help='dataset')
|
choices=['cifar10', 'imagenet'], help='dataset')
|
||||||
@ -41,7 +43,8 @@ def config_parser():
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = config_parser()
|
args = config_parser()
|
||||||
|
args.snr_list = list(map(float, args.snr_list))
|
||||||
|
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
|
||||||
print("Training Start")
|
print("Training Start")
|
||||||
for ratio in args.ratio_list:
|
for ratio in args.ratio_list:
|
||||||
for snr in args.snr_list:
|
for snr in args.snr_list:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user