Source code for mxtaltools.models.autoencoder_utils

from typing import Union, Tuple

import torch
from torch import nn as nn
from torch.nn import functional as F
from torch_scatter import scatter, scatter_softmax

from mxtaltools.models.functions.radial_graph import radius


[docs] def compute_gaussian_overlap(ref_types, mol_batch, decoded_data, sigma, nodewise_weights, dist_to_self=False, isolate_dimensions: list = None, type_distance_scaling=0.1, return_dists=False ): """ compute distance between gaussian mixtures in high dimension, taking atom types as one-hot dimensions """ # todo this could be simplified ref_points = torch.cat((mol_batch.pos, ref_types * type_distance_scaling), dim=1) if dist_to_self: pred_points = ref_points else: pred_types = decoded_data.x * type_distance_scaling # nodes are already weighted at 1 pred_points = torch.cat((decoded_data.pos, pred_types), dim=1) # assume input x has already been normalized if isolate_dimensions is not None: # only compute distances over certain dimensions ref_points = ref_points[:, isolate_dimensions[0]:isolate_dimensions[1]] pred_points = pred_points[:, isolate_dimensions[0]:isolate_dimensions[1]] edges = radius(ref_points, pred_points, # r=2 * ref_points[:, :3].norm(dim=1).amax(), # max range encompasses largest molecule in the batch # alternatively any point which will have even a small overlap - should be faster by ignoring unimportant edges, where the gradient will anyway be vanishing r=4 * sigma, max_num_neighbors=1000, batch_x=mol_batch.batch, batch_y=decoded_data.batch) # this step is slower than before dists = torch.linalg.norm(ref_points[edges[1]] - pred_points[edges[0]], dim=1) overlap = torch.exp(-torch.pow(dists / sigma, 2)) scaled_overlap = overlap * nodewise_weights[edges[0]] # reweight appropriately nodewise_overlap = scatter(scaled_overlap, edges[1], reduce='sum', dim_size=mol_batch.num_nodes) if not return_dists: return nodewise_overlap else: return nodewise_overlap, edges, dists
[docs] def compute_type_evaluation_overlap(config, data, num_atom_types, decoded_data, true_nodes): """ compute typewise overlaps at evaluation sigma # todo could be more flexible """ type_overlap = compute_gaussian_overlap( true_nodes, data, decoded_data, config.autoencoder.evaluation_sigma, nodewise_weights=decoded_data.aux_ind, isolate_dimensions=[3, 3 + num_atom_types], type_distance_scaling=config.autoencoder.type_distance_scaling ) self_type_overlap = compute_gaussian_overlap( true_nodes, data, data, config.autoencoder.evaluation_sigma, nodewise_weights=torch.ones(len(data.z), device=data.z.device, dtype=torch.float32), dist_to_self=True, isolate_dimensions=[3, 3 + num_atom_types], type_distance_scaling=config.autoencoder.type_distance_scaling ) return self_type_overlap, type_overlap
[docs] def compute_coord_evaluation_overlap( config, data, decoded_data, true_nodes): """ compute positional overlaps at evaluation sigma # todo could be more flexible """ coord_overlap = compute_gaussian_overlap( true_nodes, data, decoded_data, config.autoencoder.evaluation_sigma, nodewise_weights=decoded_data.aux_ind, isolate_dimensions=[0, 3], type_distance_scaling=config.autoencoder.type_distance_scaling ) self_coord_overlap = compute_gaussian_overlap( true_nodes, data, data, config.autoencoder.evaluation_sigma, nodewise_weights=torch.ones(len(data.z), device=data.z.device, dtype=torch.float32), dist_to_self=True, isolate_dimensions=[0, 3], type_distance_scaling=config.autoencoder.type_distance_scaling ) return coord_overlap, self_coord_overlap
[docs] def compute_full_evaluation_overlap(mol_batch, decoded_mol_batch, true_nodes, sigma=None, distance_scaling=None): """ compute overall overlaps at evaluation sigma """ full_overlap = compute_gaussian_overlap( true_nodes, mol_batch, decoded_mol_batch, sigma, nodewise_weights=decoded_mol_batch.aux_ind, type_distance_scaling=distance_scaling, ) self_overlap = compute_gaussian_overlap( true_nodes, mol_batch, mol_batch, sigma, nodewise_weights=torch.ones(len(mol_batch.z), device=mol_batch.z.device, dtype=torch.float32), dist_to_self=True, type_distance_scaling=distance_scaling) return full_overlap, self_overlap
[docs] def get_node_weights(mol_batch, decoded_mol_batch, decoding, num_decoder_nodes, node_weight_temperature): """ extract nodewise normed weights from decoder swarm """ # per-atom weights of each graph molwise_weight_per_swarm_point = mol_batch.num_atoms / num_decoder_nodes # cast to num_decoder_nodes weight_per_swarm_point = molwise_weight_per_swarm_point.repeat_interleave(num_decoder_nodes) # softmax over decoding weight dimension, adjusted by temperature nodewise_weights = scatter_softmax(decoding[:, -1] / node_weight_temperature, decoded_mol_batch.batch, dim=0, dim_size=decoded_mol_batch.num_nodes) # reweigh against the number of atoms nodewise_weights_tensor = nodewise_weights * mol_batch.num_atoms.repeat_interleave( num_decoder_nodes) return weight_per_swarm_point, nodewise_weights, nodewise_weights_tensor
[docs] def init_decoded_data(mol_batch, decoded_batch, device, num_nodes): decoded_data = mol_batch.detach().clone() decoded_data.pos = decoded_batch[:, :3] decoded_data.batch = torch.arange(mol_batch.num_graphs).repeat_interleave(num_nodes).to(device) return decoded_data
[docs] def test_decoder_equivariance(data, encoding: torch.Tensor, rotated_encoding: torch.Tensor, rotations: torch.Tensor, autoencoder: nn.Module, device: Union[torch.device, str]) -> torch.Tensor: """ check decoder end-to-end equivariance """ '''take a given embedding and decode it''' decoding = autoencoder.decode(encoding) '''rotate embedding and decode''' decoding2 = autoencoder.decode( rotated_encoding.reshape(data.num_graphs, 3, encoding.shape[-1])) '''rotate first decoding and compare''' decoded_batch = torch.arange(data.num_graphs).repeat_interleave(autoencoder.num_decoder_nodes).to(device) rotated_decoding_positions = torch.cat( [torch.einsum('ij, kj->ki', rotations[ind], decoding[:, :3][decoded_batch == ind]) for ind in range(data.num_graphs)]) rotated_decoding = decoding.clone() rotated_decoding[:, :3] = rotated_decoding_positions # first three dimensions should be equivariant and all trailing invariant decoder_equivariance_loss = ( torch.abs(rotated_decoding[:, :3] - decoding2[:, :3]) / (1e-3 + torch.abs(rotated_decoding[:, :3]))) return decoder_equivariance_loss.mean(-1)
[docs] def test_encoder_equivariance(data, rotations: torch.Tensor, autoencoder) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ check encoder end-to-end equivariance """ '''embed the input data then rotate the embedding''' encoding = autoencoder.encode(data.clone(), override_centering=True) rotated_encoding = torch.einsum('nij, njk->nik', rotations, encoding ) # rotate in 3D rotated_encoding = rotated_encoding.reshape(data.num_graphs, rotated_encoding.shape[-1] * 3) '''rotate the input data and embed it''' data.pos = torch.cat([torch.einsum('ij, kj->ki', rotations[ind], data.pos[data.batch == ind]) for ind in range(data.num_graphs)]) encoding2 = autoencoder.encode(data.clone(), override_centering=True) encoding2 = encoding2.reshape(data.num_graphs, encoding2.shape[-1] * 3) '''compare the embeddings - should be identical for an equivariant embedding''' encoder_equivariance_loss = (torch.abs(rotated_encoding - encoding2) / torch.abs(rotated_encoding)).mean(-1) return encoder_equivariance_loss, encoding, rotated_encoding
[docs] def decoding2mol_batch(mol_batch, decoding, num_decoder_nodes, node_weight_temperature, device): # generate input reconstructed as a data type decoded_mol_batch = init_decoded_data(mol_batch, decoding, device, num_decoder_nodes ) # compute the distributional weight of each node nodewise_graph_weights, graph_weighted_node_weights, node_weighted_node_weights = \ get_node_weights(mol_batch, decoded_mol_batch, decoding, num_decoder_nodes, node_weight_temperature) decoded_mol_batch.aux_ind = node_weighted_node_weights # input node weights are always 1 - corresponding each to an atom mol_batch.aux_ind = torch.ones(mol_batch.num_nodes, dtype=torch.float32, device=device) # get probability distribution over type dimensions decoded_mol_batch.x = F.softmax(decoding[:, 3:-1], dim=1) decoded_mol_batch.num_nodes = len(decoded_mol_batch.x) return decoded_mol_batch, nodewise_graph_weights, graph_weighted_node_weights, node_weighted_node_weights
[docs] def ae_reconstruction_loss(mol_batch, decoding_batch, graph_weighted_node_weights, node_weighted_node_weights, num_atom_types, type_distance_scaling, autoencoder_sigma, ): true_node_one_hot = F.one_hot(mol_batch.x.flatten().long(), num_classes=num_atom_types).float() decoder_likelihoods, input2output_edges, input2output_dists = ( compute_gaussian_overlap(true_node_one_hot, mol_batch, decoding_batch, autoencoder_sigma, nodewise_weights=decoding_batch.aux_ind, type_distance_scaling=type_distance_scaling, return_dists=True )) # if sigma is too large, these can be > 1, so we map to the overlap of the true density with itself self_likelihoods = compute_gaussian_overlap( true_node_one_hot, mol_batch, mol_batch, autoencoder_sigma, nodewise_weights=mol_batch.aux_ind, dist_to_self=True, type_distance_scaling=type_distance_scaling) # typewise agreement for whole graph per_graph_true_types = scatter( true_node_one_hot, mol_batch.batch[:, None], dim=0, reduce='mean') per_graph_pred_types = scatter( decoding_batch.x * graph_weighted_node_weights[:, None], decoding_batch.batch[:, None], dim=0, reduce='sum') nodewise_type_loss = ( F.binary_cross_entropy(per_graph_pred_types.clip(min=1e-6, max=1 - 1e-6), per_graph_true_types) - F.binary_cross_entropy(per_graph_true_types, per_graph_true_types)) nodewise_reconstruction_loss = F.smooth_l1_loss(decoder_likelihoods, self_likelihoods, reduction='none') graph_reconstruction_loss = scatter(nodewise_reconstruction_loss, mol_batch.batch, reduce='mean') # new losses - # 1 penalize output components for distance to nearest atom nearest_node_dist = scatter(input2output_dists, input2output_edges[0], reduce='min', dim_size=decoding_batch.num_nodes ) nearest_node_loss = scatter(nearest_node_dist, decoding_batch.batch, reduce='mean', dim_size=mol_batch.num_graphs) # 1a also identify reciprocal distance from each atom to nearest component nearest_component_dist = scatter(input2output_dists, input2output_edges[1], reduce='min', dim_size=mol_batch.num_nodes ) nearest_component_loss = scatter(nearest_component_dist, mol_batch.batch, reduce='mean', dim_size=mol_batch.num_graphs) # 2 penalize area near an atom for not being a part of an exactly atom-size clump collect_bools = input2output_dists < 0.5 inds_within_cutoff = input2output_edges[0][collect_bools] inside_edge_nodes = input2output_edges[1][collect_bools] collected_particle_weights = node_weighted_node_weights[inds_within_cutoff] pred_particle_weights = scatter(collected_particle_weights, inside_edge_nodes, reduce='sum', dim_size=mol_batch.num_nodes, ) nodewise_clumping_loss = F.smooth_l1_loss(pred_particle_weights, torch.ones_like(pred_particle_weights), reduction='none') graph_clumping_loss = scatter(nodewise_clumping_loss, mol_batch.batch, reduce='mean') return (nodewise_reconstruction_loss, nodewise_type_loss, graph_reconstruction_loss, self_likelihoods, nearest_node_loss, graph_clumping_loss, nearest_component_dist, nearest_component_loss)
[docs] def batch_rmsd(mol_batch, decoded_mol_batch, true_node_one_hot, intrapoint_cutoff: float = 0.5, probability_threshold: float = 0.25, type_distance_scaling: float = 2): ref_types = true_node_one_hot.float() ref_points = torch.cat((mol_batch.pos, ref_types * type_distance_scaling), dim=1) pred_types = decoded_mol_batch.x * type_distance_scaling # nodes are already weighted at 1 pred_points = torch.cat((decoded_mol_batch.pos, pred_types), dim=1) # assume input x has already been normalized nodewise_weights = decoded_mol_batch.aux_ind edges = radius(ref_points, pred_points, r=intrapoint_cutoff, max_num_neighbors=1000, batch_x=mol_batch.batch, batch_y=decoded_mol_batch.batch) # this step is slower than before dists = torch.linalg.norm(ref_points[edges[1]] - pred_points[edges[0]], dim=1) collect_bools = dists < intrapoint_cutoff inds_within_cutoff = edges[0][collect_bools] inside_edge_nodes = edges[1][collect_bools] collected_particles = pred_points[inds_within_cutoff] collected_particle_weights = nodewise_weights[inds_within_cutoff] # # confirm each output is mapped to a single input # a, b = torch.unique(edges[0][collect_bools], return_counts=True) # assert b.max() == 1 pred_particle_weights = scatter(collected_particle_weights, inside_edge_nodes, reduce='sum', dim_size=mol_batch.num_nodes, ) # filter here for where we do not match the scaffold (no nearby nodes, or insufficient probability mass) missing_particle_bools = (1 - pred_particle_weights).abs() >= probability_threshold complete_graph_bools = scatter((~missing_particle_bools).long(), mol_batch.batch, reduce='mul', dim_size=mol_batch.num_graphs, dim=0 ).bool() pred_particle_points = scatter(collected_particles * collected_particle_weights[:, None], inside_edge_nodes, reduce='sum', dim=0, dim_size=mol_batch.num_nodes, ) pred_dists = torch.linalg.norm(ref_points - pred_particle_points, dim=1) rmsd = scatter(pred_dists, mol_batch.batch, reduce='mean', dim_size=mol_batch.num_graphs) rmsd[~complete_graph_bools] = torch.nan pred_particle_points[missing_particle_bools] *= torch.nan return rmsd, pred_dists, complete_graph_bools, ~missing_particle_bools, pred_particle_points, pred_particle_weights