Update model.py

This commit is contained in:
ZHENG Chunhang 2024-01-16 23:00:12 +08:00 committed by GitHub
parent e142bb899a
commit cb6e52e183
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,7 +7,6 @@ Created on Tue Dec 11:00:00 2023
import torch
import torch.nn as nn
import numpy as np
import channel
@ -24,14 +23,17 @@ import channel
def ratio2filtersize(x: torch.Tensor, ratio):
if x.dim() == 4:
before_size = np.prod(x.size()[1:])
# before_size = np.prod(x.size()[1:])
before_size = torch.prod(torch.tensor(x.size()[1:]))
elif x.dim() == 3:
before_size = np.prod(x.size())
# before_size = np.prod(x.size())
before_size = torch.prod(torch.tensor(x.size()))
else:
raise Exception('Unknown size of input')
encoder_temp = _Encoder(is_temp=True)
z_temp = encoder_temp(x)
c = before_size * ratio / np.prod(z_temp.size()[-2:])
# c = before_size * ratio / np.prod(z_temp.size()[-2:])
c = before_size * ratio / torch.prod(torch.tensor(z_temp.size()[-2:]))
return int(c)
@ -85,10 +87,12 @@ class _Encoder(nn.Module):
def _inner(z_hat: torch.Tensor):
if z_hat.dim() == 4:
batch_size = z_hat.size()[0]
k = np.prod(z_hat.size()[1:])
# k = np.prod(z_hat.size()[1:])
k = torch.prod(torch.tensor(z_hat.size()[1:]))
elif z_hat.dim() == 3:
batch_size = 1
k = np.prod(z_hat.size())
# k = np.prod(z_hat.size())
k = torch.prod(torch.tensor(z_hat.size()))
else:
raise Exception('Unknown size of input')
k = torch.tensor(k)