From 5c5ff99bb91d0611ab994f9cc0704006b9285f33 Mon Sep 17 00:00:00 2001 From: chun Date: Sat, 23 Dec 2023 17:30:19 +0800 Subject: [PATCH] debug norm --- train.py | 8 ++++---- utils.py | 4 ---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index 67f5674..62583fe 100644 --- a/train.py +++ b/train.py @@ -83,7 +83,7 @@ def train(args: config_parser(), ratio: float, snr: float): if args.parallel and torch.cuda.device_count() > 1: model = DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) model = model.cuda() - criterion = nn.MSELoss(reduction='mean') + criterion = nn.MSELoss(reduction='mean').cuda() optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False) @@ -103,10 +103,10 @@ def train(args: config_parser(), ratio: float, snr: float): model.eval() test_mse = 0.0 for images, _ in tqdm((test_loader), leave=False): - images = images + images = images.cuda() outputs = model(images) - images = image_normalization('normalization')(images) - outputs = image_normalization('normalization')(outputs) + images = image_normalization('denormalization')(images) + outputs = image_normalization('denormalization')(outputs) loss = criterion(outputs, images) test_mse += loss.item() model.train() diff --git a/utils.py b/utils.py index 1f5c6f8..efd29f8 100644 --- a/utils.py +++ b/utils.py @@ -20,7 +20,3 @@ def get_psnr(image, gt, max=255): psnr = 10 * torch.log10(max**2 / mse) return psnr - - -a = torch.randn(2, 3, 32, 32) -b = image_normalization('nomalization')(a)