Source code for mxtaltools.models.task_models.autoencoder_models

import torch
from torch import nn as nn
import torch.nn.functional as F

from mxtaltools.models.autoencoder_utils import decoding2mol_batch, ae_reconstruction_loss, batch_rmsd
from mxtaltools.models.graph_models.base_graph_model import BaseGraphModel
from mxtaltools.models.graph_models.graph_neural_network import VectorGNN
from mxtaltools.models.graph_models.molecule_graph_model import VectorMoleculeGraphModel
from mxtaltools.models.modules.components import Scalarizer, vectorMLP
from mxtaltools.reporting.ae_reporting import swarm_vs_tgt_fig


# noinspection PyAttributeOutsideInit
[docs] class Mo3ENet(BaseGraphModel): def __init__(self, seed, config, num_atom_types: int, atom_embedding_vector: torch.Tensor, radial_normalization: float, protons_in_input: bool, ): super(Mo3ENet, self).__init__() """ 3D o3 equivariant multi-type point cloud autoencoder model Mo3ENet """ torch.manual_seed(seed) self.cartesian_dimension = 3 self.num_classes = num_atom_types self.output_depth = self.num_classes + self.cartesian_dimension + 1 self.num_decoder_nodes = config.decoder.num_nodes self.bottleneck_dim = config.bottleneck_dim if not hasattr(config.decoder, 'model_type'): self.decoder_type = 'mlp' # old model else: self.decoder_type = config.decoder.model_type # todo add type distance scaling and num atom types and node weight temperature self.register_buffer('atom_embedding_vector', atom_embedding_vector) self.register_buffer('radial_normalization', torch.tensor(radial_normalization, dtype=torch.float32)) self.register_buffer('protons_in_input', torch.tensor(protons_in_input, dtype=torch.bool)) self.register_buffer('inferring_protons', torch.tensor(False, dtype=torch.bool)) self.register_buffer('convolution_cutoff', config.encoder.graph.cutoff / self.radial_normalization) self.encoder = Mo3ENetEncoder(seed, config.encoder, config.bottleneck_dim, override_cutoff=self.convolution_cutoff) if self.decoder_type == 'mlp': self.decoder = Mo3ENetDecoder(seed, config.decoder, config.bottleneck_dim, self.output_depth, self.num_decoder_nodes) elif self.decoder_type == 'gnn': self.decoder = Mo3ENetGraphDecoder(config.decoder, config.bottleneck_dim, self.output_depth, self.num_decoder_nodes, ) else: assert False, "Unknown decoder type" + str(self.decoder_type) self.scalarizer = Scalarizer(config.bottleneck_dim, self.cartesian_dimension, None, None, 0)
[docs] def forward(self, mol_batch, return_latent: bool = False, return_dists: bool = False, **kwargs): encoding = self.encode(mol_batch) if torch.sum(torch.isnan(encoding)) != 0: print("NaN values in encoding") decoding = self.decode(encoding) if torch.sum(torch.isnan(decoding)) != 0: print("NaN values in decoding") if return_latent: return decoding, encoding else: return decoding
[docs] def encode(self, mol_batch, override_centering: bool = False): # normalize radii if not override_centering: assert torch.linalg.norm(mol_batch.pos.mean(0)) < 1e-3, "Encoder trained only for centered molecules!" mol_batch.pos /= self.radial_normalization _, encoding = self.encoder(mol_batch) return encoding
[docs] def decode(self, encoding): """encoding nx3xk""" s = self.scalarizer(encoding) if torch.sum(torch.isnan(s)) > 0: print("NaN values in scalarized encoding") scalar_decoding, vector_decoding = self.decoder(s, v=encoding) '''combine vector and scalar features to n*nodes x m''' # de-normalize predicted node positions and rearrange to correct format # from n_graphs, x (num_nodes * scalar feats), v (num_nodes * scalar_feats) if self.decoder_type == 'mlp': decoding = torch.cat([ vector_decoding.permute(0, 2, 1).reshape(len(vector_decoding) * self.num_decoder_nodes, 3) * self.radial_normalization, scalar_decoding.reshape(len(scalar_decoding) * self.num_decoder_nodes, self.output_depth - 3)], dim=-1) elif self.decoder_type == 'gnn': decoding = torch.cat( [ vector_decoding[:, :, 0] * self.radial_normalization, scalar_decoding ], dim=1 ) else: assert False, "Unknown decoder type" + str(self.decoder_type) return decoding
[docs] def compile_self(self, dynamic=True, fullgraph=False): self.encoder = torch.compile(self.encoder, dynamic=dynamic, fullgraph=fullgraph) self.decoder = torch.compile(self.decoder, dynamic=dynamic, fullgraph=fullgraph) self.scalarizer = torch.compile(self.scalarizer, dynamic=dynamic, fullgraph=fullgraph)
[docs] def check_embedding_quality(self, mol_batch, sigma=0.35, type_distance_scaling=2, # todo next two should be properties of the model node_weight_temperature=1, num_atom_types=5, visualize=False, ): encoding = self.encode(mol_batch.clone()) decoding = self.decode(encoding) mol_batch.x = self.atom_embedding_vector[mol_batch.z].flatten() decoded_mol_batch, nodewise_graph_weights, nodewise_weights, nodewise_weights_tensor = ( decoding2mol_batch(mol_batch, decoding, self.num_decoder_nodes, node_weight_temperature, mol_batch.x.device)) (nodewise_reconstruction_loss, nodewise_type_loss, graph_reconstruction_loss, self_likelihoods, nearest_node_loss, graph_clumping_loss, nearest_component_dist, nearest_component_loss) = ae_reconstruction_loss(mol_batch, decoded_mol_batch, nodewise_weights, nodewise_weights_tensor, num_atom_types, type_distance_scaling, sigma, ) true_node_one_hot = F.one_hot(mol_batch.x.flatten().long(), num_classes=num_atom_types).float() (rmsd, pred_dists, complete_graph_bools, particle_matched_bools, pred_particle_points, pred_particle_weights)=( batch_rmsd(mol_batch, decoded_mol_batch, true_node_one_hot)) if visualize: for ind in range(mol_batch.num_graphs): swarm_vs_tgt_fig(mol_batch, decoded_mol_batch, [1, 6, 7, 8, 9], graph_ind=ind).show(renderer='browser') return graph_reconstruction_loss, rmsd, complete_graph_bools
[docs] class Mo3ENetDecoder(nn.Module): def __init__(self, seed, config, bottleneck_dim, output_depth, num_nodes): super(Mo3ENetDecoder, self).__init__() self.model = vectorMLP( seed=seed, layers=config.fc.num_layers, filters=config.fc.hidden_dim, input_dim=bottleneck_dim, vector_input_dim=bottleneck_dim, vector_output_dim=num_nodes, output_dim=(output_depth - 3) * num_nodes, activation=config.activation, norm=config.fc.norm, dropout=config.fc.dropout, vector_norm=config.fc.vector_norm, ramp_depth=config.ramp_depth, )
[docs] def forward(self, x, v): return self.model(x, v)
[docs] class Mo3ENetGraphDecoder(nn.Module): def __init__(self, config, bottleneck_dim, output_depth, num_nodes): super(Mo3ENetGraphDecoder, self).__init__() self.num_nodes = num_nodes self.hidden_dim = config.fc.hidden_dim self.model = VectorGNN( input_node_dim=config.fc.hidden_dim, node_dim=config.fc.hidden_dim, fcs_per_gc=1, message_dim=config.fc.hidden_dim // 4, embedding_dim=output_depth - 3, num_convs=config.fc.num_layers, num_radial=32, num_input_classes=101, cutoff=2, max_num_neighbors=32, envelope_exponent=5, activation='gelu', atom_type_embedding_dim=0, norm=('graph ' + config.fc.norm) if config.fc.norm is not None else None, vector_norm=('graph ' + config.fc.vector_norm) if config.fc.vector_norm is not None else None, dropout=config.fc.dropout, radial_embedding='gaussian', override_cutoff=None, v_embedding_dim=1, v_input_node_dim=config.fc.hidden_dim, ) self.s_to_nodes = nn.Linear(bottleneck_dim, config.fc.hidden_dim * num_nodes, bias=False) self.v_to_nodes = nn.Linear(bottleneck_dim, config.fc.hidden_dim * num_nodes, bias=False) self.v_to_pos = nn.Linear(bottleneck_dim, num_nodes, bias=False)
[docs] def forward(self, x, v): eps = 1e-1 num_graphs = len(x) # all combinations of edges within each graph edges = [] edges_i = torch.combinations(torch.arange(self.num_nodes), r=2, with_replacement=False).to(x.device) for ind in range(num_graphs): batch_ind = ind * self.num_nodes edges.append( batch_ind + torch.cat([edges_i, torch.fliplr(edges_i)], dim=0) ) batch = torch.arange(num_graphs, device=x.device).repeat_interleave(self.num_nodes) edges = torch.cat(edges, dim=0) edges_dict = {'edge_index': edges.T} x = self.s_to_nodes(x).reshape(num_graphs * self.num_nodes, self.hidden_dim) directions = self.v_to_pos(v).permute(0, 2, 1).reshape(num_graphs * self.num_nodes, 3, 1)[..., 0] pos = directions / (eps + torch.linalg.norm(directions, dim=1))[:, None] v = self.v_to_nodes(v).permute(0, 2, 1).reshape(num_graphs * self.num_nodes, self.hidden_dim, 3).permute(0, 2, 1) return self.model(x, v, pos, batch, edges_dict)
''' equivariance test def v_to_node(v, num_graphs): v2 = self.v_to_nodes(v).reshape(num_graphs, 3, self.hidden_dim, self.num_nodes) v2 = v2.permute(0, 3, 1, 2).flatten(0, 1) return v2 from scipy.spatial.transform import Rotation as R import numpy as np 'initialize rotations' rotations = torch.tensor( R.random(num_graphs).as_matrix() * np.random.choice((-1, 1), replace=True, size=num_graphs)[:, None, None], dtype=torch.float, device=x.device) 'rotate input' r_v = torch.einsum('ij, njk -> nik', rotations[0], v) 'get output' out1 = v_to_node(v, num_graphs) out2 = v_to_node(r_v, num_graphs) 'rotated output' r_out1 = torch.einsum('ij, njk -> nik', rotations[0], out1) print(torch.mean(torch.abs(r_out1 - out2)/out2.abs())) import plotly.graph_objects as go fig = go.Figure(go.Histogram(x=((out2 - r_out1)/out2.abs()).flatten().abs().log10().cpu().detach().numpy(), nbinsx=100)).show() def v_to_node(v, num_graphs): v2 = self.v_to_nodes(v).reshape(num_graphs, 3, self.hidden_dim, self.num_nodes) v2 = v2.permute(0, 3, 1, 2).flatten(0, 1) return v2 # ---- graph model --- from scipy.spatial.transform import Rotation as R import numpy as np 'initialize rotations' rotations = torch.tensor( R.random(num_graphs).as_matrix() * np.random.choice((-1, 1), replace=True, size=num_graphs)[:, None, None], dtype=torch.float, device=x.device) 'rotate input' r_v = torch.einsum('ij, njk -> nik', rotations[0], v) r_pos = torch.einsum('ij, nj -> ni', rotations[0], pos) 'get output' s1, out1 = self.model(x, v, pos, batch, edges_dict) s2, out2 = self.model(x, r_v, r_pos, batch, edges_dict) 'rotated output' r_out1 = torch.einsum('ij, njk -> nik', rotations[0], out1) print(torch.mean(torch.abs(r_out1 - out2)/out2.abs())) # final equivariance test if not hasattr(self, 'v0'): self.x0 = x.clone() self.v0 = v.clone() from scipy.spatial.transform import Rotation as R import numpy as np num_graphs = len(x) # all combinations of edges within each graph edges = [] edges_i = torch.combinations(torch.arange(self.num_nodes), r=2, with_replacement=False).to(x.device) for ind in range(num_graphs): batch_ind = ind * self.num_nodes edges.append( batch_ind + torch.cat([edges_i, torch.fliplr(edges_i)], dim=0) ) edges = torch.cat(edges, dim=0) edges_dict = {'edge_index': edges.T} batch = torch.arange(num_graphs, device=x.device).repeat_interleave(self.num_nodes) 'initialize rotations' rotations = torch.tensor( R.random(num_graphs).as_matrix() * np.random.choice((-1, 1), replace=True, size=num_graphs)[:, None, None], dtype=torch.float, device=x.device) x = self.x0.clone() v = self.v0.clone() def rotate_object(rotations, thing, batch, num_graphs): return torch.cat( [torch.einsum('ij, njk->nik', rotations[ind], thing[batch == ind]) for ind in range(num_graphs)]) rv = rotate_object(rotations, v, torch.arange(num_graphs, device=x.device), num_graphs) xf = self.s_to_nodes(x).reshape(num_graphs * self.num_nodes, self.hidden_dim) pos = self.v_to_pos(v).permute(0, 2, 1).reshape(num_graphs * self.num_nodes, 3, 1)[..., 0] vf = self.v_to_nodes(v).permute(0, 2, 1).reshape(num_graphs * self.num_nodes, self.hidden_dim, 3).permute(0, 2, 1) rpos = self.v_to_pos(rv).permute(0, 2, 1).reshape(num_graphs * self.num_nodes, 3, 1)[..., 0] posr = rotate_object(rotations, pos[:, :, None], batch, num_graphs)[..., 0] vfr = rotate_object(rotations, vf, batch, num_graphs) rvf = self.v_to_nodes(rv).permute(0, 2, 1).reshape(num_graphs * self.num_nodes, self.hidden_dim, 3).permute(0, 2, 1) xo, yo = self.model(xf, vf, pos, batch, edges_dict) rxo, ryo = self.model(xf, rvf, rpos, batch, edges_dict) yor = rotate_object(rotations, yo, batch, num_graphs) print(((vfr-rvf).abs()/rvf.abs()).mean()) print(((yor-ryo).abs()/ryo.abs()).mean()) print(((rpos-posr).abs()/rpos.abs()).mean()) '''
[docs] class Mo3ENetEncoder(nn.Module): def __init__(self, seed, config, bottleneck_dim, override_cutoff=None): super(Mo3ENetEncoder, self).__init__() self.model = VectorMoleculeGraphModel( input_node_dim=1, num_mol_feats=0, output_dim=bottleneck_dim, seed=seed, concat_pos_to_node_dim=True, concat_mol_to_node_dim=False, activation=config.activation, fc_config=config.fc, graph_config=config.graph, override_cutoff=override_cutoff, )
[docs] def forward(self, mol_batch): return self.model(mol_batch.z, mol_batch.pos, mol_batch.batch, mol_batch.ptr, num_graphs=mol_batch.num_graphs, )