Update model.py
This commit is contained in:
parent
e142bb899a
commit
cb6e52e183
16
model.py
16
model.py
@ -7,7 +7,6 @@ Created on Tue Dec 11:00:00 2023
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import channel
|
||||
|
||||
|
||||
@ -24,14 +23,17 @@ import channel
|
||||
|
||||
def ratio2filtersize(x: torch.Tensor, ratio):
|
||||
if x.dim() == 4:
|
||||
before_size = np.prod(x.size()[1:])
|
||||
# before_size = np.prod(x.size()[1:])
|
||||
before_size = torch.prod(torch.tensor(x.size()[1:]))
|
||||
elif x.dim() == 3:
|
||||
before_size = np.prod(x.size())
|
||||
# before_size = np.prod(x.size())
|
||||
before_size = torch.prod(torch.tensor(x.size()))
|
||||
else:
|
||||
raise Exception('Unknown size of input')
|
||||
encoder_temp = _Encoder(is_temp=True)
|
||||
z_temp = encoder_temp(x)
|
||||
c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
||||
# c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
||||
c = before_size * ratio / torch.prod(torch.tensor(z_temp.size()[-2:]))
|
||||
return int(c)
|
||||
|
||||
|
||||
@ -85,10 +87,12 @@ class _Encoder(nn.Module):
|
||||
def _inner(z_hat: torch.Tensor):
|
||||
if z_hat.dim() == 4:
|
||||
batch_size = z_hat.size()[0]
|
||||
k = np.prod(z_hat.size()[1:])
|
||||
# k = np.prod(z_hat.size()[1:])
|
||||
k = torch.prod(torch.tensor(z_hat.size()[1:]))
|
||||
elif z_hat.dim() == 3:
|
||||
batch_size = 1
|
||||
k = np.prod(z_hat.size())
|
||||
# k = np.prod(z_hat.size())
|
||||
k = torch.prod(torch.tensor(z_hat.size()))
|
||||
else:
|
||||
raise Exception('Unknown size of input')
|
||||
k = torch.tensor(k)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user