This commit is contained in:
chun 2023-12-25 16:16:47 +08:00
parent 559415c240
commit d168f2946e
9 changed files with 47 additions and 11 deletions

3
.gitignore vendored
View File

@ -2,4 +2,5 @@ test.py
*.pyc
*.log
Dataset/*
*.ipynb
*.ipynb
*.swp

34
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,34 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: 当前文件",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"args": [
"--lr",
"10e-3",
"--epochs",
"50",
"--batch_size",
"512",
"--dataset",
"cifar10",
"--num_workers",
"4",
"--device",
"cuda:1",
"--ratio_list",
"1/3",
"--snr_list",
"100"
]
}
]
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.5 MiB

After

Width:  |  Height:  |  Size: 1.9 MiB

13
eval.py
View File

@ -15,14 +15,14 @@ def config_parser():
parser.add_argument('--saved', type=str, help='saved_path')
parser.add_argument('--snr', default=20, type=int, help='snr')
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
parser.add_argument('--times', default=100, type=int, help='num_workers')
parser.add_argument('--times', default=10, type=int, help='num_workers')
return parser.parse_args()
def main():
args = config_parser()
transform = transforms.Compose([transforms.ToTensor()])
# args.saved = './saved/model_cifar10_0.33_19.00_40.pth' # to be deleted
args.saved = './saved/cifar10_50_0.33_100.00_41.pth' # to be deleted
test_image = Image.open(args.test_image)
test_image.load()
test_image = transform(test_image)
@ -35,14 +35,15 @@ def main():
model.change_channel(args.channel, args.snr)
psnr_all = 0.0
demo_image = model(test_image)
demo_image = image_normalization('denormalization')(demo_image)
gt = image_normalization('denormalization')(test_image)
for i in range(args.times):
demo_image = model(test_image)
demo_image = image_normalization('denormalization')(demo_image)
gt = image_normalization('denormalization')(test_image)
psnr_all += get_psnr(demo_image, gt)
demo_image = torch.cat([test_image, demo_image], dim=1)
demo_image = transforms.ToPILImage()(demo_image)
demo_image.save('./demo/demo.png')
temp = args.saved.split('/')[-1]
demo_image.save('./run/{}.png'.format(args.saved.split('/')[-1]))
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))

View File

@ -32,7 +32,7 @@ def ratio2filtersize(x: torch.Tensor, ratio):
encoder_temp = _Encoder(is_temp=True)
z_temp = encoder_temp(x)
c = before_size * ratio / np.prod(z_temp.size()[-2:])
return int(c)
return int(c) + 1
class _ConvWithPReLU(nn.Module):

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

Binary file not shown.

Binary file not shown.

View File

@ -38,7 +38,7 @@ def config_parser():
parser.add_argument('--dataset', default='cifar10', type=str,
choices=['cifar10', 'imagenet'], help='dataset')
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
parser.add_argument('--if_scheduler', default=True, type=bool, help='if_scheduler')
parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler')
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
parser.add_argument('--device', default='cuda:0', type=str, help='device')
return parser.parse_args()
@ -53,7 +53,6 @@ def main():
for snr in args.snr_list:
train(args, ratio, snr)
def train(args: config_parser(), ratio: float, snr: float):
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
@ -72,6 +71,7 @@ def train(args: config_parser(), ratio: float, snr: float):
elif args.dataset == 'imagenet':
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
print("loading data of imagenet")
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train', transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True,
@ -113,7 +113,7 @@ def train(args: config_parser(), ratio: float, snr: float):
loss.backward()
optimizer.step()
run_loss += loss.item()
if args.if_scheduler:
if args.if_scheduler: # the scheduler is wrong before
scheduler.step()
with torch.no_grad():
model.eval()