debug cuda

This commit is contained in:
chun 2023-12-23 14:21:03 +08:00
parent 874d6fb2f0
commit e40a5c5694

View File

@ -78,9 +78,9 @@ def train(args: config_parser(), ratio: float, snr: float):
image_fisrt = train_dataset.__getitem__(0)[0] image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio) c = ratio2filtersize(image_fisrt, ratio)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
criterion = nn.MSELoss(reduction='mean').cuda(device=device) criterion = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)