format print

This commit is contained in:
chun 2023-12-23 14:06:51 +08:00
parent f35c7fe716
commit c7f456b7b4

View File

@ -74,36 +74,36 @@ def train(args: config_parser(), ratio: float, snr: float):
else: else:
raise Exception('Unknown dataset') raise Exception('Unknown dataset')
print("training with ratio: {:2f}, snr_db: {}, channel: {}".format(ratio, snr, args.channel)) print("training with ratio: {:.2f}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
image_fisrt = train_dataset.__getitem__(0)[0] image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio) c = ratio2filtersize(image_fisrt, ratio)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
criterion=nn.MSELoss(reduction='mean').cuda(device=device) criterion = nn.MSELoss(reduction='mean').cuda(device=device)
optimizer=optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
epoch_loop=tqdm(range(args.epochs), total=args.epochs, leave=False) epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
for epoch in epoch_loop: for epoch in epoch_loop:
run_loss=0.0 run_loss = 0.0
for images, _ in tqdm((train_loader), leave=False): for images, _ in tqdm((train_loader), leave=False):
optimizer.zero_grad() optimizer.zero_grad()
# images = images.cuda(device=device) # images = images.cuda(device=device)
outputs=model(images) outputs = model(images)
loss=criterion(image_normalization('denormalization')(outputs), loss = criterion(image_normalization('denormalization')(outputs),
image_normalization('denormalization')(images)) image_normalization('denormalization')(images))
loss.backward() loss.backward()
optimizer.step() optimizer.step()
run_loss += loss.item() run_loss += loss.item()
with torch.no_grad(): with torch.no_grad():
model.eval() model.eval()
test_mse=0.0 test_mse = 0.0
for images, _ in tqdm((test_loader), leave=False): for images, _ in tqdm((test_loader), leave=False):
images=images.cuda(device=device) images = images.cuda(device=device)
outputs=model(images) outputs = model(images)
images=image_normalization('normalization')(images) images = image_normalization('normalization')(images)
outputs=image_normalization('normalization')(outputs) outputs = image_normalization('normalization')(outputs)
loss=criterion(outputs, images) loss = criterion(outputs, images)
test_mse += loss.item() test_mse += loss.item()
model.train() model.train()
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader)) epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))