From 8796899e4928c4d7d7e62a29f5793025de418a5f Mon Sep 17 00:00:00 2001 From: chun Date: Sun, 9 Jun 2024 01:48:10 +0800 Subject: [PATCH] Update model arch ref to #7 --- model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/model.py b/model.py index 7d52a0e..619a022 100644 --- a/model.py +++ b/model.py @@ -74,12 +74,12 @@ class _Encoder(nn.Module): super(_Encoder, self).__init__() self.is_temp = is_temp # self.imgae_normalization = _image_normalization(norm_type='nomalization') - self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2) - self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2) + self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2, padding=2) + self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=2) self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2) # padding size could be changed here self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2) - self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2) + self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=2*c, kernel_size=5, padding=2) self.norm = self._normlizationLayer(P=P) @staticmethod @@ -121,14 +121,14 @@ class _Decoder(nn.Module): super(_Decoder, self).__init__() # self.imgae_normalization = _image_normalization(norm_type='denormalization') self.tconv1 = _TransConvWithPReLU( - in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2) + in_channels=2*c, out_channels=32, kernel_size=5, stride=1, padding=2) self.tconv2 = _TransConvWithPReLU( in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2) self.tconv3 = _TransConvWithPReLU( in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2) - self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=6, stride=2) + self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=5, stride=2, padding=2, output_padding=1) self.tconv5 = _TransConvWithPReLU( - in_channels=16, out_channels=3, kernel_size=6, stride=2, activate=nn.Sigmoid()) + in_channels=16, out_channels=3, kernel_size=5, stride=2, padding=2, output_padding=1,activate=nn.Sigmoid()) # may be some problems in tconv4 and tconv5, the kernal_size is not the same as the paper which is 5 def forward(self, x): @@ -176,7 +176,7 @@ class DeepJSCC(nn.Module): if __name__ == '__main__': model = DeepJSCC(c=20) print(model) - x = torch.rand(1, 3, 32, 32) + x = torch.rand(1, 3, 128, 128) y = model(x) print(y.size()) print(y)