mxtaltools.models.autoencoder_utils

mxtaltools.models.autoencoder_utils.ae_reconstruction_loss(mol_batch, decoding_batch, graph_weighted_node_weights, node_weighted_node_weights, num_atom_types, type_distance_scaling, autoencoder_sigma)[source]
mxtaltools.models.autoencoder_utils.batch_rmsd(mol_batch, decoded_mol_batch, true_node_one_hot, intrapoint_cutoff: float = 0.5, probability_threshold: float = 0.25, type_distance_scaling: float = 2)[source]
mxtaltools.models.autoencoder_utils.compute_coord_evaluation_overlap(config, data, decoded_data, true_nodes)[source]

compute positional overlaps at evaluation sigma # todo could be more flexible

mxtaltools.models.autoencoder_utils.compute_full_evaluation_overlap(mol_batch, decoded_mol_batch, true_nodes, sigma=None, distance_scaling=None)[source]

compute overall overlaps at evaluation sigma

mxtaltools.models.autoencoder_utils.compute_gaussian_overlap(ref_types, mol_batch, decoded_data, sigma, nodewise_weights, dist_to_self=False, isolate_dimensions: list = None, type_distance_scaling=0.1, return_dists=False)[source]

compute distance between gaussian mixtures in high dimension, taking atom types as one-hot dimensions

mxtaltools.models.autoencoder_utils.compute_type_evaluation_overlap(config, data, num_atom_types, decoded_data, true_nodes)[source]

compute typewise overlaps at evaluation sigma # todo could be more flexible

mxtaltools.models.autoencoder_utils.decoding2mol_batch(mol_batch, decoding, num_decoder_nodes, node_weight_temperature, device)[source]
mxtaltools.models.autoencoder_utils.get_node_weights(mol_batch, decoded_mol_batch, decoding, num_decoder_nodes, node_weight_temperature)[source]

extract nodewise normed weights from decoder swarm

mxtaltools.models.autoencoder_utils.init_decoded_data(mol_batch, decoded_batch, device, num_nodes)[source]
mxtaltools.models.autoencoder_utils.test_decoder_equivariance(data, encoding: Tensor, rotated_encoding: Tensor, rotations: Tensor, autoencoder: Module, device: device | str) Tensor[source]

check decoder end-to-end equivariance

mxtaltools.models.autoencoder_utils.test_encoder_equivariance(data, rotations: Tensor, autoencoder) Tuple[Tensor, Tensor, Tensor][source]

check encoder end-to-end equivariance