Update model.py
This commit is contained in:
parent
cb6e52e183
commit
ab409da89e
2
model.py
2
model.py
@ -95,7 +95,7 @@ class _Encoder(nn.Module):
|
|||||||
k = torch.prod(torch.tensor(z_hat.size()))
|
k = torch.prod(torch.tensor(z_hat.size()))
|
||||||
else:
|
else:
|
||||||
raise Exception('Unknown size of input')
|
raise Exception('Unknown size of input')
|
||||||
k = torch.tensor(k)
|
# k = torch.tensor(k)
|
||||||
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
|
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
|
||||||
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
|
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
|
||||||
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
|
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user