para modified

This commit is contained in:
chun 2023-12-23 14:28:11 +08:00
parent e40a5c5694
commit 34a6929650

View File

@ -33,6 +33,7 @@ def config_parser():
parser.add_argument('--num_workers', default=0, type=int, help='num_workers') parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
parser.add_argument('--dataset', default='cifar10', type=str, parser.add_argument('--dataset', default='cifar10', type=str,
choices=['cifar10', 'imagenet'], help='dataset') choices=['cifar10', 'imagenet'], help='dataset')
parser.add_argument('--parallel', default=True, type=bool, help='parallel')
return parser.parse_args() return parser.parse_args()
@ -47,7 +48,7 @@ def main():
def train(args: config_parser(), ratio: float, snr: float): def train(args: config_parser(), ratio: float, snr: float):
device = torch.device('cuda') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# load data # load data
if args.dataset == 'cifar10': if args.dataset == 'cifar10':
transform = transforms.Compose([transforms.ToTensor(), ]) transform = transforms.Compose([transforms.ToTensor(), ])
@ -79,6 +80,7 @@ def train(args: config_parser(), ratio: float, snr: float):
image_fisrt = train_dataset.__getitem__(0)[0] image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio) c = ratio2filtersize(image_fisrt, ratio)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
if args.parallel and torch.cuda.device_count() > 1:
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
criterion = nn.MSELoss(reduction='mean') criterion = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
@ -88,6 +90,8 @@ def train(args: config_parser(), ratio: float, snr: float):
run_loss = 0.0 run_loss = 0.0
for images, _ in tqdm((train_loader), leave=False): for images, _ in tqdm((train_loader), leave=False):
optimizer.zero_grad() optimizer.zero_grad()
if not args.parallel:
images = images.cuda(device=device)
# images = images.cuda(device=device) # images = images.cuda(device=device)
outputs = model(images) outputs = model(images)
loss = criterion(image_normalization('denormalization')(outputs), loss = criterion(image_normalization('denormalization')(outputs),
@ -99,7 +103,7 @@ def train(args: config_parser(), ratio: float, snr: float):
model.eval() model.eval()
test_mse = 0.0 test_mse = 0.0
for images, _ in tqdm((test_loader), leave=False): for images, _ in tqdm((test_loader), leave=False):
images = images.cuda(device=device) images = images
outputs = model(images) outputs = model(images)
images = image_normalization('normalization')(images) images = image_normalization('normalization')(images)
outputs = image_normalization('normalization')(outputs) outputs = image_normalization('normalization')(outputs)