mxtaltools.models.modules.vector_LayerNorm
- class mxtaltools.models.modules.vector_LayerNorm.VectorLayerNorm(in_channels: int, eps: float = 1e-05, affine: bool = True, mode: str = 'graph')[source]
Bases:
Module# TODO confirm layer vs batch norm behavior Simplified graphwise layernorm operating on the norms of vectors only based on torch gnn layernorm
- forward(v: Tensor, batch: LongTensor | None = None, batch_size: int | None = None) Tensor[source]
- Parameters:
v (torch.Tensor) – The source tensor, vector [nx3xk].
batch (torch.Tensor, optional) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each element to a specific example. (default:
None)batch_size (int, optional) – The number of examples \(B\). Automatically calculated if not given. (default:
None)