mxtaltools.models.task_models.autoencoder_models

class mxtaltools.models.task_models.autoencoder_models.Mo3ENet(seed, config, num_atom_types: int, atom_embedding_vector: Tensor, radial_normalization: float, protons_in_input: bool)[source]

Bases: BaseGraphModel

check_embedding_quality(mol_batch, sigma=0.35, type_distance_scaling=2, node_weight_temperature=1, num_atom_types=5, visualize=False)[source]
compile_self(dynamic=True, fullgraph=False)[source]
decode(encoding)[source]

encoding nx3xk

encode(mol_batch, override_centering: bool = False)[source]
forward(mol_batch, return_latent: bool = False, return_dists: bool = False, **kwargs)[source]
class mxtaltools.models.task_models.autoencoder_models.Mo3ENetDecoder(seed, config, bottleneck_dim, output_depth, num_nodes)[source]

Bases: Module

forward(x, v)[source]
class mxtaltools.models.task_models.autoencoder_models.Mo3ENetEncoder(seed, config, bottleneck_dim, override_cutoff=None)[source]

Bases: Module

forward(mol_batch)[source]
class mxtaltools.models.task_models.autoencoder_models.Mo3ENetGraphDecoder(config, bottleneck_dim, output_depth, num_nodes)[source]

Bases: Module

forward(x, v)[source]