train
This commit is contained in:
parent
559415c240
commit
d168f2946e
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,3 +3,4 @@ test.py
|
|||||||
*.log
|
*.log
|
||||||
Dataset/*
|
Dataset/*
|
||||||
*.ipynb
|
*.ipynb
|
||||||
|
*.swp
|
||||||
34
.vscode/launch.json
vendored
Normal file
34
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python: 当前文件",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": true,
|
||||||
|
"args": [
|
||||||
|
"--lr",
|
||||||
|
"10e-3",
|
||||||
|
"--epochs",
|
||||||
|
"50",
|
||||||
|
"--batch_size",
|
||||||
|
"512",
|
||||||
|
"--dataset",
|
||||||
|
"cifar10",
|
||||||
|
"--num_workers",
|
||||||
|
"4",
|
||||||
|
"--device",
|
||||||
|
"cuda:1",
|
||||||
|
"--ratio_list",
|
||||||
|
"1/3",
|
||||||
|
"--snr_list",
|
||||||
|
"100"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
BIN
demo/demo.png
BIN
demo/demo.png
Binary file not shown.
|
Before Width: | Height: | Size: 1.5 MiB After Width: | Height: | Size: 1.9 MiB |
13
eval.py
13
eval.py
@ -15,14 +15,14 @@ def config_parser():
|
|||||||
parser.add_argument('--saved', type=str, help='saved_path')
|
parser.add_argument('--saved', type=str, help='saved_path')
|
||||||
parser.add_argument('--snr', default=20, type=int, help='snr')
|
parser.add_argument('--snr', default=20, type=int, help='snr')
|
||||||
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
|
parser.add_argument('--test_image', default='./demo/kodim08.png', type=str, help='demo_image')
|
||||||
parser.add_argument('--times', default=100, type=int, help='num_workers')
|
parser.add_argument('--times', default=10, type=int, help='num_workers')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = config_parser()
|
args = config_parser()
|
||||||
transform = transforms.Compose([transforms.ToTensor()])
|
transform = transforms.Compose([transforms.ToTensor()])
|
||||||
# args.saved = './saved/model_cifar10_0.33_19.00_40.pth' # to be deleted
|
args.saved = './saved/cifar10_50_0.33_100.00_41.pth' # to be deleted
|
||||||
test_image = Image.open(args.test_image)
|
test_image = Image.open(args.test_image)
|
||||||
test_image.load()
|
test_image.load()
|
||||||
test_image = transform(test_image)
|
test_image = transform(test_image)
|
||||||
@ -35,14 +35,15 @@ def main():
|
|||||||
model.change_channel(args.channel, args.snr)
|
model.change_channel(args.channel, args.snr)
|
||||||
|
|
||||||
psnr_all = 0.0
|
psnr_all = 0.0
|
||||||
|
demo_image = model(test_image)
|
||||||
|
demo_image = image_normalization('denormalization')(demo_image)
|
||||||
|
gt = image_normalization('denormalization')(test_image)
|
||||||
for i in range(args.times):
|
for i in range(args.times):
|
||||||
demo_image = model(test_image)
|
|
||||||
demo_image = image_normalization('denormalization')(demo_image)
|
|
||||||
gt = image_normalization('denormalization')(test_image)
|
|
||||||
psnr_all += get_psnr(demo_image, gt)
|
psnr_all += get_psnr(demo_image, gt)
|
||||||
demo_image = torch.cat([test_image, demo_image], dim=1)
|
demo_image = torch.cat([test_image, demo_image], dim=1)
|
||||||
demo_image = transforms.ToPILImage()(demo_image)
|
demo_image = transforms.ToPILImage()(demo_image)
|
||||||
demo_image.save('./demo/demo.png')
|
temp = args.saved.split('/')[-1]
|
||||||
|
demo_image.save('./run/{}.png'.format(args.saved.split('/')[-1]))
|
||||||
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
2
model.py
2
model.py
@ -32,7 +32,7 @@ def ratio2filtersize(x: torch.Tensor, ratio):
|
|||||||
encoder_temp = _Encoder(is_temp=True)
|
encoder_temp = _Encoder(is_temp=True)
|
||||||
z_temp = encoder_temp(x)
|
z_temp = encoder_temp(x)
|
||||||
c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
||||||
return int(c)
|
return int(c) + 1
|
||||||
|
|
||||||
|
|
||||||
class _ConvWithPReLU(nn.Module):
|
class _ConvWithPReLU(nn.Module):
|
||||||
|
|||||||
BIN
run/cifar10_50_0.33_100.00_41.pth.png
Normal file
BIN
run/cifar10_50_0.33_100.00_41.pth.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.9 MiB |
BIN
saved/cifar10_50_0.33_100.00_41.pth
Normal file
BIN
saved/cifar10_50_0.33_100.00_41.pth
Normal file
Binary file not shown.
BIN
saved/imagenet_3_0.33_100.00_20.pth
Normal file
BIN
saved/imagenet_3_0.33_100.00_20.pth
Normal file
Binary file not shown.
6
train.py
6
train.py
@ -38,7 +38,7 @@ def config_parser():
|
|||||||
parser.add_argument('--dataset', default='cifar10', type=str,
|
parser.add_argument('--dataset', default='cifar10', type=str,
|
||||||
choices=['cifar10', 'imagenet'], help='dataset')
|
choices=['cifar10', 'imagenet'], help='dataset')
|
||||||
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
|
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
|
||||||
parser.add_argument('--if_scheduler', default=True, type=bool, help='if_scheduler')
|
parser.add_argument('--if_scheduler', default=False, type=bool, help='if_scheduler')
|
||||||
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
|
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
|
||||||
parser.add_argument('--device', default='cuda:0', type=str, help='device')
|
parser.add_argument('--device', default='cuda:0', type=str, help='device')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
@ -53,7 +53,6 @@ def main():
|
|||||||
for snr in args.snr_list:
|
for snr in args.snr_list:
|
||||||
train(args, ratio, snr)
|
train(args, ratio, snr)
|
||||||
|
|
||||||
|
|
||||||
def train(args: config_parser(), ratio: float, snr: float):
|
def train(args: config_parser(), ratio: float, snr: float):
|
||||||
|
|
||||||
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
||||||
@ -72,6 +71,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
elif args.dataset == 'imagenet':
|
elif args.dataset == 'imagenet':
|
||||||
transform = transforms.Compose(
|
transform = transforms.Compose(
|
||||||
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
||||||
|
print("loading data of imagenet")
|
||||||
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train', transform=transform)
|
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train', transform=transform)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||||
@ -113,7 +113,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
run_loss += loss.item()
|
run_loss += loss.item()
|
||||||
if args.if_scheduler:
|
if args.if_scheduler: # the scheduler is wrong before
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user