Update model.py
This commit is contained in:
parent
e142bb899a
commit
cb6e52e183
16
model.py
16
model.py
@ -7,7 +7,6 @@ Created on Tue Dec 11:00:00 2023
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
|
||||||
import channel
|
import channel
|
||||||
|
|
||||||
|
|
||||||
@ -24,14 +23,17 @@ import channel
|
|||||||
|
|
||||||
def ratio2filtersize(x: torch.Tensor, ratio):
|
def ratio2filtersize(x: torch.Tensor, ratio):
|
||||||
if x.dim() == 4:
|
if x.dim() == 4:
|
||||||
before_size = np.prod(x.size()[1:])
|
# before_size = np.prod(x.size()[1:])
|
||||||
|
before_size = torch.prod(torch.tensor(x.size()[1:]))
|
||||||
elif x.dim() == 3:
|
elif x.dim() == 3:
|
||||||
before_size = np.prod(x.size())
|
# before_size = np.prod(x.size())
|
||||||
|
before_size = torch.prod(torch.tensor(x.size()))
|
||||||
else:
|
else:
|
||||||
raise Exception('Unknown size of input')
|
raise Exception('Unknown size of input')
|
||||||
encoder_temp = _Encoder(is_temp=True)
|
encoder_temp = _Encoder(is_temp=True)
|
||||||
z_temp = encoder_temp(x)
|
z_temp = encoder_temp(x)
|
||||||
c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
# c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
||||||
|
c = before_size * ratio / torch.prod(torch.tensor(z_temp.size()[-2:]))
|
||||||
return int(c)
|
return int(c)
|
||||||
|
|
||||||
|
|
||||||
@ -85,10 +87,12 @@ class _Encoder(nn.Module):
|
|||||||
def _inner(z_hat: torch.Tensor):
|
def _inner(z_hat: torch.Tensor):
|
||||||
if z_hat.dim() == 4:
|
if z_hat.dim() == 4:
|
||||||
batch_size = z_hat.size()[0]
|
batch_size = z_hat.size()[0]
|
||||||
k = np.prod(z_hat.size()[1:])
|
# k = np.prod(z_hat.size()[1:])
|
||||||
|
k = torch.prod(torch.tensor(z_hat.size()[1:]))
|
||||||
elif z_hat.dim() == 3:
|
elif z_hat.dim() == 3:
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
k = np.prod(z_hat.size())
|
# k = np.prod(z_hat.size())
|
||||||
|
k = torch.prod(torch.tensor(z_hat.size()))
|
||||||
else:
|
else:
|
||||||
raise Exception('Unknown size of input')
|
raise Exception('Unknown size of input')
|
||||||
k = torch.tensor(k)
|
k = torch.tensor(k)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user