Encoder test

This commit is contained in:
chun 2023-12-19 21:33:50 +08:00
parent 64323f3ca4
commit dc08bfa17e
3 changed files with 74 additions and 6 deletions

Binary file not shown.

View File

@ -1,5 +1,13 @@
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 11:00:00 2023
@author: chun
"""
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
class _ConvWithPReLU(nn.Module): class _ConvWithPReLU(nn.Module):
@ -26,11 +34,50 @@ class _TransConvWithPReLU(nn.Module):
return x return x
class Nomalization(nn.Module): def _image_normalization(tensor, norm_type):
def __init__(self, in_channels): if norm_type == 'nomalization':
super(Nomalization, self).__init__() return tensor / 255.0
self.norm = nn.BatchNorm2d(in_channels) elif norm_type == 'denormalization':
return tensor * 255.0
else:
raise Exception('Unknown type of normalization')
def _NormlizationLayer(norm_type='nomalization'):
pass
def ratio2filter_size(x, ratio):
before_size = np.prod(x.size())
after_size = before_size*ratio
encoder_temp = Encoder(c=after_size)
class Encoder(nn.Module):
def __init__(self, c, is_temp=False):
super(Encoder, self).__init__()
self.is_temp = is_temp
self.imgae_normalization = _image_normalization(norm_type='nomalization')
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1)
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1)
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, stride=1)
self.norm = _NormlizationLayer(norm_type='nomalization')
def forward(self, x): def forward(self, x):
x = self.norm(x) x = self.imgae_normalization(x)
return x x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
if not self.is_temp:
x = self.conv5(x)
z = self.norm(x)
del x
return z
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 11:00:00 2023
@author: chun
"""
import torch
import torch.nn as nn
def main():
pass
if __name__ == '__main__':
main()