para modified

This commit is contained in:
chun 2023-12-23 17:16:39 +08:00
parent 05cdfb29c0
commit 0e8aaeda07

View File

@ -82,7 +82,7 @@ def train(args: config_parser(), ratio: float, snr: float):
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
if args.parallel and torch.cuda.device_count() > 1: if args.parallel and torch.cuda.device_count() > 1:
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
model.cuda() model = model.cuda()
criterion = nn.MSELoss(reduction='mean') criterion = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
@ -91,8 +91,7 @@ def train(args: config_parser(), ratio: float, snr: float):
run_loss = 0.0 run_loss = 0.0
for images, _ in tqdm((train_loader), leave=False): for images, _ in tqdm((train_loader), leave=False):
optimizer.zero_grad() optimizer.zero_grad()
if not args.parallel: images = images.cuda()
images = images.cuda(device=device)
outputs = model(images) outputs = model(images)
loss = criterion(image_normalization('denormalization')(outputs), loss = criterion(image_normalization('denormalization')(outputs),
image_normalization('denormalization')(images)) image_normalization('denormalization')(images))