train.py modified
This commit is contained in:
parent
1aec37af4e
commit
1150b8e846
11
eval.py
11
eval.py
@ -4,6 +4,8 @@ import torch.nn as nn
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from utils import get_psnr, image_normalization
|
||||
import os
|
||||
from model import DeepJSCC
|
||||
|
||||
|
||||
def config_parser():
|
||||
@ -14,6 +16,7 @@ def config_parser():
|
||||
parser.add_argument('--snr', default=20, type=int, help='snr')
|
||||
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
|
||||
parser.add_argument('--times', default=100, type=int, help='num_workers')
|
||||
parser.add_argument('--filters', default=40, type=int, help='channel type')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -24,8 +27,14 @@ def main():
|
||||
test_image = Image.open(args.test_image)
|
||||
test_image.load()
|
||||
test_image = transform(test_image)
|
||||
model = torch.load(args.saved)
|
||||
|
||||
file_name = os.path.basename(args.saved)
|
||||
c = file_name.split('_')[-1].split('.')[0]
|
||||
c = int(c)
|
||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=args.snr)
|
||||
model.load_state_dict(torch.load(args.saved))
|
||||
model.change_channel(args.channel, args.snr)
|
||||
|
||||
psnr_all = 0.0
|
||||
for i in range(args.times):
|
||||
demo_image = model(test_image)
|
||||
|
||||
14
train.py
14
train.py
@ -23,7 +23,7 @@ def config_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
|
||||
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
|
||||
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
|
||||
parser.add_argument('--epochs', default=256, type=int, help='number of epochs')
|
||||
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
|
||||
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
|
||||
parser.add_argument('--channel', default='AWGN', type=str,
|
||||
@ -85,6 +85,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
|
||||
image_fisrt = train_dataset.__getitem__(0)[0]
|
||||
c = ratio2filtersize(image_fisrt, ratio)
|
||||
print("the inner channel is {}".format(c))
|
||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
|
||||
|
||||
if args.parallel and torch.cuda.device_count() > 1:
|
||||
@ -96,7 +97,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
if args.if_scheduler:
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)
|
||||
|
||||
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
|
||||
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=True)
|
||||
for epoch in epoch_loop:
|
||||
run_loss = 0.0
|
||||
for images, _ in tqdm((train_loader), leave=False):
|
||||
@ -123,12 +124,13 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
test_mse += loss.item()
|
||||
model.train()
|
||||
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
||||
save_model(model, args.saved + '/model{}_{:.2f}_{:.2f}.pth'.format(args.dataset, ratio, snr))
|
||||
save_model(model, args.saved, args.saved +
|
||||
'/model_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, ratio, snr, c))
|
||||
|
||||
|
||||
def save_model(model, path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torch.save(model, path)
|
||||
def save_model(model, dir, path):
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
torch.save(model.state_dict(), path)
|
||||
print("Model saved in {}".format(path))
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user