mxtaltools.models.modules.vector_LayerNorm

class mxtaltools.models.modules.vector_LayerNorm.VectorLayerNorm(in_channels: int, eps: float = 1e-05, affine: bool = True, mode: str = 'graph')[source]

Bases: Module

# TODO confirm layer vs batch norm behavior Simplified graphwise layernorm operating on the norms of vectors only based on torch gnn layernorm

forward(v: Tensor, batch: LongTensor | None = None, batch_size: int | None = None) Tensor[source]
Parameters:
  • v (torch.Tensor) – The source tensor, vector [nx3xk].

  • batch (torch.Tensor, optional) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each element to a specific example. (default: None)

  • batch_size (int, optional) – The number of examples \(B\). Automatically calculated if not given. (default: None)

reset_parameters()[source]

Resets all learnable parameters of the module.