mxtaltools.models.task_models.mol_classifier

class mxtaltools.models.task_models.mol_classifier.MoleculeClusterClassifier(seed, config, output_dim, atom_features: list, molecule_features: list, node_standardization_tensor: Tensor, graph_standardization_tensor: Tensor)[source]

Bases: BaseGraphModel

forward(data_batch, return_dists: bool = False, return_latent: bool = False, return_embedding: bool = False) Tuple[Tensor, dict | None][source]