add model
2
.vscode/launch.json
vendored
@ -19,7 +19,7 @@
|
|||||||
"--batch_size",
|
"--batch_size",
|
||||||
"512",
|
"512",
|
||||||
"--if_scheduler",
|
"--if_scheduler",
|
||||||
"False",
|
"1",
|
||||||
"--step_size",
|
"--step_size",
|
||||||
"500",
|
"500",
|
||||||
"--dataset",
|
"--dataset",
|
||||||
|
|||||||
BIN
demo/127_127.jpg
Normal file
|
After Width: | Height: | Size: 20 KiB |
BIN
demo/32_32.jpg
Normal file
|
After Width: | Height: | Size: 2.1 KiB |
5
eval.py
@ -34,14 +34,15 @@ def main():
|
|||||||
model.change_channel(args.channel, args.snr)
|
model.change_channel(args.channel, args.snr)
|
||||||
|
|
||||||
psnr_all = 0.0
|
psnr_all = 0.0
|
||||||
|
|
||||||
|
for i in range(args.times):
|
||||||
demo_image = model(test_image)
|
demo_image = model(test_image)
|
||||||
demo_image = image_normalization('denormalization')(demo_image)
|
demo_image = image_normalization('denormalization')(demo_image)
|
||||||
gt = image_normalization('denormalization')(test_image)
|
gt = image_normalization('denormalization')(test_image)
|
||||||
for i in range(args.times):
|
|
||||||
psnr_all += get_psnr(demo_image, gt)
|
psnr_all += get_psnr(demo_image, gt)
|
||||||
|
demo_image = image_normalization('normalization')(demo_image)
|
||||||
demo_image = torch.cat([test_image, demo_image], dim=1)
|
demo_image = torch.cat([test_image, demo_image], dim=1)
|
||||||
demo_image = transforms.ToPILImage()(demo_image)
|
demo_image = transforms.ToPILImage()(demo_image)
|
||||||
temp = args.saved.split('/')[-1]
|
|
||||||
demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1],args.test_image.split('/')[-1]))
|
demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1],args.test_image.split('/')[-1]))
|
||||||
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
||||||
|
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 1.6 MiB After Width: | Height: | Size: 1.4 MiB |
|
Before Width: | Height: | Size: 1.4 MiB After Width: | Height: | Size: 1.1 MiB |
BIN
run/cifar10_2000_0.33_100.00_64_40.pth_kodim08.png
Normal file
|
After Width: | Height: | Size: 1.6 MiB |
BIN
run/cifar10_2000_0.33_100.00_64_40.pth_kodim23.png
Normal file
|
After Width: | Height: | Size: 1.4 MiB |
BIN
run/cifar10_3000_0.33_100.00_256_40.pth_kodim08.png
Normal file
|
After Width: | Height: | Size: 1.6 MiB |
BIN
run/cifar10_3000_0.33_100.00_256_40.pth_kodim23.png
Normal file
|
After Width: | Height: | Size: 1.3 MiB |
BIN
run/cifar10_3000_0.33_200.00_1024_40.pth_kodim08.png
Normal file
|
After Width: | Height: | Size: 1.4 MiB |
BIN
run/cifar10_3000_0.33_200.00_1024_40.pth_kodim23.png
Normal file
|
After Width: | Height: | Size: 1.1 MiB |
BIN
run/imagenet_10_0.33_200.00_32_19.pth_kodim08.png
Normal file
|
After Width: | Height: | Size: 1.5 MiB |
BIN
run/imagenet_10_0.33_200.00_32_19.pth_kodim23.png
Normal file
|
After Width: | Height: | Size: 1.0 MiB |
BIN
saved/cifar10_2000_0.33_100.00_64_40.pth
Normal file
BIN
saved/cifar10_3000_0.33_100.00_256_40.pth
Normal file
BIN
saved/cifar10_3000_0.33_200.00_1024_40.pth
Normal file
BIN
saved/imagenet_10_0.33_200.00_32_19.pth
Normal file
37
train.py
@ -17,6 +17,15 @@ from torch.nn.parallel import DataParallel
|
|||||||
from utils import image_normalization
|
from utils import image_normalization
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from dataset import Vanilla
|
from dataset import Vanilla
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
def config_parser():
|
def config_parser():
|
||||||
@ -41,13 +50,17 @@ def config_parser():
|
|||||||
parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler')
|
parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler')
|
||||||
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
|
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
|
||||||
parser.add_argument('--device', default='cuda:0', type=str, help='device')
|
parser.add_argument('--device', default='cuda:0', type=str, help='device')
|
||||||
|
parser.add_argument('--gamma', default=0.5, type=float, help='gamma')
|
||||||
|
parser.add_argument('--disable_tqdm', default=True, type=bool, help='disable_tqdm')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = config_parser()
|
args = config_parser()
|
||||||
args.snr_list = list(map(float, args.snr_list))
|
args.snr_list = list(map(float, args.snr_list))
|
||||||
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
|
args.ratio_list = list(map(lambda x: float(Fraction(x)), args.ratio_list))
|
||||||
|
set_seed(args.seed)
|
||||||
print("Training Start")
|
print("Training Start")
|
||||||
for ratio in args.ratio_list:
|
for ratio in args.ratio_list:
|
||||||
for snr in args.snr_list:
|
for snr in args.snr_list:
|
||||||
@ -82,8 +95,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
else:
|
else:
|
||||||
raise Exception('Unknown dataset')
|
raise Exception('Unknown dataset')
|
||||||
|
|
||||||
print("training with ratio: {:.2f}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
|
print(args)
|
||||||
|
|
||||||
image_fisrt = train_dataset.__getitem__(0)[0]
|
image_fisrt = train_dataset.__getitem__(0)[0]
|
||||||
c = ratio2filtersize(image_fisrt, ratio)
|
c = ratio2filtersize(image_fisrt, ratio)
|
||||||
print("the inner channel is {}".format(c))
|
print("the inner channel is {}".format(c))
|
||||||
@ -98,12 +110,12 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
criterion = nn.MSELoss(reduction='mean').to(device)
|
criterion = nn.MSELoss(reduction='mean').to(device)
|
||||||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
if args.if_scheduler:
|
if args.if_scheduler:
|
||||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)
|
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
|
||||||
|
|
||||||
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=True)
|
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=True, disable=args.disable_tqdm)
|
||||||
for epoch in epoch_loop:
|
for epoch in epoch_loop:
|
||||||
run_loss = 0.0
|
run_loss = 0.0
|
||||||
for images, _ in tqdm((train_loader), leave=False):
|
for images, _ in tqdm((train_loader), leave=False, disable=args.disable_tqdm):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
images = images.cuda() if args.parallel else images.to(device)
|
images = images.cuda() if args.parallel else images.to(device)
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
@ -118,7 +130,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.eval()
|
model.eval()
|
||||||
test_mse = 0.0
|
test_mse = 0.0
|
||||||
for images, _ in tqdm((test_loader), leave=False):
|
for images, _ in tqdm((test_loader), leave=False, disable=args.disable_tqdm):
|
||||||
images = images if args.parallel else images.to(device)
|
images = images if args.parallel else images.to(device)
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
images = image_normalization('denormalization')(images)
|
images = image_normalization('denormalization')(images)
|
||||||
@ -126,15 +138,22 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
loss = criterion(outputs, images)
|
loss = criterion(outputs, images)
|
||||||
test_mse += loss.item()
|
test_mse += loss.item()
|
||||||
model.train()
|
model.train()
|
||||||
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
# epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
||||||
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f} lr:{}".format(
|
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}, lr:{}".format(
|
||||||
epoch, run_loss/len(train_loader), test_mse/len(test_loader), optimizer.param_groups[0]['lr']))
|
epoch, run_loss/len(train_loader), test_mse/len(test_loader), optimizer.param_groups[0]['lr']))
|
||||||
save_model(model, args.saved, args.saved +
|
save_model(model, args.saved, args.saved +
|
||||||
'/{}_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, c))
|
'/{}_{}_{:.2f}_{:.2f}_{}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, args.batch_size,c))
|
||||||
|
|
||||||
|
|
||||||
def save_model(model, dir, path):
|
def save_model(model, dir, path):
|
||||||
os.makedirs(dir, exist_ok=True)
|
os.makedirs(dir, exist_ok=True)
|
||||||
|
flag = 1
|
||||||
|
while True:
|
||||||
|
if os.path.exists(path):
|
||||||
|
path = path+'_'+str(flag)
|
||||||
|
flag += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
torch.save(model.state_dict(), path)
|
torch.save(model.state_dict(), path)
|
||||||
print("Model saved in {}".format(path))
|
print("Model saved in {}".format(path))
|
||||||
|
|
||||||
|
|||||||