para modified
This commit is contained in:
parent
e40a5c5694
commit
34a6929650
8
train.py
8
train.py
@ -33,6 +33,7 @@ def config_parser():
|
|||||||
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
||||||
parser.add_argument('--dataset', default='cifar10', type=str,
|
parser.add_argument('--dataset', default='cifar10', type=str,
|
||||||
choices=['cifar10', 'imagenet'], help='dataset')
|
choices=['cifar10', 'imagenet'], help='dataset')
|
||||||
|
parser.add_argument('--parallel', default=True, type=bool, help='parallel')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -47,7 +48,7 @@ def main():
|
|||||||
|
|
||||||
def train(args: config_parser(), ratio: float, snr: float):
|
def train(args: config_parser(), ratio: float, snr: float):
|
||||||
|
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||||
# load data
|
# load data
|
||||||
if args.dataset == 'cifar10':
|
if args.dataset == 'cifar10':
|
||||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||||
@ -79,6 +80,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
image_fisrt = train_dataset.__getitem__(0)[0]
|
image_fisrt = train_dataset.__getitem__(0)[0]
|
||||||
c = ratio2filtersize(image_fisrt, ratio)
|
c = ratio2filtersize(image_fisrt, ratio)
|
||||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
|
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
|
||||||
|
if args.parallel and torch.cuda.device_count() > 1:
|
||||||
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
||||||
criterion = nn.MSELoss(reduction='mean')
|
criterion = nn.MSELoss(reduction='mean')
|
||||||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
@ -88,6 +90,8 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
run_loss = 0.0
|
run_loss = 0.0
|
||||||
for images, _ in tqdm((train_loader), leave=False):
|
for images, _ in tqdm((train_loader), leave=False):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
if not args.parallel:
|
||||||
|
images = images.cuda(device=device)
|
||||||
# images = images.cuda(device=device)
|
# images = images.cuda(device=device)
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
loss = criterion(image_normalization('denormalization')(outputs),
|
loss = criterion(image_normalization('denormalization')(outputs),
|
||||||
@ -99,7 +103,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
model.eval()
|
model.eval()
|
||||||
test_mse = 0.0
|
test_mse = 0.0
|
||||||
for images, _ in tqdm((test_loader), leave=False):
|
for images, _ in tqdm((test_loader), leave=False):
|
||||||
images = images.cuda(device=device)
|
images = images
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
images = image_normalization('normalization')(images)
|
images = image_normalization('normalization')(images)
|
||||||
outputs = image_normalization('normalization')(outputs)
|
outputs = image_normalization('normalization')(outputs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user