Update model arch ref to #7
This commit is contained in:
parent
b7bd3bdd42
commit
8796899e49
14
model.py
14
model.py
@ -74,12 +74,12 @@ class _Encoder(nn.Module):
|
|||||||
super(_Encoder, self).__init__()
|
super(_Encoder, self).__init__()
|
||||||
self.is_temp = is_temp
|
self.is_temp = is_temp
|
||||||
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
# self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
||||||
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
|
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2, padding=2)
|
||||||
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
|
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=2)
|
||||||
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32,
|
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32,
|
||||||
kernel_size=5, padding=2) # padding size could be changed here
|
kernel_size=5, padding=2) # padding size could be changed here
|
||||||
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2)
|
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2)
|
||||||
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2)
|
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=2*c, kernel_size=5, padding=2)
|
||||||
self.norm = self._normlizationLayer(P=P)
|
self.norm = self._normlizationLayer(P=P)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -121,14 +121,14 @@ class _Decoder(nn.Module):
|
|||||||
super(_Decoder, self).__init__()
|
super(_Decoder, self).__init__()
|
||||||
# self.imgae_normalization = _image_normalization(norm_type='denormalization')
|
# self.imgae_normalization = _image_normalization(norm_type='denormalization')
|
||||||
self.tconv1 = _TransConvWithPReLU(
|
self.tconv1 = _TransConvWithPReLU(
|
||||||
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
|
in_channels=2*c, out_channels=32, kernel_size=5, stride=1, padding=2)
|
||||||
self.tconv2 = _TransConvWithPReLU(
|
self.tconv2 = _TransConvWithPReLU(
|
||||||
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
|
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
|
||||||
self.tconv3 = _TransConvWithPReLU(
|
self.tconv3 = _TransConvWithPReLU(
|
||||||
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
|
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
|
||||||
self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=6, stride=2)
|
self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=5, stride=2, padding=2, output_padding=1)
|
||||||
self.tconv5 = _TransConvWithPReLU(
|
self.tconv5 = _TransConvWithPReLU(
|
||||||
in_channels=16, out_channels=3, kernel_size=6, stride=2, activate=nn.Sigmoid())
|
in_channels=16, out_channels=3, kernel_size=5, stride=2, padding=2, output_padding=1,activate=nn.Sigmoid())
|
||||||
# may be some problems in tconv4 and tconv5, the kernal_size is not the same as the paper which is 5
|
# may be some problems in tconv4 and tconv5, the kernal_size is not the same as the paper which is 5
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -176,7 +176,7 @@ class DeepJSCC(nn.Module):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = DeepJSCC(c=20)
|
model = DeepJSCC(c=20)
|
||||||
print(model)
|
print(model)
|
||||||
x = torch.rand(1, 3, 32, 32)
|
x = torch.rand(1, 3, 128, 128)
|
||||||
y = model(x)
|
y = model(x)
|
||||||
print(y.size())
|
print(y.size())
|
||||||
print(y)
|
print(y)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user