Encoder test
This commit is contained in:
parent
64323f3ca4
commit
dc08bfa17e
BIN
environment.yml
BIN
environment.yml
Binary file not shown.
59
model.py
59
model.py
@ -1,5 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Tue Dec 11:00:00 2023
|
||||
|
||||
@author: chun
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class _ConvWithPReLU(nn.Module):
|
||||
@ -26,11 +34,50 @@ class _TransConvWithPReLU(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Nomalization(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super(Nomalization, self).__init__()
|
||||
self.norm = nn.BatchNorm2d(in_channels)
|
||||
def _image_normalization(tensor, norm_type):
|
||||
if norm_type == 'nomalization':
|
||||
return tensor / 255.0
|
||||
elif norm_type == 'denormalization':
|
||||
return tensor * 255.0
|
||||
else:
|
||||
raise Exception('Unknown type of normalization')
|
||||
|
||||
|
||||
def _NormlizationLayer(norm_type='nomalization'):
|
||||
pass
|
||||
|
||||
|
||||
def ratio2filter_size(x, ratio):
|
||||
before_size = np.prod(x.size())
|
||||
after_size = before_size*ratio
|
||||
encoder_temp = Encoder(c=after_size)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, c, is_temp=False):
|
||||
super(Encoder, self).__init__()
|
||||
self.is_temp = is_temp
|
||||
self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
||||
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
|
||||
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
|
||||
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1)
|
||||
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1)
|
||||
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, stride=1)
|
||||
self.norm = _NormlizationLayer(norm_type='nomalization')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
return x
|
||||
x = self.imgae_normalization(x)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
if not self.is_temp:
|
||||
x = self.conv5(x)
|
||||
z = self.norm(x)
|
||||
del x
|
||||
return z
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user