format print
This commit is contained in:
parent
f35c7fe716
commit
c7f456b7b4
26
train.py
26
train.py
@ -74,36 +74,36 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
else:
|
else:
|
||||||
raise Exception('Unknown dataset')
|
raise Exception('Unknown dataset')
|
||||||
|
|
||||||
print("training with ratio: {:2f}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
|
print("training with ratio: {:.2f}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
|
||||||
|
|
||||||
image_fisrt = train_dataset.__getitem__(0)[0]
|
image_fisrt = train_dataset.__getitem__(0)[0]
|
||||||
c = ratio2filtersize(image_fisrt, ratio)
|
c = ratio2filtersize(image_fisrt, ratio)
|
||||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
|
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
|
||||||
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
||||||
criterion=nn.MSELoss(reduction='mean').cuda(device=device)
|
criterion = nn.MSELoss(reduction='mean').cuda(device=device)
|
||||||
optimizer=optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
epoch_loop=tqdm(range(args.epochs), total=args.epochs, leave=False)
|
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
|
||||||
|
|
||||||
for epoch in epoch_loop:
|
for epoch in epoch_loop:
|
||||||
run_loss=0.0
|
run_loss = 0.0
|
||||||
for images, _ in tqdm((train_loader), leave=False):
|
for images, _ in tqdm((train_loader), leave=False):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
# images = images.cuda(device=device)
|
# images = images.cuda(device=device)
|
||||||
outputs=model(images)
|
outputs = model(images)
|
||||||
loss=criterion(image_normalization('denormalization')(outputs),
|
loss = criterion(image_normalization('denormalization')(outputs),
|
||||||
image_normalization('denormalization')(images))
|
image_normalization('denormalization')(images))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
run_loss += loss.item()
|
run_loss += loss.item()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.eval()
|
model.eval()
|
||||||
test_mse=0.0
|
test_mse = 0.0
|
||||||
for images, _ in tqdm((test_loader), leave=False):
|
for images, _ in tqdm((test_loader), leave=False):
|
||||||
images=images.cuda(device=device)
|
images = images.cuda(device=device)
|
||||||
outputs=model(images)
|
outputs = model(images)
|
||||||
images=image_normalization('normalization')(images)
|
images = image_normalization('normalization')(images)
|
||||||
outputs=image_normalization('normalization')(outputs)
|
outputs = image_normalization('normalization')(outputs)
|
||||||
loss=criterion(outputs, images)
|
loss = criterion(outputs, images)
|
||||||
test_mse += loss.item()
|
test_mse += loss.item()
|
||||||
model.train()
|
model.train()
|
||||||
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user