Update model.py

This commit is contained in:
ZHENG Chunhang 2024-01-16 23:02:10 +08:00 committed by GitHub
parent cb6e52e183
commit ab409da89e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -95,7 +95,7 @@ class _Encoder(nn.Module):
k = torch.prod(torch.tensor(z_hat.size()))
else:
raise Exception('Unknown size of input')
k = torch.tensor(k)
# k = torch.tensor(k)
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))