from math import pi as PI
import torch
[docs]
class BesselBasisLayer(torch.nn.Module): # NOTE borrowed from DimeNet implementation
def __init__(self,
num_radial: int,
cutoff: float = 5.0,
envelope_exponent: int = 5):
super(BesselBasisLayer, self).__init__()
self.register_buffer('cutoff', torch.tensor(cutoff) if not torch.is_tensor(cutoff) else cutoff.clone().detach())
self.envelope = Envelope(envelope_exponent)
self.freq = torch.nn.Parameter(torch.Tensor(num_radial))
self.reset_parameters()
[docs]
def reset_parameters(self):
with torch.no_grad():
torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI)
self.freq.requires_grad_()
[docs]
def forward(self,
dist: torch.Tensor
) -> torch.Tensor:
dist = dist.unsqueeze(-1) / self.cutoff
envelope = self.envelope(dist)
envelope[dist > self.cutoff] = 0 # this function explodes beyond the cutoff
return envelope * (self.freq * dist).sin()
# return self.envelope(dist) * (self.freq * dist).sin()
[docs]
class Envelope(torch.nn.Module):
def __init__(self,
exponent: float):
super(Envelope, self).__init__()
self.p = exponent + 1
self.a = -(self.p + 1) * (self.p + 2) / 2
self.b = self.p * (self.p + 2)
self.c = -self.p * (self.p + 1) / 2
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
p, a, b, c = self.p, self.a, self.b, self.c
x_pow_p0 = x.pow(p - 1)
x_pow_p1 = x_pow_p0 * x
x_pow_p2 = x_pow_p1 * x
return 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2
[docs]
class GaussianEmbedding(torch.nn.Module):
def __init__(self,
start: float = 0.0,
stop: float = 5.0,
num_gaussians: int = 50):
super(GaussianEmbedding, self).__init__()
offset = torch.linspace(start, stop, num_gaussians)
coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
self.register_buffer('offset', offset)
self.register_buffer('coeff', torch.tensor([coeff], dtype=torch.float32))
[docs]
def forward(self, dist: torch.Tensor) -> torch.Tensor:
dist = dist.view(-1, 1) - self.offset.view(1, -1)
return torch.exp(self.coeff * torch.pow(dist, 2))