debug nargs
This commit is contained in:
parent
c7f456b7b4
commit
874d6fb2f0
6
train.py
6
train.py
@ -28,8 +28,8 @@ def config_parser():
|
|||||||
parser.add_argument('--channel', default='AWGN', type=str,
|
parser.add_argument('--channel', default='AWGN', type=str,
|
||||||
choices=['AWGN', 'Rayleigh'], help='channel type')
|
choices=['AWGN', 'Rayleigh'], help='channel type')
|
||||||
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
||||||
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
|
parser.add_argument('--snr_list', default=list(range(1, 19, 3)), nargs='+', help='snr_list')
|
||||||
parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], type=list, help='ratio_list')
|
parser.add_argument('--ratio_list', default=[1/3, 1/6, 1/12], nargs='+', help='ratio_list')
|
||||||
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
||||||
parser.add_argument('--dataset', default='cifar10', type=str,
|
parser.add_argument('--dataset', default='cifar10', type=str,
|
||||||
choices=['cifar10', 'imagenet'], help='dataset')
|
choices=['cifar10', 'imagenet'], help='dataset')
|
||||||
@ -107,7 +107,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
test_mse += loss.item()
|
test_mse += loss.item()
|
||||||
model.train()
|
model.train()
|
||||||
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
||||||
save_model(model, args.saved + '/model{}_{:2f}_{:2f}.pth'.format(args.dataset, ratio, snr))
|
save_model(model, args.saved + '/model{}_{:.2f}_{:.2f}.pth'.format(args.dataset, ratio, snr))
|
||||||
|
|
||||||
|
|
||||||
def save_model(model, path):
|
def save_model(model, path):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user