This commit is contained in:
chun 2023-12-23 13:40:56 +08:00
parent acd646a9ea
commit f35c7fe716
4 changed files with 26 additions and 19 deletions

View File

@ -3,14 +3,15 @@ import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from utils import get_psnr
from utils import get_psnr, image_normalization
def config_parser():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
parser.add_argument('--saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
parser.add_argument('--snr', default=20, type=int, help='snr')
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
parser.add_argument('--times', default=100, type=int, help='num_workers')
return parser.parse_args()
@ -24,14 +25,18 @@ def main():
test_image.load()
test_image = transform(test_image)
model = torch.load(args.saved)
model.change_channel(args.channel, args.snr)
psnr_all = 0.0
for i in range(args.times):
demo_image = model(test_image)
image = image_normalization('denormalization')(image)
gt = image_normalization('denormalization')(gt)
psnr_all += get_psnr(demo_image, test_image)
demo_image = torch.cat([test_image, demo_image], dim=1)
demo_image = transforms.ToPILImage()(demo_image)
demo_image.save('./demo/demo.png')
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
if __name__ == '__main__':
main()

View File

@ -140,3 +140,6 @@ class DeepJSCC(nn.Module):
z = self.channel(z)
x_hat = self.decoder(z)
return x_hat
def change_channel(self, channel_type, snr):
self.channel = channel.channel(channel_type, snr)

View File

@ -9,7 +9,7 @@ import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
from model import DeepJSCC, ratio2filtersize
@ -79,31 +79,31 @@ def train(args: config_parser(), ratio: float, snr: float):
image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
criterion = nn.MSELoss(reduction='mean').cuda(device=device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
criterion=nn.MSELoss(reduction='mean').cuda(device=device)
optimizer=optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
epoch_loop=tqdm(range(args.epochs), total=args.epochs, leave=False)
for epoch in epoch_loop:
run_loss = 0.0
run_loss=0.0
for images, _ in tqdm((train_loader), leave=False):
optimizer.zero_grad()
images = images.cuda(device=device)
outputs = model(images)
loss = criterion(image_normalization('denormalization')(outputs),
# images = images.cuda(device=device)
outputs=model(images)
loss=criterion(image_normalization('denormalization')(outputs),
image_normalization('denormalization')(images))
loss.backward()
optimizer.step()
run_loss += loss.item()
with torch.no_grad():
model.eval()
test_mse = 0.0
test_mse=0.0
for images, _ in tqdm((test_loader), leave=False):
images = images.cuda(device=device)
outputs = model(images)
images = image_normalization('normalization')(images)
outputs = image_normalization('normalization')(outputs)
loss = criterion(outputs, images)
images=images.cuda(device=device)
outputs=model(images)
images=image_normalization('normalization')(images)
outputs=image_normalization('normalization')(outputs)
loss=criterion(outputs, images)
test_mse += loss.item()
model.train()
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))

View File

@ -15,8 +15,7 @@ def image_normalization(norm_type):
def get_psnr(image, gt, max=255):
image = image_normalization('denormalization')(image)
gt = image_normalization('denormalization')(gt)
mse = F.mse_loss(image, gt)
psnr = 10 * torch.log10(max**2 / mse)