Update model.py
This commit is contained in:
parent
cb6e52e183
commit
ab409da89e
2
model.py
2
model.py
@ -95,7 +95,7 @@ class _Encoder(nn.Module):
|
||||
k = torch.prod(torch.tensor(z_hat.size()))
|
||||
else:
|
||||
raise Exception('Unknown size of input')
|
||||
k = torch.tensor(k)
|
||||
# k = torch.tensor(k)
|
||||
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
|
||||
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
|
||||
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user