diff --git a/train.py b/train.py index bdf1e65..67ed242 100644 --- a/train.py +++ b/train.py @@ -78,9 +78,9 @@ def train(args: config_parser(), ratio: float, snr: float): image_fisrt = train_dataset.__getitem__(0)[0] c = ratio2filtersize(image_fisrt, ratio) - model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device) + model = DeepJSCC(c=c, channel_type=args.channel, snr=snr) model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) - criterion = nn.MSELoss(reduction='mean').cuda(device=device) + criterion = nn.MSELoss(reduction='mean') optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)