from argparse import Namespace
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from mxtaltools.models.graph_models.graph_neural_network import ScalarGNN, VectorGNN, \
MolCrystalScalarGNN
from mxtaltools.models.modules.augmented_softmax_aggregator import AugSoftmaxAggregation, VectorAugSoftmaxAggregation
from mxtaltools.models.modules.components import scalarMLP, vectorMLP
from mxtaltools.models.functions.radial_graph import build_radial_graph
from mxtaltools.models.functions.minimum_image_neighbors import argwhere_minimum_image_convention_edges
# noinspection PyAttributeOutsideInit
[docs]
class ScalarMoleculeGraphModel(nn.Module):
def __init__(self,
input_node_dim: int,
output_dim: int,
fc_config: Namespace,
graph_config: Namespace,
activation: str = 'gelu',
num_mol_feats: int = 0,
concat_pos_to_node_dim: bool = False,
concat_mol_to_node_dim: bool = False,
seed: int = 5,
override_cutoff=None
):
super(ScalarMoleculeGraphModel, self).__init__()
torch.manual_seed(seed)
self.concat_pos_to_node_dim = concat_pos_to_node_dim
self.concat_mol_to_node_dim = concat_mol_to_node_dim
if override_cutoff is None:
self.register_buffer('convolution_cutoff', torch.tensor(graph_config.cutoff, dtype=torch.float32))
else:
self.register_buffer('convolution_cutoff', torch.tensor(override_cutoff, dtype=torch.float32))
self.max_num_neighbors = graph_config.max_num_neighbors
self.num_fc_layers = fc_config.num_layers
if concat_mol_to_node_dim:
input_node_dim += num_mol_feats
self.graph_net = ScalarGNN(
activation=activation,
input_node_dim=input_node_dim,
override_cutoff=override_cutoff,
**graph_config.__dict__
)
# initialize global pooling operation
self.global_pool = AugSoftmaxAggregation(
temperature=1,
learn=True,
bias=0.1,
channels=graph_config.embedding_dim)
# molecule features FC layer
self.mol_fc = nn.Linear(num_mol_feats, num_mol_feats) if num_mol_feats != 0 else None
"""Optional MLP model to post-process graph embedding"""
if fc_config.num_layers > 0:
self.gnn_mlp = scalarMLP(layers=fc_config.num_layers,
filters=fc_config.hidden_dim,
norm=fc_config.norm,
dropout=fc_config.dropout,
input_dim=graph_config.embedding_dim + num_mol_feats,
output_dim=fc_config.hidden_dim,
seed=seed,
)
graph_output_dim = fc_config.hidden_dim
else:
graph_output_dim = graph_config.embedding_dim
"""initialize output reshaping layers"""
if graph_output_dim != output_dim: # only want this if we have to change the dimension
self.output_fc = nn.Linear(graph_output_dim, output_dim, bias=False)
else:
self.output_fc = nn.Identity()
[docs]
def forward(self,
x: torch.Tensor,
pos: torch.FloatTensor,
batch: torch.LongTensor,
ptr: torch.LongTensor,
mol_x: Union[torch.Tensor],
num_graphs: int,
edge_index: Optional[torch.LongTensor] = None,
edges_dict: Optional[dict] = None,
return_latent: bool = False,
return_dists: bool = False,
return_embedding: bool = False,
) -> Tuple[torch.Tensor, Optional[dict]]:
x = self.append_init_node_features(x, pos, ptr, mol_x)
x = self.graph_net(x,
pos,
batch,
edge_index) # get graph encoding
if return_embedding:
embedding = x.clone()
else:
embedding = None
# aggregate atoms to molecule / graph representation
x = self.global_pool(x,
batch,
dim_size=num_graphs)
if self.num_fc_layers > 0:
if self.mol_fc is not None:
x = torch.cat([x, self.mol_fc(mol_x)], dim=-1)
gmlp_out = self.gnn_mlp(x)
x = gmlp_out
output = self.output_fc(x)
extra_outputs = self.collect_extra_outputs(x,
edges_dict,
return_dists,
return_latent,
return_embedding,
embedding)
if len(extra_outputs) > 0:
return output, extra_outputs
else:
return output
[docs]
def append_init_node_features(self, x, pos, ptr, mol_x):
if x.ndim == 1:
x = x[:, None]
# simply append node coordinates, PointNet style
if self.concat_pos_to_node_dim:
x = torch.cat((x, pos), dim=-1)
# add molwise information to input node features
if self.concat_mol_to_node_dim:
nodes_per_graph = torch.diff(ptr)
x = torch.cat((x,
torch.repeat_interleave(mol_x, nodes_per_graph, 0)),
dim=-1)
return x
# noinspection PyAttributeOutsideInit
[docs]
class VectorMoleculeGraphModel(nn.Module):
def __init__(self,
input_node_dim: int,
output_dim: int,
fc_config: Namespace,
graph_config: Namespace,
activation: str = 'gelu',
num_mol_feats: int = 0,
concat_pos_to_node_dim: bool = False,
concat_mol_to_node_dim: bool = False,
seed: int = 5,
override_cutoff=None
):
super(VectorMoleculeGraphModel, self).__init__()
torch.manual_seed(seed)
self.concat_pos_to_node_dim = concat_pos_to_node_dim
self.concat_mol_to_node_dim = concat_mol_to_node_dim
if self.concat_pos_to_node_dim:
input_node_dim += 1 # radial dimension - vector features explicitly added later
if override_cutoff is None:
self.register_buffer('convolution_cutoff', torch.tensor(graph_config.cutoff, dtype=torch.float32))
else:
self.register_buffer('convolution_cutoff', torch.tensor(override_cutoff, dtype=torch.float32) if not torch.is_tensor(override_cutoff) else override_cutoff.clone().detach())
self.max_num_neighbors = graph_config.max_num_neighbors
self.num_fc_layers = fc_config.num_layers
if concat_mol_to_node_dim:
input_node_dim += num_mol_feats
self.graph_net = VectorGNN(
activation=activation,
input_node_dim=input_node_dim,
override_cutoff=override_cutoff,
**graph_config.__dict__
)
# initialize global pooling operation
self.global_pool = AugSoftmaxAggregation(
temperature=1,
learn=True,
bias=0.1,
channels=graph_config.embedding_dim)
self.v_global_pool = VectorAugSoftmaxAggregation(
temperature=1,
learn=True,
bias=0.1,
channels=graph_config.embedding_dim)
# molecule features FC layer
self.mol_fc = nn.Linear(num_mol_feats, num_mol_feats) if num_mol_feats != 0 else None
"""Optional MLP model to post-process graph embedding"""
if fc_config.num_layers > 0:
self.gnn_mlp = vectorMLP(layers=fc_config.num_layers,
filters=fc_config.hidden_dim,
norm=fc_config.norm,
dropout=fc_config.dropout,
input_dim=graph_config.embedding_dim + num_mol_feats,
output_dim=fc_config.hidden_dim,
seed=seed,
vector_input_dim=fc_config.hidden_dim,
v_to_s_combination='sum',
vector_norm=fc_config.vector_norm,
vector_output_dim=fc_config.hidden_dim,
)
graph_output_dim = fc_config.hidden_dim
else:
graph_output_dim = graph_config.embedding_dim
"""initialize output reshaping layers"""
if graph_output_dim != output_dim: # only want this if we have to change the dimension
self.output_fc = nn.Linear(graph_output_dim, output_dim, bias=False)
self.v_output_fc = nn.Linear(graph_output_dim, output_dim, bias=False)
else:
self.output_fc = nn.Identity()
self.v_output_fc = nn.Identity()
[docs]
def forward(self,
x: torch.Tensor,
pos: torch.FloatTensor,
batch: torch.LongTensor,
ptr: torch.LongTensor,
num_graphs: int,
mol_x: Optional[torch.Tensor] = None,
edges_dict: Optional[dict] = None,
return_latent: bool = False,
return_dists: bool = False,
return_embedding: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, Optional[dict]]:
if len(self.graph_net.interaction_blocks) > 0 or return_dists:
# todo clean up options around prebuilt radial graphs
if edges_dict is None: # option to rebuild radial graph
edges_dict = build_radial_graph(
pos,
batch,
ptr,
self.convolution_cutoff,
self.max_num_neighbors,
)
else:
edges_dict = None
x, v = self.append_init_node_features(x, pos, ptr, mol_x)
x, v = self.graph_net(x,
v,
pos,
batch,
edges_dict) # get graph encoding
if return_embedding:
embedding = x.clone()
else:
embedding = None
# aggregate atoms to molecule / graph representation
x = self.global_pool(x,
batch,
dim_size=num_graphs)
v = self.v_global_pool(v,
batch,
dim_size=num_graphs,
dim=0,
cart_dim=1)
if self.num_fc_layers > 0:
if self.mol_fc is not None:
x = torch.cat([x, self.mol_fc(mol_x)], dim=-1)
x, v = self.gnn_mlp(x, v)
x_out, v_out = self.output_fc(x), self.v_output_fc(v)
extra_outputs = self.collect_extra_outputs(x,
pos,
batch,
edges_dict,
return_dists,
return_latent,
return_embedding,
embedding)
if len(extra_outputs) > 0:
return x_out, v_out, extra_outputs
else:
return x_out, v_out
[docs]
def append_init_node_features(self,
x: torch.Tensor,
pos: torch.Tensor,
ptr: torch.LongTensor,
mol_x: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if x.ndim == 1:
x = x[:, None]
# append radial position as scalar feature
# and 3 vector dimensions (unit vectors from centroid)
rad = torch.linalg.norm(pos, dim=1)
if self.concat_pos_to_node_dim:
x = torch.cat((x, rad[:, None]), dim=-1) # radii
# v = pos / (rad[:, None] + 1e-5) # normed directions
v = pos[..., None] # set dimension as [n,3,k]
# richer embedding as 3 component vectors rather than one single vector
#v = pos[:, :, None] * torch.eye(3, device=pos.device, dtype=torch.float32).repeat(len(pos), 1, 1)
# add molwise information to input node features
if self.concat_mol_to_node_dim:
nodes_per_graph = torch.diff(ptr)
x = torch.cat((x,
torch.repeat_interleave(mol_x, nodes_per_graph, 0)),
dim=-1)
return x, v
# noinspection PyAttributeOutsideInit
[docs]
class MolecularCrystalGraphModel(nn.Module):
def __init__(self,
input_node_dim: int,
output_dim: int,
fc_config: Namespace,
graph_config: Namespace,
activation: str = 'gelu',
num_mol_feats: int = 0,
concat_mol_ind_to_node_dim: bool = False,
concat_mol_to_node_dim: bool = False,
seed: int = 5,
override_cutoff=None
):
super(MolecularCrystalGraphModel, self).__init__()
torch.manual_seed(seed)
self.concat_mol_to_node_dim = concat_mol_to_node_dim
self.concat_mol_ind_to_node_dim = concat_mol_ind_to_node_dim
if override_cutoff is None:
self.register_buffer('convolution_cutoff', torch.tensor(graph_config.cutoff, dtype=torch.float32))
else:
self.register_buffer('convolution_cutoff', torch.tensor(override_cutoff, dtype=torch.float32))
self.max_num_neighbors = graph_config.max_num_neighbors
self.num_fc_layers = fc_config.num_layers
if concat_mol_to_node_dim:
input_node_dim += num_mol_feats
if concat_mol_ind_to_node_dim:
input_node_dim += 2 # aux_ind and mol_ind will be appended
self.graph_net = MolCrystalScalarGNN(
activation=activation,
input_node_dim=input_node_dim,
override_cutoff=override_cutoff,
**graph_config.__dict__
)
# initialize global pooling operation
self.global_pool = AugSoftmaxAggregation(
temperature=1,
learn=True,
bias=0.1,
channels=graph_config.embedding_dim)
# molecule features FC layer
self.mol_fc = nn.Linear(num_mol_feats, num_mol_feats) if num_mol_feats != 0 else None
"""Optional MLP model to post-process graph embedding"""
if fc_config.num_layers > 0:
self.gnn_mlp = scalarMLP(layers=fc_config.num_layers,
filters=fc_config.hidden_dim,
norm=fc_config.norm,
dropout=fc_config.dropout,
input_dim=graph_config.embedding_dim + num_mol_feats,
output_dim=fc_config.hidden_dim,
seed=seed,
)
graph_output_dim = fc_config.hidden_dim
else:
graph_output_dim = graph_config.embedding_dim
"""initialize output reshaping layers"""
if graph_output_dim != output_dim: # only want this if we have to change the dimension
self.output_fc = nn.Linear(graph_output_dim, output_dim, bias=False)
else:
self.output_fc = nn.Identity()
[docs]
def forward(self,
x: torch.Tensor,
pos: torch.FloatTensor,
batch: torch.LongTensor,
ptr: torch.LongTensor,
mol_x: Union[torch.Tensor],
num_graphs: int,
aux_ind: torch.LongTensor,
mol_ind: torch.LongTensor,
edges_dict: Optional[dict] = None,
return_latent: bool = False,
return_dists: bool = False,
return_embedding: bool = False,
force_edges_rebuild: bool = False,
) -> Tuple[torch.Tensor, Optional[dict]]:
if len(self.graph_net.interaction_blocks) > 0 or return_dists or force_edges_rebuild:
if edges_dict is None or force_edges_rebuild: # option to rebuild radial graph
edges_dict = build_radial_graph(
pos,
batch,
ptr,
self.convolution_cutoff,
self.max_num_neighbors,
aux_ind=aux_ind,
mol_ind=mol_ind,
)
x = self.append_init_node_features(x, ptr, mol_x, aux_ind, mol_ind)
g = self.graph_net(x,
pos,
batch,
aux_ind,
ptr,
edges_dict) # get graph encoding
if return_embedding:
embedding = g.clone()
else:
embedding = None
# aggregate atoms to molecule / graph representation, only over 'inside' nodes
g = self.global_pool(g,
edges_dict['inside_batch'],
dim_size=num_graphs)
if self.num_fc_layers > 0:
if self.mol_fc is not None:
g = torch.cat([g, self.mol_fc(mol_x)], dim=-1)
g = self.gnn_mlp(g)
output = self.output_fc(g)
extra_outputs = self.collect_extra_outputs(x,
pos,
batch,
edges_dict,
return_dists,
return_latent,
return_embedding,
embedding)
if len(extra_outputs) > 0:
return output, extra_outputs
else:
return output
[docs]
def append_init_node_features(self, x, ptr, mol_x, aux_ind, mol_ind):
if x.ndim == 1:
x = x[:, None]
if self.concat_mol_ind_to_node_dim:
x = torch.cat((x, mol_ind[:, None], aux_ind[:, None]), dim=-1)
# add molwise information to input node features
if self.concat_mol_to_node_dim:
nodes_per_graph = torch.diff(ptr)
x = torch.cat((x,
torch.repeat_interleave(mol_x, nodes_per_graph, 0)),
dim=-1)
return x
# noinspection PyAttributeOutsideInit
[docs]
class MoleculeClusterModel(nn.Module):
def __init__(self,
input_node_dim: int,
output_dim: int,
fc_config: Namespace,
graph_config: Namespace,
activation: str = 'gelu',
num_mol_feats: int = 0,
concat_mol_to_node_dim: bool = False,
seed: int = 5,
override_cutoff=None
):
super(MoleculeClusterModel, self).__init__()
torch.manual_seed(seed)
self.concat_mol_to_node_dim = concat_mol_to_node_dim
if override_cutoff is None:
self.register_buffer('convolution_cutoff', torch.tensor(graph_config.cutoff, dtype=torch.float32))
else:
self.register_buffer('convolution_cutoff', torch.tensor(override_cutoff, dtype=torch.float32))
self.max_num_neighbors = graph_config.max_num_neighbors
self.num_fc_layers = fc_config.num_layers
if concat_mol_to_node_dim:
input_node_dim += num_mol_feats
self.graph_net = ScalarGNN(
activation=activation,
input_node_dim=input_node_dim,
override_cutoff=override_cutoff,
**graph_config.__dict__
)
# initialize global pooling operation
self.global_pool = AugSoftmaxAggregation(
temperature=1,
learn=True,
bias=0.1,
channels=graph_config.embedding_dim)
# molecule features FC layer
self.mol_fc = nn.Linear(num_mol_feats, num_mol_feats) if num_mol_feats != 0 else None
"""Optional MLP model to post-process graph embedding"""
if fc_config.num_layers > 0:
self.gnn_mlp = scalarMLP(layers=fc_config.num_layers,
filters=fc_config.hidden_dim,
norm=fc_config.norm,
dropout=fc_config.dropout,
input_dim=graph_config.embedding_dim + num_mol_feats,
output_dim=fc_config.hidden_dim,
seed=seed,
)
graph_output_dim = fc_config.hidden_dim
else:
graph_output_dim = graph_config.embedding_dim
"""initialize output reshaping layers"""
if graph_output_dim != output_dim: # only want this if we have to change the dimension
self.output_fc = nn.Linear(graph_output_dim, output_dim, bias=False)
else:
self.output_fc = nn.Identity()
[docs]
def forward(self,
x: torch.Tensor,
pos: torch.FloatTensor,
ptr: torch.LongTensor,
mol_x: Union[torch.Tensor],
num_graphs: int,
mol_ind: torch.Tensor,
T_fc: torch.Tensor,
edge_index: Optional[torch.LongTensor] = None,
edge_attr: Optional[torch.Tensor] = None,
edges_dict: Optional[dict] = None,
return_latent: bool = False,
return_dists: bool = False,
return_embedding: bool = False
) -> Tuple[torch.Tensor, Optional[dict]]:
assert ptr[-1] == len(x), "Only one cluster sample per batch is allowed for this model"
if len(self.graph_net.interaction_blocks) > 0 or return_dists:
if edge_index is not None:
pass
elif edges_dict is None: # option to rebuild radial graph
edges_dict = argwhere_minimum_image_convention_edges(
num_graphs, pos, T_fc, self.convolution_cutoff)
edge_index = edges_dict['edge_index']
edge_attr = edges_dict['dists']
else:
edge_index = edges_dict['edge_index']
edge_attr = edges_dict['dists']
x = self.append_init_node_features(x, ptr, mol_x)
x = self.graph_net(x,
pos,
mol_ind,
edge_index,
edge_attr) # get graph encoding
if return_embedding:
embedding = x.clone()
else:
embedding = None
# aggregate atoms to molecule / graph representation
x = self.global_pool(x,
index=mol_ind - mol_ind.min(),
dim_size=int(mol_ind.max()))
if self.num_fc_layers > 0:
if self.mol_fc is not None:
x = torch.cat([x, self.mol_fc(mol_x)], dim=-1)
x = self.gnn_mlp(x)
extra_outputs = self.collect_extra_outputs(x,
edges_dict,
return_dists,
return_latent,
return_embedding,
embedding)
if len(extra_outputs) > 0:
return self.output_fc(x), extra_outputs
else:
return self.output_fc(x)
[docs]
def append_init_node_features(self, x, ptr, mol_x):
if x.ndim == 1:
x = x[:, None]
# add molwise information to input node features
if self.concat_mol_to_node_dim:
nodes_per_graph = torch.diff(ptr)
x = torch.cat((x,
torch.repeat_interleave(mol_x, nodes_per_graph, 0)),
dim=-1)
return x