From c7f456b7b4dfe3c94b44638dece673ffaebf439e Mon Sep 17 00:00:00 2001 From: chun Date: Sat, 23 Dec 2023 14:06:51 +0800 Subject: [PATCH] format print --- train.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/train.py b/train.py index 3e99b38..3087f45 100644 --- a/train.py +++ b/train.py @@ -74,36 +74,36 @@ def train(args: config_parser(), ratio: float, snr: float): else: raise Exception('Unknown dataset') - print("training with ratio: {:2f}, snr_db: {}, channel: {}".format(ratio, snr, args.channel)) + print("training with ratio: {:.2f}, snr_db: {}, channel: {}".format(ratio, snr, args.channel)) image_fisrt = train_dataset.__getitem__(0)[0] c = ratio2filtersize(image_fisrt, ratio) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device) model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) - criterion=nn.MSELoss(reduction='mean').cuda(device=device) - optimizer=optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) - epoch_loop=tqdm(range(args.epochs), total=args.epochs, leave=False) + criterion = nn.MSELoss(reduction='mean').cuda(device=device) + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) for epoch in epoch_loop: - run_loss=0.0 + run_loss = 0.0 for images, _ in tqdm((train_loader), leave=False): optimizer.zero_grad() # images = images.cuda(device=device) - outputs=model(images) - loss=criterion(image_normalization('denormalization')(outputs), + outputs = model(images) + loss = criterion(image_normalization('denormalization')(outputs), image_normalization('denormalization')(images)) loss.backward() optimizer.step() run_loss += loss.item() with torch.no_grad(): model.eval() - test_mse=0.0 + test_mse = 0.0 for images, _ in tqdm((test_loader), leave=False): - images=images.cuda(device=device) - outputs=model(images) - images=image_normalization('normalization')(images) - outputs=image_normalization('normalization')(outputs) - loss=criterion(outputs, images) + images = images.cuda(device=device) + outputs = model(images) + images = image_normalization('normalization')(images) + outputs = image_normalization('normalization')(outputs) + loss = criterion(outputs, images) test_mse += loss.item() model.train() epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))