train.py modified

This commit is contained in:
chun 2023-12-23 21:49:20 +08:00
parent 1aec37af4e
commit 1150b8e846
2 changed files with 18 additions and 7 deletions

11
eval.py
View File

@ -4,6 +4,8 @@ import torch.nn as nn
from PIL import Image
from torchvision import transforms
from utils import get_psnr, image_normalization
import os
from model import DeepJSCC
def config_parser():
@ -14,6 +16,7 @@ def config_parser():
parser.add_argument('--snr', default=20, type=int, help='snr')
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
parser.add_argument('--times', default=100, type=int, help='num_workers')
parser.add_argument('--filters', default=40, type=int, help='channel type')
return parser.parse_args()
@ -24,8 +27,14 @@ def main():
test_image = Image.open(args.test_image)
test_image.load()
test_image = transform(test_image)
model = torch.load(args.saved)
file_name = os.path.basename(args.saved)
c = file_name.split('_')[-1].split('.')[0]
c = int(c)
model = DeepJSCC(c=c, channel_type=args.channel, snr=args.snr)
model.load_state_dict(torch.load(args.saved))
model.change_channel(args.channel, args.snr)
psnr_all = 0.0
for i in range(args.times):
demo_image = model(test_image)

View File

@ -23,7 +23,7 @@ def config_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
parser.add_argument('--epochs', default=256, type=int, help='number of epochs')
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str,
@ -85,6 +85,7 @@ def train(args: config_parser(), ratio: float, snr: float):
image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio)
print("the inner channel is {}".format(c))
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
if args.parallel and torch.cuda.device_count() > 1:
@ -96,7 +97,7 @@ def train(args: config_parser(), ratio: float, snr: float):
if args.if_scheduler:
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=True)
for epoch in epoch_loop:
run_loss = 0.0
for images, _ in tqdm((train_loader), leave=False):
@ -123,12 +124,13 @@ def train(args: config_parser(), ratio: float, snr: float):
test_mse += loss.item()
model.train()
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
save_model(model, args.saved + '/model{}_{:.2f}_{:.2f}.pth'.format(args.dataset, ratio, snr))
save_model(model, args.saved, args.saved +
'/model_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, ratio, snr, c))
def save_model(model, path):
os.makedirs(path, exist_ok=True)
torch.save(model, path)
def save_model(model, dir, path):
os.makedirs(dir, exist_ok=True)
torch.save(model.state_dict(), path)
print("Model saved in {}".format(path))