v1
This commit is contained in:
parent
c756af358d
commit
cc27436535
14
.vscode/launch.json
vendored
14
.vscode/launch.json
vendored
@ -13,15 +13,15 @@
|
||||
"justMyCode": true,
|
||||
"args": [
|
||||
"--lr",
|
||||
"10e-3",
|
||||
"1e-3",
|
||||
"--epochs",
|
||||
"1000",
|
||||
"--batch_size",
|
||||
"64",
|
||||
"512",
|
||||
"--if_scheduler",
|
||||
"True",
|
||||
"step_size",
|
||||
"500e3",
|
||||
"False",
|
||||
"--step_size",
|
||||
"500",
|
||||
"--dataset",
|
||||
"cifar10",
|
||||
"--num_workers",
|
||||
@ -31,7 +31,9 @@
|
||||
"--ratio_list",
|
||||
"1/3",
|
||||
"--snr_list",
|
||||
"100"
|
||||
"100",
|
||||
"--seed",
|
||||
"42"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@ -10,7 +10,7 @@ This is my first time to use PyTorch and git to reproduce a paper, so there may
|
||||
I spend 3 days to reproduce the paper, and i get the result as follow. The result is not good, because i trained the model on cifar10 which is 32\*32 but test on kodim which is 768\*512 and the model is not trained enough.
|
||||
|
||||
That is all enough!!!
|
||||

|
||||

|
||||
|
||||
|
||||
## Installation
|
||||
@ -27,7 +27,7 @@ pip install requirements.txt
|
||||
Run(example presented in paper) on cifar10
|
||||
|
||||
```
|
||||
python train.py --lr 10e-3 --epochs 100 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list 1 4 7 13 19 --ratio_list 1/6 1/12 --dataset cifar10 --num_workers 4 --parallel True --if_scheduler True --scheduler_step_size 50
|
||||
python train.py --lr 1e-3 --epochs 1000 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list 1 4 7 13 19 --ratio_list 1/6 1/12 --dataset cifar10 --num_workers 4 --parallel True --if_scheduler True --scheduler_step_size 50
|
||||
```
|
||||
or Run(example presented in paper) on imagenet
|
||||
|
||||
|
||||
BIN
demo/1080_1080.jpg
Normal file
BIN
demo/1080_1080.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 94 KiB |
BIN
demo/kodim23.png
Normal file
BIN
demo/kodim23.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 544 KiB |
3
eval.py
3
eval.py
@ -22,7 +22,6 @@ def config_parser():
|
||||
def main():
|
||||
args = config_parser()
|
||||
transform = transforms.Compose([transforms.ToTensor()])
|
||||
args.saved = './saved/cifar10_50_0.33_100.00_41.pth' # to be deleted
|
||||
test_image = Image.open(args.test_image)
|
||||
test_image.load()
|
||||
test_image = transform(test_image)
|
||||
@ -43,7 +42,7 @@ def main():
|
||||
demo_image = torch.cat([test_image, demo_image], dim=1)
|
||||
demo_image = transforms.ToPILImage()(demo_image)
|
||||
temp = args.saved.split('/')[-1]
|
||||
demo_image.save('./run/{}.png'.format(args.saved.split('/')[-1]))
|
||||
demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1],args.test_image.split('/')[-1]))
|
||||
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
||||
|
||||
|
||||
|
||||
11
model.py
11
model.py
@ -32,7 +32,7 @@ def ratio2filtersize(x: torch.Tensor, ratio):
|
||||
encoder_temp = _Encoder(is_temp=True)
|
||||
z_temp = encoder_temp(x)
|
||||
c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
||||
return int(c) + 1
|
||||
return int(c)
|
||||
|
||||
|
||||
class _ConvWithPReLU(nn.Module):
|
||||
@ -40,6 +40,9 @@ class _ConvWithPReLU(nn.Module):
|
||||
super(_ConvWithPReLU, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
||||
self.prelu = nn.PReLU()
|
||||
|
||||
|
||||
nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
@ -53,7 +56,11 @@ class _TransConvWithPReLU(nn.Module):
|
||||
self.transconv = nn.ConvTranspose2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, output_padding)
|
||||
self.activate = activate
|
||||
|
||||
if activate == nn.PReLU():
|
||||
nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
else:
|
||||
nn.init.xavier_normal_(self.transconv.weight)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.transconv(x)
|
||||
x = self.activate(x)
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.5 MiB After Width: | Height: | Size: 1.6 MiB |
BIN
run/cifar10_1000_0.33_100.00_40.pth_kodim23.png
Normal file
BIN
run/cifar10_1000_0.33_100.00_40.pth_kodim23.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.4 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 1.9 MiB |
Binary file not shown.
BIN
saved/cifar10_1000_0.33_100.00_40.pth
Normal file
BIN
saved/cifar10_1000_0.33_100.00_40.pth
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
10
train.py
10
train.py
@ -98,7 +98,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
criterion = nn.MSELoss(reduction='mean').to(device)
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
if args.if_scheduler:
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)
|
||||
|
||||
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=True)
|
||||
for epoch in epoch_loop:
|
||||
@ -113,8 +113,8 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
run_loss += loss.item()
|
||||
if args.if_scheduler: # the scheduler is wrong before
|
||||
scheduler.step()
|
||||
if args.if_scheduler: # the scheduler is wrong before
|
||||
scheduler.step()
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
test_mse = 0.0
|
||||
@ -127,8 +127,8 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
test_mse += loss.item()
|
||||
model.train()
|
||||
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
||||
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}".format(
|
||||
epoch, run_loss/len(train_loader), test_mse/len(test_loader)))
|
||||
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f} lr:{}".format(
|
||||
epoch, run_loss/len(train_loader), test_mse/len(test_loader), optimizer.param_groups[0]['lr']))
|
||||
save_model(model, args.saved, args.saved +
|
||||
'/{}_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, c))
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user