train.py modified
This commit is contained in:
parent
6e6d83673e
commit
8b1631ae90
10
train.py
10
train.py
@ -34,6 +34,8 @@ def config_parser():
|
|||||||
parser.add_argument('--dataset', default='cifar10', type=str,
|
parser.add_argument('--dataset', default='cifar10', type=str,
|
||||||
choices=['cifar10', 'imagenet'], help='dataset')
|
choices=['cifar10', 'imagenet'], help='dataset')
|
||||||
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
|
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
|
||||||
|
parser.add_argument('--if_scheduler', default=True, type=bool, help='if_scheduler')
|
||||||
|
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -80,13 +82,17 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
image_fisrt = train_dataset.__getitem__(0)[0]
|
image_fisrt = train_dataset.__getitem__(0)[0]
|
||||||
c = ratio2filtersize(image_fisrt, ratio)
|
c = ratio2filtersize(image_fisrt, ratio)
|
||||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
|
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
|
||||||
|
|
||||||
if args.parallel and torch.cuda.device_count() > 1:
|
if args.parallel and torch.cuda.device_count() > 1:
|
||||||
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
criterion = nn.MSELoss(reduction='mean').cuda()
|
criterion = nn.MSELoss(reduction='mean').cuda()
|
||||||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
|
if args.if_scheduler:
|
||||||
|
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)
|
||||||
|
|
||||||
|
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
|
||||||
for epoch in epoch_loop:
|
for epoch in epoch_loop:
|
||||||
run_loss = 0.0
|
run_loss = 0.0
|
||||||
for images, _ in tqdm((train_loader), leave=False):
|
for images, _ in tqdm((train_loader), leave=False):
|
||||||
@ -99,6 +105,8 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
run_loss += loss.item()
|
run_loss += loss.item()
|
||||||
|
if args.if_scheduler:
|
||||||
|
scheduler.step()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.eval()
|
model.eval()
|
||||||
test_mse = 0.0
|
test_mse = 0.0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user