mxtaltools.models.task_models.crystal_models

class mxtaltools.models.task_models.crystal_models.MolecularCrystalModel(seed, config, atom_features: list, molecule_features: list, output_dim: int, node_standardization_tensor: OptTensor = None, graph_standardization_tensor: OptTensor = None)[source]

Bases: BaseGraphModel

__init__(seed, config, atom_features: list, molecule_features: list, output_dim: int, node_standardization_tensor: OptTensor = None, graph_standardization_tensor: OptTensor = None)[source]

wrapper for molecule model, with appropriate I/O

forward(crystal_batch, return_dists=False, return_latent=False, force_edges_rebuild=False)[source]

overwrites base method