v1.1
This commit is contained in:
parent
acd646a9ea
commit
f35c7fe716
9
eval.py
9
eval.py
@ -3,14 +3,15 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from utils import get_psnr
|
from utils import get_psnr, image_normalization
|
||||||
|
|
||||||
|
|
||||||
def config_parser():
|
def config_parser():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
||||||
parser.add_argument('--saved', type=str, help='saved_path')
|
parser.add_argument('--saved', type=str, help='saved_path')
|
||||||
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
|
parser.add_argument('--snr', default=20, type=int, help='snr')
|
||||||
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
|
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
|
||||||
parser.add_argument('--times', default=100, type=int, help='num_workers')
|
parser.add_argument('--times', default=100, type=int, help='num_workers')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
@ -24,14 +25,18 @@ def main():
|
|||||||
test_image.load()
|
test_image.load()
|
||||||
test_image = transform(test_image)
|
test_image = transform(test_image)
|
||||||
model = torch.load(args.saved)
|
model = torch.load(args.saved)
|
||||||
|
model.change_channel(args.channel, args.snr)
|
||||||
psnr_all = 0.0
|
psnr_all = 0.0
|
||||||
for i in range(args.times):
|
for i in range(args.times):
|
||||||
demo_image = model(test_image)
|
demo_image = model(test_image)
|
||||||
|
image = image_normalization('denormalization')(image)
|
||||||
|
gt = image_normalization('denormalization')(gt)
|
||||||
psnr_all += get_psnr(demo_image, test_image)
|
psnr_all += get_psnr(demo_image, test_image)
|
||||||
demo_image = torch.cat([test_image, demo_image], dim=1)
|
demo_image = torch.cat([test_image, demo_image], dim=1)
|
||||||
demo_image = transforms.ToPILImage()(demo_image)
|
demo_image = transforms.ToPILImage()(demo_image)
|
||||||
demo_image.save('./demo/demo.png')
|
demo_image.save('./demo/demo.png')
|
||||||
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
3
model.py
3
model.py
@ -140,3 +140,6 @@ class DeepJSCC(nn.Module):
|
|||||||
z = self.channel(z)
|
z = self.channel(z)
|
||||||
x_hat = self.decoder(z)
|
x_hat = self.decoder(z)
|
||||||
return x_hat
|
return x_hat
|
||||||
|
|
||||||
|
def change_channel(self, channel_type, snr):
|
||||||
|
self.channel = channel.channel(channel_type, snr)
|
||||||
|
|||||||
6
train.py
6
train.py
@ -9,7 +9,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision import datasets
|
from torchvision import datasets
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from model import DeepJSCC, ratio2filtersize
|
from model import DeepJSCC, ratio2filtersize
|
||||||
@ -79,7 +79,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
image_fisrt = train_dataset.__getitem__(0)[0]
|
image_fisrt = train_dataset.__getitem__(0)[0]
|
||||||
c = ratio2filtersize(image_fisrt, ratio)
|
c = ratio2filtersize(image_fisrt, ratio)
|
||||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
|
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
|
||||||
|
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
||||||
criterion=nn.MSELoss(reduction='mean').cuda(device=device)
|
criterion=nn.MSELoss(reduction='mean').cuda(device=device)
|
||||||
optimizer=optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
optimizer=optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
epoch_loop=tqdm(range(args.epochs), total=args.epochs, leave=False)
|
epoch_loop=tqdm(range(args.epochs), total=args.epochs, leave=False)
|
||||||
@ -88,7 +88,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
run_loss=0.0
|
run_loss=0.0
|
||||||
for images, _ in tqdm((train_loader), leave=False):
|
for images, _ in tqdm((train_loader), leave=False):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
images = images.cuda(device=device)
|
# images = images.cuda(device=device)
|
||||||
outputs=model(images)
|
outputs=model(images)
|
||||||
loss=criterion(image_normalization('denormalization')(outputs),
|
loss=criterion(image_normalization('denormalization')(outputs),
|
||||||
image_normalization('denormalization')(images))
|
image_normalization('denormalization')(images))
|
||||||
|
|||||||
3
utils.py
3
utils.py
@ -15,8 +15,7 @@ def image_normalization(norm_type):
|
|||||||
|
|
||||||
|
|
||||||
def get_psnr(image, gt, max=255):
|
def get_psnr(image, gt, max=255):
|
||||||
image = image_normalization('denormalization')(image)
|
|
||||||
gt = image_normalization('denormalization')(gt)
|
|
||||||
mse = F.mse_loss(image, gt)
|
mse = F.mse_loss(image, gt)
|
||||||
|
|
||||||
psnr = 10 * torch.log10(max**2 / mse)
|
psnr = 10 * torch.log10(max**2 / mse)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user