import sys
from typing import Optional, Tuple
import numpy as np
import torch
from torch import nn as nn, Tensor
from torch.nn import functional as F
from torch_geometric import nn as gnn
from mxtaltools.models.modules.vector_LayerNorm import VectorLayerNorm
[docs]
class Scalarizer(nn.Module):
[docs]
def __init__(self,
hidden_dim: int,
embedding_dim: int,
norm_mode: str,
act_func: str,
dropout: float = 0,
output_dim: int = None,
):
"""
Generate a learned invariant representation of dimension :math:`(k)` from a list of vectors of dimension :math:`v=(k x 3)`.
Generate m vectors as a linear combination of the k vectors of v, take their normalized dot products with the components of v, concatenate to the norms of v, and linearly combine to the so-called scalarized representation of v.
Args:
hidden_dim (int): feature depth of input and output
embedding_dim (int): number of vectors to use for dot-product projection
norm_mode (str): type of normalization to use
act_func (str): type of activation to use
dropout (float): dropout probability
"""
super(Scalarizer, self).__init__()
if output_dim is None:
output_dim = hidden_dim
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
reduced_dim = max(hidden_dim // 4, 1)
#self.linear = nn.Linear(int(hidden_dim * (1 + embedding_dim)), hidden_dim, bias=True)
self.embedding = nn.Linear(reduced_dim, embedding_dim, bias=False)
self.dim_red = nn.Linear(hidden_dim, reduced_dim, bias=False)
self.linear = nn.Linear(reduced_dim * 4, output_dim, bias=True)
self.norm = Normalization(norm_mode, output_dim)
self.activation = Activation(act_func, output_dim)
self.dropout = nn.Dropout(p=dropout)
[docs]
def forward(self,
v: torch.Tensor
) -> torch.Tensor:
eps = 1e-3
v_red = self.dim_red(v)
norm = torch.linalg.norm(v_red, dim=1)
normed_v_red = v_red / (norm[:, None, :] + eps)
directions = self.embedding(v_red)
normed_directions = directions / (torch.linalg.norm(directions, dim=1)[:, None, :] + eps)
projections = torch.einsum('nik,nij->njk', normed_v_red, normed_directions)
v2 = torch.cat([norm, projections.flatten(1)], dim=1)
return self.dropout(self.activation(self.norm(self.linear(v2))))
[docs]
class VectorActivation(nn.Module):
r"""
Modified implementation of the vector activation function from https://github.com/FlyingGiraffe/vnn/blob/master/models/vn_layers.py
Generates an axis as a learned linear combination of input v, then the normalized overlaps of all the components of v.
Applies an activation function on the vector overlaps, such that, e.g., for ReLU activation, vectors with negative overlap are rotated to be perpendicular to the learned axis (zero overlap) and vectors with positive overlap are untouched.
Args:
hidden_dim (int): feature depth of input/output vectors, :math:`(k\times 3)`
act_func (str): activation function to apply to the normalized vector overlaps with the learned axis
"""
def __init__(self,
hidden_dim: int,
act_func: str):
super(VectorActivation, self).__init__()
self.embedding = nn.Linear(hidden_dim, 1, bias=False)
self.activation = Activation(act_func, hidden_dim)
[docs]
def forward(self,
v: torch.Tensor
) -> torch.Tensor:
direction = self.embedding(v)[..., -1]
normed_direction = direction / (torch.linalg.norm(direction, dim=1, keepdim=True) + 1e-5)
projection = torch.einsum('nik,ni->nk', v, normed_direction).clip(max=0)
correction = -self.activation(projection[..., None]) * normed_direction[:, None, :]
activated_output = v + correction.permute(0, 2, 1)
return activated_output
'''
# tests # todo write equivariance check
# assert projection.max() <= 1
# assert projeciton.min() >= -1
# assert torch.einsum('nik,ni->nk', activated_output, normed_direction).min() >= 1e-3, "Vector Activation Failed"
import numpy as np
import plotly.graph_objects as go
fig = go.Figure()
fig.add_histogram(x=projection.flatten(),nbinsx=100)
fig.add_histogram(x=projection2.flatten(), nbinsx=100)
fig.show(renderer='browser')
'''
# noinspection PyAttributeOutsideInit
[docs]
class Normalization(nn.Module):
r"""
Wrapper module for several normalization options
Args:
norm (str): type of normalization function
filters (int): feature depth of objects to be normalized
"""
def __init__(self,
norm: str,
filters: int):
super().__init__()
self.norm_type = norm
if norm == 'batch':
self.norm = nn.BatchNorm1d(filters)
elif norm == 'graph layer':
self.norm = gnn.LayerNorm(filters)
elif norm == 'layer':
self.norm = nn.LayerNorm(filters)
elif norm == 'instance':
self.norm = nn.InstanceNorm1d(filters) # not tested
elif norm == 'graph':
self.norm = gnn.GraphNorm(filters)
elif norm == 'graph vector layer':
self.norm = VectorLayerNorm(filters, mode='graph')
elif norm == 'vector layer':
self.norm = VectorLayerNorm(filters, mode='node')
elif norm is None:
self.norm = nn.Identity()
else:
print(norm + " is not a valid normalization")
sys.exit()
[docs]
def forward(self,
x: torch.Tensor,
batch: Optional[torch.LongTensor] = None):
if batch is not None and self.norm_type != 'batch' and self.norm_type != 'layer' and self.norm_type is not None:
return self.norm(x, batch)
return self.norm(x)
[docs]
class Activation(nn.Module):
r"""
Wrapper module for several activation options
Args:
activation_func (str): type of activation function
filters (int): feature depth of objects to be normalized
"""
def __init__(self,
activation_func: str,
filters: int):
super().__init__()
if activation_func is not None:
if activation_func.lower() == 'relu':
self.activation = F.relu
elif activation_func.lower() == 'gelu':
self.activation = F.gelu
elif activation_func.lower() == 'kernel': # rather expensive
self.activation = KernelActivation(n_basis=10, span=4, filters=filters)
elif activation_func.lower() == 'leaky relu':
self.activation = F.leaky_relu
elif activation_func.lower() == 'tanh':
self.activation = F.tanh
elif activation_func.lower() == 'sigmoid':
self.activation = F.sigmoid
elif activation_func is None:
self.activation = nn.Identity()
[docs]
def forward(self,
x: torch.Tensor
) -> torch.Tensor:
return self.activation(x)
[docs]
class KernelActivation(nn.Module):
r"""
Function for learning an activation function for every node in a given layer, as a linear combination of basis functions over a given span.
Uses nn.Conv1d groups option for efficient evaluation.
Args:
n_basis (int): number of basis functions
span (float): span over which to define localized basis functions
filters (int): feature depth of inputs to be activated
"""
def __init__(self,
n_basis: int,
span: float,
filters: int):
super(KernelActivation, self).__init__()
self.channels, self.n_basis = filters, n_basis
# define the space of basis functions
self.register_buffer('dict',
torch.linspace(-span, span, n_basis)) # positive and negative values for Dirichlet Kernel
gamma = 1 / (6 * (self.dict[-1] - self.dict[
-2]) ** 2) # optimum gaussian spacing parameter should be equal to 1/(6*spacing^2) according to KAFnet paper
self.register_buffer('gamma', torch.ones(1) * gamma) #
# self.register_buffer('dict', torch.linspace(0, n_basis-1, n_basis)) # positive values for ReLU kernel
# define module to learn parameters
# 1d convolutions allow for grouping of terms, unlike nn.linear which is always fully-connected.
# #This way should be fast and efficient, and play nice with pytorch optim
self.linear = nn.Conv1d(filters * n_basis, filters, kernel_size=(1, 1), groups=int(filters), bias=False)
# nn.init.normal(self.linear.weight.data, std=0.1)
[docs]
def kernel(self,
x: torch.Tensor
) -> torch.Tensor:
# x has dimention batch, features, y, x
# must return object of dimension batch, features, y, x, basis
x = x.unsqueeze(2)
if len(x) == 2:
x = x.reshape(2, self.channels, 1)
return torch.exp(-self.gamma * (x - self.dict) ** 2)
[docs]
def forward(self,
x: torch.Tensor
) -> torch.Tensor:
x = self.kernel(x).unsqueeze(-1).unsqueeze(-1) # run activation, output shape batch, features, y, x, basis
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3],
x.shape[4]) # concatenate basis functions with filters
x = self.linear(x).squeeze(-1).squeeze(-1) # apply linear coefficients and sum
return x
# noinspection PyAttributeOutsideInit
[docs]
class scalarMLP(nn.Module): # todo simplify and smooth out +1's and other custom methods for a general depth controller
r"""
Flexible multi-layer perceptron module, with several options.
Args:
layers (int): number of fully-connected layers
filters (int): feature depth with FC layers
input_dim (int): feature depth of inputs
output_dim (int): feature depth of outputs
activation (str): activation function
seed (int): random seed
dropout (float): dropout probability
ramp_depth (bool): whether to ramp the feature depth exponentially from input_dim to output_dim through the network
"""
def __init__(self,
layers: int,
filters: int,
input_dim: int,
output_dim: int,
activation: str = 'gelu',
seed: int = 0,
dropout: float = 0,
norm: 'str' = None,
norm_after_linear: bool = True,
bias: bool = True,
ramp_depth: bool = False,
):
super(scalarMLP, self).__init__()
# initialize constants and layers
self.n_layers = layers
self.output_dim = output_dim
self.input_dim = input_dim
self.norm_mode = norm
self.dropout_p = dropout
self.activation = activation
self.bias = bias
self.norm_after_linear = norm_after_linear
self.ramp_depth = ramp_depth
torch.manual_seed(seed)
self.init_scalar_filters(filters)
self.init_scalar_transforms()
[docs]
def init_scalar_filters(self, filters):
"""
returns a list with layers + 1 integer elements
"""
if self.ramp_depth: # smoothly ramp feature depth across layers
# linear scaling
# self.n_filters = torch.linspace(self.input_dim, self.output_dim, self.n_layers).long().tolist()
# log scaling for consistent growth ratio
p = np.log(self.output_dim) / np.log(self.input_dim)
n_filters = [self.input_dim] + [int(self.input_dim ** (1 + (p - 1) * (i / self.n_layers))) for i in
range(self.n_layers)]
else:
n_filters = [self.input_dim] + [filters for _ in range(self.n_layers)]
self.s_filters_in = n_filters[:-1]
self.s_filters_out = n_filters[1:]
if n_filters.count(n_filters[0]) != len(
n_filters): # if they are not all the same, we need residue adjustments
self.same_depth = False
self.residue_adjust = torch.nn.ModuleList([
nn.Linear(self.s_filters_in[i], self.s_filters_out[i], bias=False)
for i in range(self.n_layers)
])
else:
self.same_depth = True
[docs]
def forward(self,
x: torch.Tensor,
return_latent: bool = False,
batch: Optional[torch.LongTensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""initialize to correct feature dimension"""
for i, (norm, linear, activation, dropout) in enumerate(
zip(self.fc_norms, self.fc_layers, self.fc_activations, self.fc_dropouts)):
'get residue'
if self.same_depth:
res_x = x.clone()
else:
res_x = self.residue_adjust[i](x)
'linear layer'
if self.norm_after_linear:
x = res_x + dropout(activation(norm(linear(x), batch=batch)))
else:
x = res_x + dropout(activation(linear(norm(x, batch=batch))))
if return_latent:
return self.output_layer(x), x
else:
return self.output_layer(x)
# noinspection PyAttributeOutsideInit
[docs]
class vectorMLP(scalarMLP):
r"""
scalarMLP model with l=1 vector track added with o3 equivariance
"""
def __init__(self,
layers: int,
filters: int,
input_dim: int,
output_dim: int,
vector_input_dim: int,
vector_output_dim: int,
activation: str = 'gelu',
seed: int = 0,
dropout: float = 0,
norm: str = None,
norm_after_linear: bool = True,
bias: bool = True,
vector_norm: str = None,
ramp_depth: bool = False,
v_to_s_combination: str = 'sum'):
super(scalarMLP, self).__init__()
# initialize constants and layers
self.n_layers = layers
self.output_dim = output_dim
self.v_output_dim = vector_output_dim
self.v_input_dim = vector_input_dim
self.input_dim = input_dim
self.norm_mode = norm
self.dropout_p = dropout
self.activation = activation
self.bias = bias
self.norm_after_linear = norm_after_linear
self.v_norm_mode = vector_norm
self.ramp_depth = ramp_depth
self.v_to_s_combination = v_to_s_combination
torch.manual_seed(seed)
# addition of two normally distributed 3-vectors increases the norm by roughly this factor on average
# divide this out to combat vector elongation & poor gradient flow
self.vector_addition_rescaling_factor = 1.6
self.init_scalar_filters(filters)
self.init_vector_filters(filters)
self.init_scalar_transforms()
self.init_vector_transforms()
[docs]
def init_vector_filters(self, filters):
if self.ramp_depth: # smoothly ramp feature depth across layers
# linear scaling
# self.n_filters = torch.linspace(self.input_dim, self.output_dim, self.n_layers).long().tolist()
# exp scaling for consistent growth ratio
p = np.log(self.v_output_dim) / np.log(self.input_dim)
v_n_filters = [self.v_input_dim] + [int(self.v_input_dim ** (1 + (p - 1) * (i / (self.n_layers)))) for i in
range(self.n_layers)]
else:
v_n_filters = [self.v_input_dim] + [filters for _ in range(self.n_layers)]
self.v_filters_in = v_n_filters[:-1]
self.v_filters_out = v_n_filters[1:]
# if they are not all the same, we need residue adjustments
if v_n_filters.count(v_n_filters[0]) != len(v_n_filters):
self.v_same_depth = False
self.v_residue_adjust = torch.nn.ModuleList([
nn.Linear(self.v_filters_in[i], self.v_filters_out[i], bias=False)
for i in range(self.n_layers)
])
else:
self.v_same_depth = True
[docs]
def forward(self,
x: torch.Tensor,
v: torch.Tensor,
return_latent: bool = False,
batch: Optional[torch.LongTensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
for i, (s_norm, s_linear, s_act, s_dropout,
v_norm, v_linear, v_act,
s2v_norm, s2v_linear, s2v_act,
v2s_linear) in enumerate(
zip(self.fc_norms, self.fc_layers, self.fc_activations, self.fc_dropouts,
self.v_fc_norms, self.v_fc_layers, self.vector_activation,
self.scalar_to_vector_norm, self.s_to_v_gating_layers, self.s_to_v_activations,
self.vector_to_scalar)):
res_x, res_v = self.get_residues(i, x, v)
if torch.sum(torch.isnan(x)) != 0:
print("NaN values in EMLP scalars")
if torch.sum(torch.isnan(v)) != 0:
print("NaN values in EMLP vectors")
'scalar forward'
if self.v_to_s_combination == 'concatenate':
# concatenate vector lengths to scalar values
x = torch.cat([x, v2s_linear(v)], dim=-1)
elif self.v_to_s_combination == 'sum':
x = x + v2s_linear(v)
else:
assert False, f'{self.v_to_s_combination} not implemented'
if self.norm_after_linear:
x = res_x + s_dropout(s_act(s_norm(s_linear(x), batch=batch)))
else:
x = res_x + s_dropout(s_act(s_linear(s_norm(x, batch=batch))))
'vector forward' # A(FC(x)) * FC(N(v)) # rescaling factor keeps norm from exploding
s2v_gating = s2v_act(s2v_norm(s2v_linear(x))[:, None, :])
v = v_act(v_norm(v_linear(v), batch=batch))
v = (res_v + s2v_gating * v) / self.vector_addition_rescaling_factor
if torch.sum(torch.isnan(x)) != 0:
print("NaN values in EMLP scalars")
if torch.sum(torch.isnan(v)) != 0:
print("NaN values in EMLP vectors")
if return_latent:
return self.output_layer(x), self.v_output_layer(v), x
else:
return self.output_layer(x), self.v_output_layer(v)
[docs]
def get_residues(self,
i: int,
x: torch.Tensor,
v: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.same_depth:
x = x.clone()
else:
x = self.residue_adjust[i](x)
if self.v_same_depth:
v = v.clone()
else:
v = self.v_residue_adjust[i](v)
return x, v
[docs]
class EmbeddingBlock(torch.nn.Module):
def __init__(self,
init_node_embedding_dim: int,
num_input_classes: int,
num_scalar_input_features: int,
atom_type_embedding_dim: int):
super(EmbeddingBlock, self).__init__()
self.embeddings = nn.Embedding(num_input_classes + 1, atom_type_embedding_dim)
self.linear = nn.Linear(atom_type_embedding_dim + num_scalar_input_features - 1, init_node_embedding_dim)
[docs]
def forward(self,
x: Tensor) -> Tensor:
if x.dim() == 1:
x = x.unsqueeze(1) # make dim 1 explicit
# always embed the first dimension only (by convention, atomic number)
embedding = self.embeddings(x[:, 0].long())
return self.linear(torch.cat([embedding, x[:, 1:]], dim=-1))
[docs]
class ResidualMLP(nn.Module):
"""
a simpler general-purpose MLP
"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers,
norm, dropout_p):
super().__init__()
self.input_layer = nn.Linear(input_dim, hidden_dim)
self.hidden_layers = nn.ModuleList([
nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)
])
if norm is not None:
self.norms = nn.ModuleList([
nn.LayerNorm(hidden_dim) for _ in range(num_layers)
])
else:
self.norms = nn.ModuleList([
nn.Identity() for _ in range(num_layers)
])
self.dropouts = nn.ModuleList([
nn.Dropout(dropout_p) for _ in range(num_layers)
])
self.output_layer = nn.Linear(hidden_dim, output_dim)
[docs]
def forward(self, x):
x = F.relu(self.input_layer(x))
for layer, dropout, norm in zip(self.hidden_layers, self.dropouts, self.norms):
residual = x
x = dropout(F.relu(norm(layer(x))))
x = x + residual # skip connection
return self.output_layer(x)