mxtaltools.models.graph_models.molecule_graph_model

class mxtaltools.models.graph_models.molecule_graph_model.MolecularCrystalGraphModel(input_node_dim: int, output_dim: int, fc_config: Namespace, graph_config: Namespace, activation: str = 'gelu', num_mol_feats: int = 0, concat_mol_ind_to_node_dim: bool = False, concat_mol_to_node_dim: bool = False, seed: int = 5, override_cutoff=None)[source]

Bases: Module

append_init_node_features(x, ptr, mol_x, aux_ind, mol_ind)[source]
static collect_extra_outputs(x: Tensor, pos: Tensor, batch: LongTensor, edges_dict: dict, return_dists: bool, return_latent: bool, return_embedding: bool, embedding: Tensor | None) dict[source]
forward(x: Tensor, pos: FloatTensor, batch: LongTensor, ptr: LongTensor, mol_x: Tensor, num_graphs: int, aux_ind: LongTensor, mol_ind: LongTensor, edges_dict: dict | None = None, return_latent: bool = False, return_dists: bool = False, return_embedding: bool = False, force_edges_rebuild: bool = False) Tuple[Tensor, dict | None][source]
mol_fc

Optional MLP model to post-process graph embedding

class mxtaltools.models.graph_models.molecule_graph_model.MoleculeClusterModel(input_node_dim: int, output_dim: int, fc_config: Namespace, graph_config: Namespace, activation: str = 'gelu', num_mol_feats: int = 0, concat_mol_to_node_dim: bool = False, seed: int = 5, override_cutoff=None)[source]

Bases: Module

append_init_node_features(x, ptr, mol_x)[source]
static collect_extra_outputs(x: Tensor, edges_dict: dict, return_dists: bool, return_latent: bool, return_embedding: bool, embedding: Tensor | None) dict[source]
forward(x: Tensor, pos: FloatTensor, ptr: LongTensor, mol_x: Tensor, num_graphs: int, mol_ind: Tensor, T_fc: Tensor, edge_index: LongTensor | None = None, edge_attr: Tensor | None = None, edges_dict: dict | None = None, return_latent: bool = False, return_dists: bool = False, return_embedding: bool = False) Tuple[Tensor, dict | None][source]
mol_fc

Optional MLP model to post-process graph embedding

class mxtaltools.models.graph_models.molecule_graph_model.ScalarMoleculeGraphModel(input_node_dim: int, output_dim: int, fc_config: Namespace, graph_config: Namespace, activation: str = 'gelu', num_mol_feats: int = 0, concat_pos_to_node_dim: bool = False, concat_mol_to_node_dim: bool = False, seed: int = 5, override_cutoff=None)[source]

Bases: Module

append_init_node_features(x, pos, ptr, mol_x)[source]
static collect_extra_outputs(x: Tensor, edges_dict: dict, return_dists: bool, return_latent: bool, return_embedding: bool, embedding: Tensor | None) dict[source]
forward(x: Tensor, pos: FloatTensor, batch: LongTensor, ptr: LongTensor, mol_x: Tensor, num_graphs: int, edge_index: LongTensor | None = None, edges_dict: dict | None = None, return_latent: bool = False, return_dists: bool = False, return_embedding: bool = False) Tuple[Tensor, dict | None][source]
mol_fc

Optional MLP model to post-process graph embedding

class mxtaltools.models.graph_models.molecule_graph_model.VectorMoleculeGraphModel(input_node_dim: int, output_dim: int, fc_config: Namespace, graph_config: Namespace, activation: str = 'gelu', num_mol_feats: int = 0, concat_pos_to_node_dim: bool = False, concat_mol_to_node_dim: bool = False, seed: int = 5, override_cutoff=None)[source]

Bases: Module

append_init_node_features(x: Tensor, pos: Tensor, ptr: LongTensor, mol_x: Tensor | None = None) Tuple[Tensor, Tensor][source]
static collect_extra_outputs(x: Tensor, pos: Tensor, batch: LongTensor, edges_dict: dict, return_dists: bool, return_latent: bool, return_embedding: bool, embedding: Tensor | None) dict[source]
forward(x: Tensor, pos: FloatTensor, batch: LongTensor, ptr: LongTensor, num_graphs: int, mol_x: Tensor | None = None, edges_dict: dict | None = None, return_latent: bool = False, return_dists: bool = False, return_embedding: bool = False) Tuple[Tensor, Tensor, dict | None][source]
mol_fc

Optional MLP model to post-process graph embedding