This commit is contained in:
chun 2023-12-23 13:40:56 +08:00
parent acd646a9ea
commit f35c7fe716
4 changed files with 26 additions and 19 deletions

View File

@ -3,14 +3,15 @@ import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from utils import get_psnr from utils import get_psnr, image_normalization
def config_parser(): def config_parser():
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--channel', default='AWGN', type=str, help='channel type') parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
parser.add_argument('--saved', type=str, help='saved_path') parser.add_argument('--saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list') parser.add_argument('--snr', default=20, type=int, help='snr')
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image') parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
parser.add_argument('--times', default=100, type=int, help='num_workers') parser.add_argument('--times', default=100, type=int, help='num_workers')
return parser.parse_args() return parser.parse_args()
@ -24,14 +25,18 @@ def main():
test_image.load() test_image.load()
test_image = transform(test_image) test_image = transform(test_image)
model = torch.load(args.saved) model = torch.load(args.saved)
model.change_channel(args.channel, args.snr)
psnr_all = 0.0 psnr_all = 0.0
for i in range(args.times): for i in range(args.times):
demo_image = model(test_image) demo_image = model(test_image)
image = image_normalization('denormalization')(image)
gt = image_normalization('denormalization')(gt)
psnr_all += get_psnr(demo_image, test_image) psnr_all += get_psnr(demo_image, test_image)
demo_image = torch.cat([test_image, demo_image], dim=1) demo_image = torch.cat([test_image, demo_image], dim=1)
demo_image = transforms.ToPILImage()(demo_image) demo_image = transforms.ToPILImage()(demo_image)
demo_image.save('./demo/demo.png') demo_image.save('./demo/demo.png')
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times)) print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -140,3 +140,6 @@ class DeepJSCC(nn.Module):
z = self.channel(z) z = self.channel(z)
x_hat = self.decoder(z) x_hat = self.decoder(z)
return x_hat return x_hat
def change_channel(self, channel_type, snr):
self.channel = channel.channel(channel_type, snr)

View File

@ -9,7 +9,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torchvision import transforms from torchvision import transforms
from torchvision import datasets from torchvision import datasets
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader
import torch.optim as optim import torch.optim as optim
from tqdm import tqdm from tqdm import tqdm
from model import DeepJSCC, ratio2filtersize from model import DeepJSCC, ratio2filtersize
@ -79,7 +79,7 @@ def train(args: config_parser(), ratio: float, snr: float):
image_fisrt = train_dataset.__getitem__(0)[0] image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio) c = ratio2filtersize(image_fisrt, ratio)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
criterion=nn.MSELoss(reduction='mean').cuda(device=device) criterion=nn.MSELoss(reduction='mean').cuda(device=device)
optimizer=optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer=optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
epoch_loop=tqdm(range(args.epochs), total=args.epochs, leave=False) epoch_loop=tqdm(range(args.epochs), total=args.epochs, leave=False)
@ -88,7 +88,7 @@ def train(args: config_parser(), ratio: float, snr: float):
run_loss=0.0 run_loss=0.0
for images, _ in tqdm((train_loader), leave=False): for images, _ in tqdm((train_loader), leave=False):
optimizer.zero_grad() optimizer.zero_grad()
images = images.cuda(device=device) # images = images.cuda(device=device)
outputs=model(images) outputs=model(images)
loss=criterion(image_normalization('denormalization')(outputs), loss=criterion(image_normalization('denormalization')(outputs),
image_normalization('denormalization')(images)) image_normalization('denormalization')(images))

View File

@ -15,8 +15,7 @@ def image_normalization(norm_type):
def get_psnr(image, gt, max=255): def get_psnr(image, gt, max=255):
image = image_normalization('denormalization')(image)
gt = image_normalization('denormalization')(gt)
mse = F.mse_loss(image, gt) mse = F.mse_loss(image, gt)
psnr = 10 * torch.log10(max**2 / mse) psnr = 10 * torch.log10(max**2 / mse)