Source code for mxtaltools.models.modules.basis_functions

from math import pi as PI

import torch


[docs] class BesselBasisLayer(torch.nn.Module): # NOTE borrowed from DimeNet implementation def __init__(self, num_radial: int, cutoff: float = 5.0, envelope_exponent: int = 5): super(BesselBasisLayer, self).__init__() self.register_buffer('cutoff', torch.tensor(cutoff) if not torch.is_tensor(cutoff) else cutoff.clone().detach()) self.envelope = Envelope(envelope_exponent) self.freq = torch.nn.Parameter(torch.Tensor(num_radial)) self.reset_parameters()
[docs] def reset_parameters(self): with torch.no_grad(): torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI) self.freq.requires_grad_()
[docs] def forward(self, dist: torch.Tensor ) -> torch.Tensor: dist = dist.unsqueeze(-1) / self.cutoff envelope = self.envelope(dist) envelope[dist > self.cutoff] = 0 # this function explodes beyond the cutoff return envelope * (self.freq * dist).sin()
# return self.envelope(dist) * (self.freq * dist).sin()
[docs] class Envelope(torch.nn.Module): def __init__(self, exponent: float): super(Envelope, self).__init__() self.p = exponent + 1 self.a = -(self.p + 1) * (self.p + 2) / 2 self.b = self.p * (self.p + 2) self.c = -self.p * (self.p + 1) / 2
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: p, a, b, c = self.p, self.a, self.b, self.c x_pow_p0 = x.pow(p - 1) x_pow_p1 = x_pow_p0 * x x_pow_p2 = x_pow_p1 * x return 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2
[docs] class GaussianEmbedding(torch.nn.Module): def __init__(self, start: float = 0.0, stop: float = 5.0, num_gaussians: int = 50): super(GaussianEmbedding, self).__init__() offset = torch.linspace(start, stop, num_gaussians) coeff = -0.5 / (offset[1] - offset[0]).item() ** 2 self.register_buffer('offset', offset) self.register_buffer('coeff', torch.tensor([coeff], dtype=torch.float32))
[docs] def forward(self, dist: torch.Tensor) -> torch.Tensor: dist = dist.view(-1, 1) - self.offset.view(1, -1) return torch.exp(self.coeff * torch.pow(dist, 2))