From 34a6929650c0b431a26f56bf0f725052554f45e8 Mon Sep 17 00:00:00 2001 From: chun Date: Sat, 23 Dec 2023 14:28:11 +0800 Subject: [PATCH] para modified --- train.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 67ed242..8eda390 100644 --- a/train.py +++ b/train.py @@ -33,6 +33,7 @@ def config_parser(): parser.add_argument('--num_workers', default=0, type=int, help='num_workers') parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset') + parser.add_argument('--parallel', default=True, type=bool, help='parallel') return parser.parse_args() @@ -47,7 +48,7 @@ def main(): def train(args: config_parser(), ratio: float, snr: float): - device = torch.device('cuda') + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # load data if args.dataset == 'cifar10': transform = transforms.Compose([transforms.ToTensor(), ]) @@ -79,7 +80,8 @@ def train(args: config_parser(), ratio: float, snr: float): image_fisrt = train_dataset.__getitem__(0)[0] c = ratio2filtersize(image_fisrt, ratio) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr) - model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) + if args.parallel and torch.cuda.device_count() > 1: + model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) criterion = nn.MSELoss(reduction='mean') optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) @@ -88,6 +90,8 @@ def train(args: config_parser(), ratio: float, snr: float): run_loss = 0.0 for images, _ in tqdm((train_loader), leave=False): optimizer.zero_grad() + if not args.parallel: + images = images.cuda(device=device) # images = images.cuda(device=device) outputs = model(images) loss = criterion(image_normalization('denormalization')(outputs), @@ -99,7 +103,7 @@ def train(args: config_parser(), ratio: float, snr: float): model.eval() test_mse = 0.0 for images, _ in tqdm((test_loader), leave=False): - images = images.cuda(device=device) + images = images outputs = model(images) images = image_normalization('normalization')(images) outputs = image_normalization('normalization')(outputs)