Source code for mxtaltools.models.graph_models.graph_neural_network

from typing import Optional, Tuple

import torch
import torch.nn as nn

from mxtaltools.models.modules.basis_functions import GaussianEmbedding, BesselBasisLayer
from mxtaltools.models.modules.components import scalarMLP, vectorMLP, EmbeddingBlock
from mxtaltools.models.modules.graph_convolution import MConv, v_MConv


[docs] class ScalarGNN(torch.nn.Module): def __init__(self, input_node_dim: int, node_dim: int, fcs_per_gc: int, message_dim: int, embedding_dim: int, num_convs: int, num_radial: int, num_input_classes=101, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, activation='gelu', atom_type_embedding_dim: int = 5, norm: Optional[str] = None, dropout: float = 0, radial_embedding: str = 'bessel', override_cutoff: Optional[float] = None ): super(ScalarGNN, self).__init__() self.max_num_neighbors = max_num_neighbors if override_cutoff is None: self.register_buffer('cutoff', torch.tensor(cutoff, dtype=torch.float32)) else: self.register_buffer('cutoff', torch.tensor(override_cutoff, dtype=torch.float32)) if radial_embedding == 'bessel': self.rbf = BesselBasisLayer(num_radial, self.cutoff, envelope_exponent) elif radial_embedding == 'gaussian': self.rbf = GaussianEmbedding(start=0.0, stop=self.cutoff, num_gaussians=num_radial) self.init_node_embedding = EmbeddingBlock(node_dim, num_input_classes, input_node_dim, atom_type_embedding_dim) self.zeroth_fc_block = scalarMLP(layers=fcs_per_gc, filters=node_dim, input_dim=node_dim, output_dim=node_dim, activation=activation, norm=norm, dropout=dropout) self.interaction_blocks = torch.nn.ModuleList([ MConv( message_dim=message_dim, node_dim=node_dim, edge_embedding_dim=num_radial, norm=None, activation_fn=activation) for _ in range(num_convs) ]) self.fc_blocks = torch.nn.ModuleList([ scalarMLP(layers=fcs_per_gc, filters=node_dim, input_dim=node_dim, output_dim=node_dim, activation=activation, norm=norm, dropout=dropout) for _ in range(num_convs) ]) if node_dim != embedding_dim: self.output_layer = nn.Linear(node_dim, embedding_dim, bias=False) else: self.output_layer = nn.Identity()
[docs] def radial_embedding(self, edge_index, pos: torch.Tensor, dist: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ compute elements for radial & spherical embeddings """ if dist is None: i, j = edge_index # i->j source-to-target dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() return dist, self.rbf(dist) # apply radial basis functions
[docs] def forward(self, z: torch.Tensor, pos: torch.Tensor, batch: torch.LongTensor, edge_index: torch.LongTensor, dist: Optional[torch.Tensor] = None ) -> torch.Tensor: x = self.init_node_embedding(z) x = self.zeroth_fc_block(x=x, batch=batch) if len(self.interaction_blocks) > 0: dist, rbf = self.radial_embedding(edge_index, pos, dist) for n, (convolution, fc) in enumerate(zip(self.interaction_blocks, self.fc_blocks)): x = convolution(x, edge_index, rbf) x = fc(x, batch=batch) return self.output_layer(x)
[docs] class VectorGNN(torch.nn.Module): def __init__(self, input_node_dim: int, node_dim: int, fcs_per_gc: int, message_dim: int, embedding_dim: int, num_convs: int, num_radial: int, num_input_classes=101, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, activation='gelu', atom_type_embedding_dim: int = 5, norm: Optional[str] = None, vector_norm: Optional[str] = None, dropout: float = 0, radial_embedding: str = 'bessel', override_cutoff: Optional[float] = None, v_embedding_dim: Optional[int] = None, v_input_node_dim: Optional[int] = None, ): super(VectorGNN, self).__init__() self.max_num_neighbors = max_num_neighbors if override_cutoff is None: self.register_buffer('cutoff', torch.tensor(cutoff, dtype=torch.float32)) else: self.register_buffer('cutoff', torch.tensor(override_cutoff, dtype=torch.float32) if not torch.is_tensor(override_cutoff) else override_cutoff.clone().detach()) if radial_embedding == 'bessel': self.rbf = BesselBasisLayer(num_radial, self.cutoff, envelope_exponent) elif radial_embedding == 'gaussian': self.rbf = GaussianEmbedding(start=0.0, stop=self.cutoff, num_gaussians=num_radial) if atom_type_embedding_dim == 0: self.init_node_embedding = nn.Identity() else: self.init_node_embedding = EmbeddingBlock(node_dim, num_input_classes, input_node_dim, atom_type_embedding_dim) if v_input_node_dim is None: v_input_node_dim = 1 self.init_vector_embedding = self.init_vector_embedding = nn.Linear(v_input_node_dim, node_dim, bias=False) self.zeroth_fc_block = vectorMLP(layers=fcs_per_gc, filters=node_dim, input_dim=node_dim, output_dim=node_dim, activation=activation, norm=norm, dropout=dropout, vector_input_dim=node_dim, vector_output_dim=node_dim, vector_norm=vector_norm) self.interaction_blocks = torch.nn.ModuleList([ MConv( message_dim=message_dim, node_dim=node_dim, edge_embedding_dim=num_radial, norm=None, activation_fn=activation) for _ in range(num_convs) ]) self.vector_interaction_blocks = torch.nn.ModuleList([ v_MConv( message_depth=message_dim, node_depth=node_dim, edge_embedding_dim=num_radial, norm=None, ) for _ in range(num_convs) ]) self.fc_blocks = torch.nn.ModuleList([ vectorMLP(layers=fcs_per_gc, filters=node_dim, input_dim=node_dim, output_dim=node_dim, activation=activation, norm=norm, dropout=dropout, vector_norm=vector_norm, vector_input_dim=node_dim, vector_output_dim=node_dim) for _ in range(num_convs) ]) if node_dim != embedding_dim: self.output_layer = nn.Linear(node_dim, embedding_dim, bias=False) else: self.output_layer = nn.Identity() if v_embedding_dim is None: v_embedding_dim = embedding_dim if node_dim != v_embedding_dim: self.v_output_layer = nn.Linear(node_dim, v_embedding_dim, bias=False) else: self.v_output_layer = nn.Identity()
[docs] def radial_embedding(self, edge_index: torch.LongTensor, pos: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ compute elements for radial & spherical embeddings """ i, j = edge_index # i->j source-to-target dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() return dist, self.rbf(dist) # apply radial basis functions
[docs] def forward(self, x: torch.Tensor, v: torch.Tensor, pos: torch.Tensor, batch: torch.LongTensor, edges_dict: dict ) -> Tuple[torch.Tensor, torch.Tensor]: x = self.init_node_embedding(x) v = self.init_vector_embedding(v) x, v = self.zeroth_fc_block(x=x, v=v, batch=batch) if len(self.interaction_blocks) > 0: dist, rbf = self.radial_embedding(edges_dict['edge_index'], pos) for n, (convolution, vector_convolution, fc) in enumerate( zip(self.interaction_blocks, self.vector_interaction_blocks, self.fc_blocks)): x = convolution(x, edges_dict['edge_index'], rbf) v = vector_convolution(v, edges_dict['edge_index'], rbf) x, v = fc(x=x, v=v, batch=batch) return self.output_layer(x), self.v_output_layer(v)
[docs] class MolCrystalScalarGNN(torch.nn.Module): def __init__(self, input_node_dim: int, node_dim: int, fcs_per_gc: int, message_dim: int, embedding_dim: int, num_convs: int, num_radial: int, num_input_classes=101, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, activation='gelu', atom_type_embedding_dim: int = 5, norm: Optional[str] = None, dropout: float = 0, radial_embedding: str = 'bessel', override_cutoff: Optional[float] = None ): super(MolCrystalScalarGNN, self).__init__() self.max_num_neighbors = max_num_neighbors if override_cutoff is None: self.register_buffer('cutoff', torch.tensor(cutoff, dtype=torch.float32)) else: self.register_buffer('cutoff', torch.tensor(override_cutoff, dtype=torch.float32)) if radial_embedding == 'bessel': self.rbf = BesselBasisLayer(num_radial, self.cutoff, envelope_exponent) elif radial_embedding == 'gaussian': self.rbf = GaussianEmbedding(start=0.0, stop=self.cutoff, num_gaussians=num_radial) self.init_node_embedding = EmbeddingBlock(node_dim, num_input_classes, input_node_dim, atom_type_embedding_dim) self.zeroth_fc_block = scalarMLP(layers=fcs_per_gc, filters=node_dim, input_dim=node_dim, output_dim=node_dim, activation=activation, norm=norm, dropout=dropout) self.interaction_blocks = torch.nn.ModuleList([ MConv( message_dim=message_dim, node_dim=node_dim, edge_embedding_dim=num_radial, norm=None, activation_fn=activation) for _ in range(num_convs) ]) self.fc_blocks = torch.nn.ModuleList([ scalarMLP(layers=fcs_per_gc, filters=node_dim, input_dim=node_dim, output_dim=node_dim, activation=activation, norm=norm, dropout=dropout) for _ in range(num_convs) ]) if node_dim != embedding_dim: self.output_layer = nn.Linear(node_dim, embedding_dim, bias=False) else: self.output_layer = nn.Identity()
[docs] def radial_embedding(self, edge_index: torch.LongTensor, pos: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ compute elements for radial & spherical embeddings """ i, j = edge_index # i->j source-to-target dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() return dist, self.rbf(dist) # apply radial basis functions
[docs] def forward(self, z: torch.Tensor, pos: torch.Tensor, batch: torch.LongTensor, aux_ind: torch.Tensor, ptr: torch.LongTensor, edges_dict: dict ) -> torch.Tensor: x = self.init_node_embedding(z) x = self.zeroth_fc_block(x=x, batch=batch) # assumes input with inside-outside structure, and enforces periodicity after each convolution edge_index, edge_index_inter, inside_inds, outside_inds, inside_batch, n_repeats = ( edges_dict['edge_index'], edges_dict['edge_index_inter'], edges_dict['inside_inds'], edges_dict['outside_inds'], edges_dict['inside_batch'], edges_dict['n_repeats'] ) edge_index = torch.cat((edge_index, edge_index_inter), dim=1) # all edges counted in one big batch if len(self.interaction_blocks) > 0: dist, rbf = self.radial_embedding(edge_index, pos) for n, (convolution, fc) in enumerate(zip(self.interaction_blocks, self.fc_blocks)): x = convolution(x, edge_index, rbf) # manually periodize inside nodes to outside nodes after each convolution x[inside_inds] = fc(x[inside_inds], batch=batch[inside_inds]) # then broadcast node features to all symmetry images x = self.periodize_molecular_crystal(inside_batch, inside_inds, n, n_repeats, ptr, x, aux_ind) return self.output_layer(x)
[docs] def periodize_molecular_crystal(self, inside_batch, inside_inds, n, n_repeats, ptr, x, aux_ind): for ii in range(len(ptr) - 1): # enforce periodicity for each crystal, assuming invariant node features # copy the first asymmetric unit to all periodic images (safe since all are SE(3) invariant) # assumes the crystals are ordered specifically in this way # todo check if this could be done faster with repeat_interleave x[ptr[ii]:ptr[ii + 1], :] = x[inside_inds[inside_batch == ii]].repeat(n_repeats[ii], 1) x[:, -1] = aux_ind # manually re-indicate inside/outside structure if n == len(self.interaction_blocks) - 1: x = x[inside_inds] # reduce to inside image on the final convolution return x
# old method - has since been broken up into separate models # class GraphNeuralNetwork(torch.nn.Module): # def __init__(self, # input_node_dim: int, # node_dim: int, # fcs_per_gc: int, # message_dim: int, # embedding_dim: int, # num_convs: int, # num_radial: int, # num_input_classes=101, # cutoff: float = 5.0, # max_num_neighbors: int = 32, # envelope_exponent: int = 5, # activation='gelu', # atom_type_embedding_dim=5, # norm: Optional[str] = None, # dropout: float = 0, # radial_embedding: str = 'bessel', # num_attention_heads: int = 1, # # periodize_inside_nodes: bool = False, # outside_convolution_type: str = 'none', # add_vector_track: bool = False, # vector_norm: Optional[str] = None, # override_cutoff: Optional[float] = None # ): # super(GraphNeuralNetwork, self).__init__() # # self.max_num_neighbors = max_num_neighbors # if override_cutoff is None: # self.register_buffer('cutoff', torch.tensor(cutoff, dtype=torch.float32)) # else: # self.register_buffer('cutoff', torch.tensor(override_cutoff, dtype=torch.float32)) # # self.periodize_inside_nodes = periodize_inside_nodes # self.outside_convolution_type = outside_convolution_type # self.add_vector_track = add_vector_track # self.vector_addition_rescaling_factor = 1.6 # # if radial_embedding == 'bessel': # self.rbf = BesselBasisLayer(num_radial, self.cutoff, envelope_exponent) # elif radial_embedding == 'gaussian': # self.rbf = GaussianEmbedding(start=0.0, stop=self.cutoff, num_gaussians=num_radial) # # self.init_node_embedding = EmbeddingBlock(node_dim, # num_input_classes, # input_node_dim, # atom_type_embedding_dim) # # if self.add_vector_track: # self.init_vector_embedding = nn.Linear(1, node_dim, bias=False) # # self.zeroth_fc_block = FCBlock( # fcs_per_gc, # node_dim, # activation, # norm, # dropout, # equivariant=self.add_vector_track, # vector_norm=vector_norm) # # self.interaction_blocks = torch.nn.ModuleList([ # GCBlock(message_dim, # node_dim, # num_radial, # heads=num_attention_heads, # add_vector_channel=add_vector_track, # ) # for _ in range(num_convs) # ]) # # self.fc_blocks = torch.nn.ModuleList([ # FCBlock( # fcs_per_gc, # node_dim, # activation, # norm, # dropout, # equivariant=self.add_vector_track, # vector_norm=vector_norm, # ) # for _ in range(num_convs) # ]) # # self.output_block = OutputBlock(node_dim, embedding_dim, add_vector_track) # # def radial_embedding(self, edge_index, pos): # """ # compute elements for radial & spherical embeddings # """ # i, j = edge_index # i->j source-to-target # dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() # # return dist, self.rbf(dist) # apply radial basis functions # # def forward(self, z, pos, batch, ptr, edges_dict: dict): # # x, v = self.init_node_embeddings(z) # # if self.add_vector_track: # x, v = self.zeroth_fc_block(x=x, v=v, batch=batch) # else: # x = self.zeroth_fc_block(x=x, v=v, batch=batch) # # if len(self.interaction_blocks) > 0: # (edge_index, edge_index_inter, # inside_batch, inside_inds, # n_repeats, # rbf, rbf_inter) = self.get_edges(edges_dict, pos) # # for n, (convolution, fc) in enumerate(zip(self.interaction_blocks, self.fc_blocks)): # if v is not None: # x_res, v_res = x.clone(), v.clone() # x, v = convolution(x, v, rbf, edge_index) # x = x + x_res # v = (v + v_res) / self.vector_addition_rescaling_factor # else: # x = x + convolution(x, v, rbf, edge_index) # # if not self.periodize_inside_nodes: # inside/outside periodic convolution # if self.add_vector_track: # x_res, v_res = x.clone(), v.clone() # x, v = fc(x, v=v, batch=batch) # x = x + x_res # v = (v + v_res) / self.vector_addition_rescaling_factor # else: # x = x + fc(x, v=v, batch=batch) # # #assert torch.sum(torch.isnan(x)) == 0, f"NaN in fc_block output {get_model_nans(self.fc_blocks)}" # # else: # assert v is None, "Vector embeddings not set up for periodic molecular crystal graph convolutions" # # update only the inside inds # x[inside_inds] = (x[inside_inds] + fc(x[inside_inds], batch=batch[inside_inds])) # # # then broadcast to all symmetry images # x = self.periodize_molecular_crystal(inside_batch, inside_inds, n, n_repeats, ptr, x) # # return self.output_block(x, v) # # def init_node_embeddings(self, z): # if self.add_vector_track: # x, v = z[:, :-3], z[:, -3:] # vector features are trailing 3 dimensions of node input # v = self.init_vector_embedding(v[:, :, None]) # [n_nodes, 3] -> [n_nodes, 3, n_dim] # x = self.init_node_embedding(x) # embed atomic numbers & compute initial atom-wise feature vector # else: # x, v = self.init_node_embedding(z), None # embed atomic numbers & compute initial atom-wise feature vector # return x, v # # def periodize_molecular_crystal(self, inside_batch, inside_inds, n, n_repeats, ptr, x): # for ii in range(len(ptr) - 1): # enforce periodicity for each crystal, assuming invariant node features # # copy the first asymmetric unit to all periodic images (safe since all are SE(3) invariant) # x[ptr[ii]:ptr[ii + 1], :] = x[inside_inds[inside_batch == ii]].repeat(n_repeats[ii], 1) # # if n == len(self.interaction_blocks) - 1: # x = x[inside_inds] # reduce to inside image on the final convolution # # return x # # def get_edges(self, # edges_dict: dict, # pos: torch.Tensor): # if self.outside_convolution_type == 'none': # # no inside/outside distinctions # edge_index = edges_dict['edge_index'] # if 'dists' in edges_dict.keys(): # previously generated distances - e.g., for periodic MIC # dist = edges_dict['dists'] # rbf = self.rbf(dist) # else: # dist, rbf = self.radial_embedding(edge_index, pos) # edge_index_inter, inside_batch, inside_inds, n_repeats, rbf_inter = None, None, None, None, None # assert not self.periodize_inside_nodes, "Cannot periodize to outside nodes if there are no outside nodes" # # elif self.outside_convolution_type == 'all_layers': # # assumes input with inside-outside structure, and enforces periodicity after each convolution # # edge_index, edge_index_inter, inside_inds, outside_inds, inside_batch, n_repeats = list(edges_dict.values()) # edge_index = torch.cat((edge_index, edge_index_inter), dim=1) # all edges counted in one big batch # dist, rbf = self.radial_embedding(edge_index, pos) # rbf_inter = None # # elif self.outside_convolution_type == 'last_layer': # assert False, "Last layer outside convolution is deprecated" # # edge_index, edge_index_inter, inside_inds, outside_inds, inside_batch, n_repeats = list(edges_dict.values()) # # dist, rbf = self.radial_embedding(edge_index, pos) # # dist_inter, rbf_inter = self.radial_embedding(torch.cat((edge_index, edge_index_inter), dim=1), pos) # # # re-integrate this in forward method to bring it back # # if n == (len(self.interaction_blocks) - 1) and self.outside_convolution_type == 'last_layer': # # # return only the results of the intermolecular convolution, omitting intramolecular features # # x = convolution(x, v, # # rbf_inter, # # torch.cat((edge_index, edge_index_inter), dim=1), # # batch) # # else: # assert False, "Must select a valid treatment of inside vs outside nodes" # # return edge_index, edge_index_inter, inside_batch, inside_inds, n_repeats, rbf, rbf_inter