Update model arch ref to #7

This commit is contained in:
chun 2024-06-09 01:48:10 +08:00
parent b7bd3bdd42
commit 8796899e49

View File

@ -74,12 +74,12 @@ class _Encoder(nn.Module):
super(_Encoder, self).__init__()
self.is_temp = is_temp
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2, padding=2)
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=2)
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32,
kernel_size=5, padding=2) # padding size could be changed here
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2)
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2)
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=2*c, kernel_size=5, padding=2)
self.norm = self._normlizationLayer(P=P)
@staticmethod
@ -121,14 +121,14 @@ class _Decoder(nn.Module):
super(_Decoder, self).__init__()
# self.imgae_normalization = _image_normalization(norm_type='denormalization')
self.tconv1 = _TransConvWithPReLU(
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
in_channels=2*c, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv2 = _TransConvWithPReLU(
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv3 = _TransConvWithPReLU(
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=6, stride=2)
self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=5, stride=2, padding=2, output_padding=1)
self.tconv5 = _TransConvWithPReLU(
in_channels=16, out_channels=3, kernel_size=6, stride=2, activate=nn.Sigmoid())
in_channels=16, out_channels=3, kernel_size=5, stride=2, padding=2, output_padding=1,activate=nn.Sigmoid())
# may be some problems in tconv4 and tconv5, the kernal_size is not the same as the paper which is 5
def forward(self, x):
@ -176,7 +176,7 @@ class DeepJSCC(nn.Module):
if __name__ == '__main__':
model = DeepJSCC(c=20)
print(model)
x = torch.rand(1, 3, 32, 32)
x = torch.rand(1, 3, 128, 128)
y = model(x)
print(y.size())
print(y)