v1
This commit is contained in:
parent
c756af358d
commit
cc27436535
14
.vscode/launch.json
vendored
14
.vscode/launch.json
vendored
@ -13,15 +13,15 @@
|
|||||||
"justMyCode": true,
|
"justMyCode": true,
|
||||||
"args": [
|
"args": [
|
||||||
"--lr",
|
"--lr",
|
||||||
"10e-3",
|
"1e-3",
|
||||||
"--epochs",
|
"--epochs",
|
||||||
"1000",
|
"1000",
|
||||||
"--batch_size",
|
"--batch_size",
|
||||||
"64",
|
"512",
|
||||||
"--if_scheduler",
|
"--if_scheduler",
|
||||||
"True",
|
"False",
|
||||||
"step_size",
|
"--step_size",
|
||||||
"500e3",
|
"500",
|
||||||
"--dataset",
|
"--dataset",
|
||||||
"cifar10",
|
"cifar10",
|
||||||
"--num_workers",
|
"--num_workers",
|
||||||
@ -31,7 +31,9 @@
|
|||||||
"--ratio_list",
|
"--ratio_list",
|
||||||
"1/3",
|
"1/3",
|
||||||
"--snr_list",
|
"--snr_list",
|
||||||
"100"
|
"100",
|
||||||
|
"--seed",
|
||||||
|
"42"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@ -10,7 +10,7 @@ This is my first time to use PyTorch and git to reproduce a paper, so there may
|
|||||||
I spend 3 days to reproduce the paper, and i get the result as follow. The result is not good, because i trained the model on cifar10 which is 32\*32 but test on kodim which is 768\*512 and the model is not trained enough.
|
I spend 3 days to reproduce the paper, and i get the result as follow. The result is not good, because i trained the model on cifar10 which is 32\*32 but test on kodim which is 768\*512 and the model is not trained enough.
|
||||||
|
|
||||||
That is all enough!!!
|
That is all enough!!!
|
||||||

|

|
||||||
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
@ -27,7 +27,7 @@ pip install requirements.txt
|
|||||||
Run(example presented in paper) on cifar10
|
Run(example presented in paper) on cifar10
|
||||||
|
|
||||||
```
|
```
|
||||||
python train.py --lr 10e-3 --epochs 100 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list 1 4 7 13 19 --ratio_list 1/6 1/12 --dataset cifar10 --num_workers 4 --parallel True --if_scheduler True --scheduler_step_size 50
|
python train.py --lr 1e-3 --epochs 1000 --batch_size 64 --channel 'AWGN' --saved ./saved --snr_list 1 4 7 13 19 --ratio_list 1/6 1/12 --dataset cifar10 --num_workers 4 --parallel True --if_scheduler True --scheduler_step_size 50
|
||||||
```
|
```
|
||||||
or Run(example presented in paper) on imagenet
|
or Run(example presented in paper) on imagenet
|
||||||
|
|
||||||
|
|||||||
BIN
demo/1080_1080.jpg
Normal file
BIN
demo/1080_1080.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 94 KiB |
BIN
demo/kodim23.png
Normal file
BIN
demo/kodim23.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 544 KiB |
3
eval.py
3
eval.py
@ -22,7 +22,6 @@ def config_parser():
|
|||||||
def main():
|
def main():
|
||||||
args = config_parser()
|
args = config_parser()
|
||||||
transform = transforms.Compose([transforms.ToTensor()])
|
transform = transforms.Compose([transforms.ToTensor()])
|
||||||
args.saved = './saved/cifar10_50_0.33_100.00_41.pth' # to be deleted
|
|
||||||
test_image = Image.open(args.test_image)
|
test_image = Image.open(args.test_image)
|
||||||
test_image.load()
|
test_image.load()
|
||||||
test_image = transform(test_image)
|
test_image = transform(test_image)
|
||||||
@ -43,7 +42,7 @@ def main():
|
|||||||
demo_image = torch.cat([test_image, demo_image], dim=1)
|
demo_image = torch.cat([test_image, demo_image], dim=1)
|
||||||
demo_image = transforms.ToPILImage()(demo_image)
|
demo_image = transforms.ToPILImage()(demo_image)
|
||||||
temp = args.saved.split('/')[-1]
|
temp = args.saved.split('/')[-1]
|
||||||
demo_image.save('./run/{}.png'.format(args.saved.split('/')[-1]))
|
demo_image.save('./run/{}_{}'.format(args.saved.split('/')[-1],args.test_image.split('/')[-1]))
|
||||||
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
print("psnr on {} is {}".format(args.test_image, psnr_all / args.times))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
11
model.py
11
model.py
@ -32,7 +32,7 @@ def ratio2filtersize(x: torch.Tensor, ratio):
|
|||||||
encoder_temp = _Encoder(is_temp=True)
|
encoder_temp = _Encoder(is_temp=True)
|
||||||
z_temp = encoder_temp(x)
|
z_temp = encoder_temp(x)
|
||||||
c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
||||||
return int(c) + 1
|
return int(c)
|
||||||
|
|
||||||
|
|
||||||
class _ConvWithPReLU(nn.Module):
|
class _ConvWithPReLU(nn.Module):
|
||||||
@ -40,6 +40,9 @@ class _ConvWithPReLU(nn.Module):
|
|||||||
super(_ConvWithPReLU, self).__init__()
|
super(_ConvWithPReLU, self).__init__()
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
||||||
self.prelu = nn.PReLU()
|
self.prelu = nn.PReLU()
|
||||||
|
|
||||||
|
|
||||||
|
nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
@ -53,7 +56,11 @@ class _TransConvWithPReLU(nn.Module):
|
|||||||
self.transconv = nn.ConvTranspose2d(
|
self.transconv = nn.ConvTranspose2d(
|
||||||
in_channels, out_channels, kernel_size, stride, padding, output_padding)
|
in_channels, out_channels, kernel_size, stride, padding, output_padding)
|
||||||
self.activate = activate
|
self.activate = activate
|
||||||
|
if activate == nn.PReLU():
|
||||||
|
nn.init.kaiming_normal_(self.transconv.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||||
|
else:
|
||||||
|
nn.init.xavier_normal_(self.transconv.weight)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.transconv(x)
|
x = self.transconv(x)
|
||||||
x = self.activate(x)
|
x = self.activate(x)
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.5 MiB After Width: | Height: | Size: 1.6 MiB |
BIN
run/cifar10_1000_0.33_100.00_40.pth_kodim23.png
Normal file
BIN
run/cifar10_1000_0.33_100.00_40.pth_kodim23.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.4 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 1.9 MiB |
Binary file not shown.
BIN
saved/cifar10_1000_0.33_100.00_40.pth
Normal file
BIN
saved/cifar10_1000_0.33_100.00_40.pth
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
10
train.py
10
train.py
@ -98,7 +98,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
criterion = nn.MSELoss(reduction='mean').to(device)
|
criterion = nn.MSELoss(reduction='mean').to(device)
|
||||||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
if args.if_scheduler:
|
if args.if_scheduler:
|
||||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)
|
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)
|
||||||
|
|
||||||
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=True)
|
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=True)
|
||||||
for epoch in epoch_loop:
|
for epoch in epoch_loop:
|
||||||
@ -113,8 +113,8 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
run_loss += loss.item()
|
run_loss += loss.item()
|
||||||
if args.if_scheduler: # the scheduler is wrong before
|
if args.if_scheduler: # the scheduler is wrong before
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.eval()
|
model.eval()
|
||||||
test_mse = 0.0
|
test_mse = 0.0
|
||||||
@ -127,8 +127,8 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
test_mse += loss.item()
|
test_mse += loss.item()
|
||||||
model.train()
|
model.train()
|
||||||
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
||||||
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}".format(
|
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f} lr:{}".format(
|
||||||
epoch, run_loss/len(train_loader), test_mse/len(test_loader)))
|
epoch, run_loss/len(train_loader), test_mse/len(test_loader), optimizer.param_groups[0]['lr']))
|
||||||
save_model(model, args.saved, args.saved +
|
save_model(model, args.saved, args.saved +
|
||||||
'/{}_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, c))
|
'/{}_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, c))
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user