mxtaltools.models.graph_models.graph_neural_network

class mxtaltools.models.graph_models.graph_neural_network.MolCrystalScalarGNN(input_node_dim: int, node_dim: int, fcs_per_gc: int, message_dim: int, embedding_dim: int, num_convs: int, num_radial: int, num_input_classes=101, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, activation='gelu', atom_type_embedding_dim: int = 5, norm: str | None = None, dropout: float = 0, radial_embedding: str = 'bessel', override_cutoff: float | None = None)[source]

Bases: Module

forward(z: Tensor, pos: Tensor, batch: LongTensor, aux_ind: Tensor, ptr: LongTensor, edges_dict: dict) Tensor[source]
periodize_molecular_crystal(inside_batch, inside_inds, n, n_repeats, ptr, x, aux_ind)[source]
radial_embedding(edge_index: LongTensor, pos: Tensor) Tuple[Tensor, Tensor][source]

compute elements for radial & spherical embeddings

class mxtaltools.models.graph_models.graph_neural_network.ScalarGNN(input_node_dim: int, node_dim: int, fcs_per_gc: int, message_dim: int, embedding_dim: int, num_convs: int, num_radial: int, num_input_classes=101, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, activation='gelu', atom_type_embedding_dim: int = 5, norm: str | None = None, dropout: float = 0, radial_embedding: str = 'bessel', override_cutoff: float | None = None)[source]

Bases: Module

forward(z: Tensor, pos: Tensor, batch: LongTensor, edge_index: LongTensor, dist: Tensor | None = None) Tensor[source]
radial_embedding(edge_index, pos: Tensor, dist: Tensor | None = None) Tuple[Tensor, Tensor][source]

compute elements for radial & spherical embeddings

class mxtaltools.models.graph_models.graph_neural_network.VectorGNN(input_node_dim: int, node_dim: int, fcs_per_gc: int, message_dim: int, embedding_dim: int, num_convs: int, num_radial: int, num_input_classes=101, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, activation='gelu', atom_type_embedding_dim: int = 5, norm: str | None = None, vector_norm: str | None = None, dropout: float = 0, radial_embedding: str = 'bessel', override_cutoff: float | None = None, v_embedding_dim: int | None = None, v_input_node_dim: int | None = None)[source]

Bases: Module

forward(x: Tensor, v: Tensor, pos: Tensor, batch: LongTensor, edges_dict: dict) Tuple[Tensor, Tensor][source]
radial_embedding(edge_index: LongTensor, pos: Tensor) Tuple[Tensor, Tensor][source]

compute elements for radial & spherical embeddings