mxtaltools.models.autoencoder_utils
- mxtaltools.models.autoencoder_utils.ae_reconstruction_loss(mol_batch, decoding_batch, graph_weighted_node_weights, node_weighted_node_weights, num_atom_types, type_distance_scaling, autoencoder_sigma)[source]
- mxtaltools.models.autoencoder_utils.batch_rmsd(mol_batch, decoded_mol_batch, true_node_one_hot, intrapoint_cutoff: float = 0.5, probability_threshold: float = 0.25, type_distance_scaling: float = 2)[source]
- mxtaltools.models.autoencoder_utils.compute_coord_evaluation_overlap(config, data, decoded_data, true_nodes)[source]
compute positional overlaps at evaluation sigma # todo could be more flexible
- mxtaltools.models.autoencoder_utils.compute_full_evaluation_overlap(mol_batch, decoded_mol_batch, true_nodes, sigma=None, distance_scaling=None)[source]
compute overall overlaps at evaluation sigma
- mxtaltools.models.autoencoder_utils.compute_gaussian_overlap(ref_types, mol_batch, decoded_data, sigma, nodewise_weights, dist_to_self=False, isolate_dimensions: list = None, type_distance_scaling=0.1, return_dists=False)[source]
compute distance between gaussian mixtures in high dimension, taking atom types as one-hot dimensions
- mxtaltools.models.autoencoder_utils.compute_type_evaluation_overlap(config, data, num_atom_types, decoded_data, true_nodes)[source]
compute typewise overlaps at evaluation sigma # todo could be more flexible
- mxtaltools.models.autoencoder_utils.decoding2mol_batch(mol_batch, decoding, num_decoder_nodes, node_weight_temperature, device)[source]
- mxtaltools.models.autoencoder_utils.get_node_weights(mol_batch, decoded_mol_batch, decoding, num_decoder_nodes, node_weight_temperature)[source]
extract nodewise normed weights from decoder swarm
- mxtaltools.models.autoencoder_utils.init_decoded_data(mol_batch, decoded_batch, device, num_nodes)[source]