This commit is contained in:
chun 2023-12-26 15:46:17 +08:00
parent c756af358d
commit cc27436535
15 changed files with 25 additions and 17 deletions

14
.vscode/launch.json vendored
View File

@ -13,15 +13,15 @@
"justMyCode": true,
"args": [
"--lr",
"10e-3",
"1e-3",
"--epochs",
"1000",
"--batch_size",
"64",
"512",
"--if_scheduler",
"True",
"step_size",
"500e3",
"False",
"--step_size",
"500",
"--dataset",
"cifar10",
"--num_workers",
@ -31,7 +31,9 @@
"--ratio_list",
"1/3",
"--snr_list",
"100"
"100",
"--seed",
"42"
]
}
]

View File

@ -10,7 +10,7 @@ This is my first time to use PyTorch and git to reproduce a paper, so there may
I spend 3 days to reproduce the paper, and i get the result as follow. The result is not good, because i trained the model on cifar10 which is 32\*32 but test on kodim which is 768\*512 and the model is not trained enough.
That is all enough!!
![demo](./demo/demo.png)
![demo](./run/cifar10_1000_0.33_100.00_40.pth_kodim08.png)
## Installation
@ -27,7 +27,7 @@ pip install requirements.txt
Run(example presented in paper) on cifar10
```
python train.py --lr 10e-3 --epochs 100 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list 1 4 7 13 19 --ratio_list 1/6 1/12 --dataset cifar10 --num_workers 4 --parallel True --if_scheduler True --scheduler_step_size 50
python train.py --lr 1e-3 --epochs 1000 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list 1 4 7 13 19 --ratio_list 1/6 1/12 --dataset cifar10 --num_workers 4 --parallel True --if_scheduler True --scheduler_step_size 50
```
or Run(example presented in paper) on imagenet

BIN
demo/1080_1080.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

BIN
demo/kodim23.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 544 KiB

View File

@ -22,7 +22,6 @@ def config_parser():
def main():
args = config_parser()
transform = transforms.Compose([transforms.ToTensor()])
args.saved = './saved/cifar10_50_0.33_100.00_41.pth' # to be deleted
test_image = Image.open(args.test_image)
test_image.load()
test_image = transform(test_image)
@ -43,7 +42,7 @@ def main():
demo_image = torch.cat([test_image, demo_image], dim=1)
demo_image = transforms.ToPILImage()(demo_image)
temp = args.saved.split('/')[-1]
demo_image.save('./run/{}.png'.format(args.saved.split('/')[-1]))
demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1],args.test_image.split('/')[-1]))
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))

View File

@ -32,7 +32,7 @@ def ratio2filtersize(x: torch.Tensor, ratio):
encoder_temp = _Encoder(is_temp=True)
z_temp = encoder_temp(x)
c = before_size * ratio / np.prod(z_temp.size()[-2:])
return int(c) + 1
return int(c)
class _ConvWithPReLU(nn.Module):
@ -40,6 +40,9 @@ class _ConvWithPReLU(nn.Module):
super(_ConvWithPReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.prelu = nn.PReLU()
nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='leaky_relu')
def forward(self, x):
x = self.conv(x)
@ -53,7 +56,11 @@ class _TransConvWithPReLU(nn.Module):
self.transconv = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride, padding, output_padding)
self.activate = activate
if activate == nn.PReLU():
nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out', nonlinearity='leaky_relu')
else:
nn.init.xavier_normal_(self.transconv.weight)
def forward(self, x):
x = self.transconv(x)
x = self.activate(x)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.5 MiB

After

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.9 MiB

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -98,7 +98,7 @@ def train(args: config_parser(), ratio: float, snr: float):
criterion = nn.MSELoss(reduction='mean').to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.if_scheduler:
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=True)
for epoch in epoch_loop:
@ -113,8 +113,8 @@ def train(args: config_parser(), ratio: float, snr: float):
loss.backward()
optimizer.step()
run_loss += loss.item()
if args.if_scheduler: # the scheduler is wrong before
scheduler.step()
if args.if_scheduler: # the scheduler is wrong before
scheduler.step()
with torch.no_grad():
model.eval()
test_mse = 0.0
@ -127,8 +127,8 @@ def train(args: config_parser(), ratio: float, snr: float):
test_mse += loss.item()
model.train()
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}".format(
epoch, run_loss/len(train_loader), test_mse/len(test_loader)))
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f} lr:{}".format(
epoch, run_loss/len(train_loader), test_mse/len(test_loader), optimizer.param_groups[0]['lr']))
save_model(model, args.saved, args.saved +
'/{}_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, c))