para modified
This commit is contained in:
parent
05cdfb29c0
commit
0e8aaeda07
5
train.py
5
train.py
@ -82,7 +82,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
|
||||
if args.parallel and torch.cuda.device_count() > 1:
|
||||
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
||||
model.cuda()
|
||||
model = model.cuda()
|
||||
criterion = nn.MSELoss(reduction='mean')
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
|
||||
@ -91,8 +91,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
run_loss = 0.0
|
||||
for images, _ in tqdm((train_loader), leave=False):
|
||||
optimizer.zero_grad()
|
||||
if not args.parallel:
|
||||
images = images.cuda(device=device)
|
||||
images = images.cuda()
|
||||
outputs = model(images)
|
||||
loss = criterion(image_normalization('denormalization')(outputs),
|
||||
image_normalization('denormalization')(images))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user