Source code for mxtaltools.models.modules.vector_LayerNorm

from typing import Optional

import torch
from torch import Tensor
from torch.nn import Parameter
from torch_geometric.nn.inits import ones
from torch_geometric.utils import scatter


[docs] class VectorLayerNorm(torch.nn.Module): r""" # TODO confirm layer vs batch norm behavior Simplified graphwise layernorm operating on the norms of vectors only based on torch gnn layernorm """ def __init__( self, in_channels: int, eps: float = 1e-5, affine: bool = True, mode: str = 'graph', ): super().__init__() self.in_channels = in_channels self.eps = eps self.affine = affine self.mode = mode if affine: self.weight = Parameter(torch.empty(in_channels)) else: self.register_parameter('weight', None) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" ones(self.weight)
[docs] def forward(self, v: Tensor, batch: Optional[torch.LongTensor] = None, batch_size: Optional[int] = None) -> Tensor: r""" Args: v (torch.Tensor): The source tensor, vector [nx3xk]. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ if self.mode == 'graph': if batch is None: # assumes whole input is single graph norm = torch.linalg.norm(v, dim=1).mean(0) out = v / (norm + self.eps)[None, None, :] else: # take norms graph-wise if batch_size is None: batch_size = int(batch.max()) + 1 norm = torch.linalg.norm(v, dim=1) mean = scatter(norm, batch, dim=0, dim_size=batch_size, reduce='mean') out = v / (mean.index_select(0, batch) + self.eps)[:, None, :] if self.weight is not None: out = out * self.weight return out if self.mode == 'node': # separate norms node-by-node norm = torch.linalg.norm(v, dim=1) out = v / (norm + self.eps)[:, None, :] if self.weight is not None: out = out * self.weight return out raise ValueError(f"Unknown normalization mode: {self.mode}")
def __repr__(self): return (f'{self.__class__.__name__}({self.in_channels}, ' f'affine={self.affine}, mode={self.mode})')