train
This commit is contained in:
parent
559415c240
commit
d168f2946e
3
.gitignore
vendored
3
.gitignore
vendored
@ -2,4 +2,5 @@ test.py
|
||||
*.pyc
|
||||
*.log
|
||||
Dataset/*
|
||||
*.ipynb
|
||||
*.ipynb
|
||||
*.swp
|
||||
34
.vscode/launch.json
vendored
Normal file
34
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: 当前文件",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"args": [
|
||||
"--lr",
|
||||
"10e-3",
|
||||
"--epochs",
|
||||
"50",
|
||||
"--batch_size",
|
||||
"512",
|
||||
"--dataset",
|
||||
"cifar10",
|
||||
"--num_workers",
|
||||
"4",
|
||||
"--device",
|
||||
"cuda:1",
|
||||
"--ratio_list",
|
||||
"1/3",
|
||||
"--snr_list",
|
||||
"100"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
BIN
demo/demo.png
BIN
demo/demo.png
Binary file not shown.
|
Before Width: | Height: | Size: 1.5 MiB After Width: | Height: | Size: 1.9 MiB |
13
eval.py
13
eval.py
@ -15,14 +15,14 @@ def config_parser():
|
||||
parser.add_argument('--saved', type=str, help='saved_path')
|
||||
parser.add_argument('--snr', default=20, type=int, help='snr')
|
||||
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
|
||||
parser.add_argument('--times', default=100, type=int, help='num_workers')
|
||||
parser.add_argument('--times', default=10, type=int, help='num_workers')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = config_parser()
|
||||
transform = transforms.Compose([transforms.ToTensor()])
|
||||
# args.saved = './saved/model_cifar10_0.33_19.00_40.pth' # to be deleted
|
||||
args.saved = './saved/cifar10_50_0.33_100.00_41.pth' # to be deleted
|
||||
test_image = Image.open(args.test_image)
|
||||
test_image.load()
|
||||
test_image = transform(test_image)
|
||||
@ -35,14 +35,15 @@ def main():
|
||||
model.change_channel(args.channel, args.snr)
|
||||
|
||||
psnr_all = 0.0
|
||||
demo_image = model(test_image)
|
||||
demo_image = image_normalization('denormalization')(demo_image)
|
||||
gt = image_normalization('denormalization')(test_image)
|
||||
for i in range(args.times):
|
||||
demo_image = model(test_image)
|
||||
demo_image = image_normalization('denormalization')(demo_image)
|
||||
gt = image_normalization('denormalization')(test_image)
|
||||
psnr_all += get_psnr(demo_image, gt)
|
||||
demo_image = torch.cat([test_image, demo_image], dim=1)
|
||||
demo_image = transforms.ToPILImage()(demo_image)
|
||||
demo_image.save('./demo/demo.png')
|
||||
temp = args.saved.split('/')[-1]
|
||||
demo_image.save('./run/{}.png'.format(args.saved.split('/')[-1]))
|
||||
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
||||
|
||||
|
||||
|
||||
2
model.py
2
model.py
@ -32,7 +32,7 @@ def ratio2filtersize(x: torch.Tensor, ratio):
|
||||
encoder_temp = _Encoder(is_temp=True)
|
||||
z_temp = encoder_temp(x)
|
||||
c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
||||
return int(c)
|
||||
return int(c) + 1
|
||||
|
||||
|
||||
class _ConvWithPReLU(nn.Module):
|
||||
|
||||
BIN
run/cifar10_50_0.33_100.00_41.pth.png
Normal file
BIN
run/cifar10_50_0.33_100.00_41.pth.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.9 MiB |
BIN
saved/cifar10_50_0.33_100.00_41.pth
Normal file
BIN
saved/cifar10_50_0.33_100.00_41.pth
Normal file
Binary file not shown.
BIN
saved/imagenet_3_0.33_100.00_20.pth
Normal file
BIN
saved/imagenet_3_0.33_100.00_20.pth
Normal file
Binary file not shown.
6
train.py
6
train.py
@ -38,7 +38,7 @@ def config_parser():
|
||||
parser.add_argument('--dataset', default='cifar10', type=str,
|
||||
choices=['cifar10', 'imagenet'], help='dataset')
|
||||
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
|
||||
parser.add_argument('--if_scheduler', default=True, type=bool, help='if_scheduler')
|
||||
parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler')
|
||||
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
|
||||
parser.add_argument('--device', default='cuda:0', type=str, help='device')
|
||||
return parser.parse_args()
|
||||
@ -53,7 +53,6 @@ def main():
|
||||
for snr in args.snr_list:
|
||||
train(args, ratio, snr)
|
||||
|
||||
|
||||
def train(args: config_parser(), ratio: float, snr: float):
|
||||
|
||||
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
||||
@ -72,6 +71,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
elif args.dataset == 'imagenet':
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
||||
print("loading data of imagenet")
|
||||
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train', transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||
@ -113,7 +113,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
run_loss += loss.item()
|
||||
if args.if_scheduler:
|
||||
if args.if_scheduler: # the scheduler is wrong before
|
||||
scheduler.step()
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user