diff --git a/environment.yml b/environment.yml deleted file mode 100644 index 6507ad1..0000000 Binary files a/environment.yml and /dev/null differ diff --git a/model.py b/model.py index 5ac275b..2a444c8 100644 --- a/model.py +++ b/model.py @@ -1,5 +1,13 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Dec 11:00:00 2023 + +@author: chun +""" + import torch import torch.nn as nn +import numpy as np class _ConvWithPReLU(nn.Module): @@ -26,11 +34,50 @@ class _TransConvWithPReLU(nn.Module): return x -class Nomalization(nn.Module): - def __init__(self, in_channels): - super(Nomalization, self).__init__() - self.norm = nn.BatchNorm2d(in_channels) +def _image_normalization(tensor, norm_type): + if norm_type == 'nomalization': + return tensor / 255.0 + elif norm_type == 'denormalization': + return tensor * 255.0 + else: + raise Exception('Unknown type of normalization') + + +def _NormlizationLayer(norm_type='nomalization'): + pass + + +def ratio2filter_size(x, ratio): + before_size = np.prod(x.size()) + after_size = before_size*ratio + encoder_temp = Encoder(c=after_size) + + +class Encoder(nn.Module): + def __init__(self, c, is_temp=False): + super(Encoder, self).__init__() + self.is_temp = is_temp + self.imgae_normalization = _image_normalization(norm_type='nomalization') + self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2) + self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2) + self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1) + self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1) + self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, stride=1) + self.norm = _NormlizationLayer(norm_type='nomalization') def forward(self, x): - x = self.norm(x) - return x + x = self.imgae_normalization(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + if not self.is_temp: + x = self.conv5(x) + z = self.norm(x) + del x + return z + + +class Decoder(nn.Module): + def __init__(self): + super(Decoder, self).__init__() diff --git a/train.py b/train.py index e69de29..fa323c5 100644 --- a/train.py +++ b/train.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Dec 11:00:00 2023 + +@author: chun +""" + +import torch +import torch.nn as nn + + + +def main(): + pass + + + + + +if __name__ == '__main__': + main() \ No newline at end of file