JSCC/eval.py
2023-12-23 13:40:56 +08:00

43 lines
1.5 KiB
Python

# to be implemented
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from utils import get_psnr, image_normalization
def config_parser():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
parser.add_argument('--saved', type=str, help='saved_path')
parser.add_argument('--snr', default=20, type=int, help='snr')
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
parser.add_argument('--times', default=100, type=int, help='num_workers')
return parser.parse_args()
def main():
args = config_parser()
transform = transforms.Compose([transforms.ToTensor(), ])
test_image = Image.open(args.test_image)
test_image.load()
test_image = transform(test_image)
model = torch.load(args.saved)
model.change_channel(args.channel, args.snr)
psnr_all = 0.0
for i in range(args.times):
demo_image = model(test_image)
image = image_normalization('denormalization')(image)
gt = image_normalization('denormalization')(gt)
psnr_all += get_psnr(demo_image, test_image)
demo_image = torch.cat([test_image, demo_image], dim=1)
demo_image = transforms.ToPILImage()(demo_image)
demo_image.save('./demo/demo.png')
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
if __name__ == '__main__':
main()