Source code for mxtaltools.models.utils

from argparse import Namespace
from typing import Union, Optional

import numpy as np
import torch
from torch.nn import functional as F
from torch_scatter import scatter
from tqdm import tqdm

from mxtaltools.common.geometry_utils import cell_vol_torch, components2angle, enforce_crystal_system, \
    batch_molecule_principal_axes_torch
from mxtaltools.common.training_utils import get_n_config
from mxtaltools.common.utils import softmax_np
from mxtaltools.crystal_building.utils import descale_asymmetric_unit, rescale_asymmetric_unit
from mxtaltools.dataset_utils.utils import collate_data_list
from mxtaltools.models.task_models.autoencoder_models import Mo3ENet


[docs] def softmax_and_score(raw_classwise_output, temperature=1, old_method=False, correct_discontinuity=True) -> Union[ torch.Tensor, np.ndarray]: """ Parameters ---------- raw_classwise_output: numpy array or torch tensor with dimension [n,2], representing the non-normalized [false,true] probabilities temperature: softmax temperature old_method: use more complicated method from first paper correct_discontinuity: correct discontinuity at 0 only in the old method Returns ------- score: linearizes the input probabilities from (0,1) to [-inf, inf] for easier visualization """ if not old_method: # turns out you get almost identically the same answer by simply dividing the activations, much simpler if torch.is_tensor(raw_classwise_output): soft_activation = F.softmax(raw_classwise_output, dim=-1) score = torch.log10(soft_activation[:, 1] / soft_activation[:, 0]) # if torch.sum(torch.isnan(score)) > 0: # raise ValueError("Numerical Error: discriminator output is not finite") return score else: soft_activation = softmax_np(raw_classwise_output) score = np.log10(soft_activation[:, 1] / soft_activation[:, 0]) # if np.sum(np.isnan(score)) > 0: # raise ValueError("Numerical Error: discriminator output is not finite") return score else: if correct_discontinuity: correction = 1 else: correction = 0 if isinstance(raw_classwise_output, np.ndarray): softmax_output = softmax_np(raw_classwise_output.astype('float64'), temperature)[:, 1].astype( 'float64') # values get too close to zero for float32 tanned = np.tan((softmax_output - 0.5) * np.pi) sign = (raw_classwise_output[:, 1] > raw_classwise_output[:, 0]) * 2 - 1 # values very close to zero can realize a sign error return sign * np.log10(correction + np.abs(tanned)) # new factor of 1+ conditions the function about zero elif torch.is_tensor(raw_classwise_output): softmax_output = F.softmax(raw_classwise_output / temperature, dim=-1)[:, 1] tanned = torch.tan((softmax_output - 0.5) * torch.pi) sign = (raw_classwise_output[:, 1] > raw_classwise_output[:, 0]) * 2 - 1 # values very close to zero can realize a sign error return sign * torch.log10(correction + torch.abs(tanned))
[docs] def norm_scores(score, tracking_features, dataDims): """ norm the incoming score according to some feature of the molecule (generally size) """ volume = tracking_features[:, dataDims['tracking_features'].index('molecule volume')] return score / volume
[docs] def enforce_1d_bound(x: torch.tensor, x_span, x_center, mode='soft'): # soft or hard """ constrains function to range x_center plus/minus x_span Parameters ---------- x x_span x_center mode Returns ------- """ if mode == 'soft': # smoothly converge to (center-span,center+span) bounded = F.tanh((x - x_center) / x_span) * x_span + x_center elif mode == 'hard': # linear scaling to hard stop at [center-span, center+span] bounded = F.hardtanh((x - x_center) / x_span) * x_span + x_center else: raise ValueError("bound must be of type 'hard' or 'soft'") return bounded
[docs] def undo_1d_bound(x: torch.tensor, x_span, x_center, mode='soft'): """ undo / rescale an enforced 1d bound only setup for soft rescaling """ # todo: write a version for hard bounds if mode == 'soft': out = x_span * torch.atanh((x - x_center) / x_span) + x_center return out elif mode == 'hard': # linear scaling to hard stop at [center-span, center+span] raise ValueError("'hard' bound not yet implemented") else: raise ValueError("bound must be of type 'soft'")
[docs] def compute_reduced_volume_fraction(cell_lengths: torch.tensor, cell_angles: torch.tensor, atom_radii: torch.tensor, batch: torch.tensor, crystal_multiplicity: torch.tensor): """ # TODO DEPRECATE IN FAVOUR OF PACKING COEFFICIENT Args: cell_lengths: cell_angles: atom_radii: crystal_multiplicity: Returns: asymmetric unit volume / sum of vdw volumes - so-called 'reduced volume fraction' """ cell_volumes = torch.zeros(len(cell_lengths), dtype=torch.float32, device=cell_lengths.device) for i in range(len(cell_lengths)): # todo switch to the parallel version of this function cell_volumes[i] = cell_vol_torch(cell_lengths[i], cell_angles[i]) return (cell_volumes / crystal_multiplicity) / scatter(4 / 3 * torch.pi * atom_radii ** 3, batch, reduce='sum')
[docs] def clean_generator_output(samples=None, lattice_lengths=None, lattice_angles=None, mol_positions=None, mol_orientations=None, lattice_means=None, lattice_stds=None, destandardize=True, mode='soft', skip_angular_dof=False): """ # TODO rewrite - this is a very important function but it's currently a disaster convert from raw model output to the actual cell parameters with appropriate bounds considering raw outputs to be in the standardized basis, we destandardize, then enforce bounds """ '''separate components''' if samples is not None: lattice_lengths = samples[:, :3] lattice_angles = samples[:, 3:6] mol_positions = samples[:, 6:9] mol_orientations = samples[:, 9:] '''destandardize & decode angles''' if destandardize: real_lattice_lengths = lattice_lengths * lattice_stds[:3] + lattice_means[:3] real_lattice_angles = lattice_angles * lattice_stds[3:6] + lattice_means[ 3:6] # not bothering to encode as an angle real_mol_positions = mol_positions * lattice_stds[6:9] + lattice_means[6:9] if mol_orientations.shape[-1] == 3: real_mol_orientations = mol_orientations * lattice_stds[9:] + lattice_means[9:] else: real_mol_orientations = mol_orientations * 1 else: # optionally, skip destandardization if we are already in the real basis real_lattice_lengths = lattice_lengths * 1 real_lattice_angles = lattice_angles * 1 real_mol_positions = mol_positions * 1 real_mol_orientations = mol_orientations * 1 if mol_orientations.shape[-1] == 6: theta, phi, r_i = decode_to_sph_rotvec(real_mol_orientations) # already have angles, no need to decode # todo deprecate - we will only use spherical components in future elif mol_orientations.shape[-1] == 3: if mode is not None: theta = enforce_1d_bound(real_mol_orientations[:, 0], x_span=torch.pi / 4, x_center=torch.pi / 4, mode=mode)[:, None] phi = enforce_1d_bound(real_mol_orientations[:, 1], x_span=torch.pi, x_center=0, mode=mode)[:, None] r_i = enforce_1d_bound(real_mol_orientations[:, 2], x_span=torch.pi, x_center=torch.pi, mode=mode)[:, None] else: theta, phi, r_i = real_mol_orientations r = torch.maximum(r_i, torch.ones_like(r_i) * 0.01) # MUST be nonzero clean_mol_orientations = torch.cat((theta, phi, r), dim=-1) '''enforce physical bounds''' if mode is not None: if mode == 'soft': clean_lattice_lengths = F.softplus(real_lattice_lengths - 0.01) + 0.01 # smoothly enforces positive nonzero elif mode == 'hard': clean_lattice_lengths = torch.maximum(F.relu(real_lattice_lengths), torch.ones_like( real_lattice_lengths)) # harshly enforces positive nonzero clean_lattice_angles = enforce_1d_bound(real_lattice_angles, x_span=torch.pi / 2 * 0.8, x_center=torch.pi / 2, mode=mode) # range from (0,pi) with 20% limit to prevent too-skinny cells clean_mol_positions = enforce_1d_bound(real_mol_positions, 0.5, 0.5, mode=mode) # enforce fractional centroids between 0 and 1 else: # do nothing clean_lattice_lengths, clean_lattice_angles, clean_mol_positions = real_lattice_lengths, real_lattice_angles, real_mol_positions return clean_lattice_lengths, clean_lattice_angles, clean_mol_positions, clean_mol_orientations
[docs] def decode_to_sph_rotvec(mol_orientations): """ each angle is predicted with 2 params we bound the encodings for theta on 0-1 to restrict the range of theta to [0,pi/2] """ theta_encoding = F.sigmoid(mol_orientations[:, 0:2]) # restrict to positive quadrant real_orientation_theta = components2angle(theta_encoding) # from the sigmoid, [0, pi/2] real_orientation_phi = components2angle(mol_orientations[:, 2:4]) # unrestricted [-pi,pi] real_orientation_r = components2angle( mol_orientations[:, 4:6]) + torch.pi # shift from [-pi,pi] to [0, 2pi] # want vector to have a positive norm return real_orientation_theta[:, None], real_orientation_phi[:, None], real_orientation_r[:, None]
[docs] def decode_to_sph_rotvec2(mol_orientation_components): """ # todo decide whether to use/keep or deprecate this each angle is predicted with 2 params we bound the encodings for theta on 0-1 to restrict the range of theta to [0,pi/2] identical to the above, but considering theta as a simple scalar [n, 5] input to [n, 3] output """ # theta_encoding = F.sigmoid(mol_orientations[:, 0:2]) # restrict to positive quadrant # real_orientation_theta = components2angle(theta_encoding) # from the sigmoid, [0, pi/2] real_orientation_phi = components2angle(mol_orientation_components[:, 1:3]) # unrestricted [-pi,pi] real_orientation_r = components2angle(mol_orientation_components[:, 3:5]) + torch.pi # shift from [-pi,pi] to [0, 2pi] # want vector to have a positive norm return mol_orientation_components[:, 0, None], real_orientation_phi[:, None], real_orientation_r[:, None]
[docs] def get_regression_loss(regressor, data, targets, mean, std): predictions = regressor(data).flatten() assert targets.shape == predictions.shape return (F.smooth_l1_loss(predictions, targets, reduction='none'), predictions.detach() * std + mean, targets.detach() * std + mean)
[docs] def dict_of_tensors_to_cpu_numpy(stats): for key, value in stats.items(): if torch.is_tensor(value): stats[key] = value.cpu().numpy() elif 'DataBatch' in str(type(value)): stats[key] = value.cpu()
[docs] def increment_value(value, increment, maxval, minval=0): return max(minval, min(value + increment, maxval))
[docs] def clean_cell_params(samples, sg_inds, lattice_means, lattice_stds, symmetries_dict, asym_unit_dict, rescale_asymmetric_unit=True, destandardize=False, mode='soft', fractional_basis='asymmetric_unit', skip_angular_dof=False): """ # todo deprecate An important function for enforcing physical limits on cell parameterization with randomly generated samples of different sources. Parameters ---------- skip_angular_dof samples: torch.Tensor sg_inds: torch.LongTensor lattice_means: torch.Tensor lattice_stds: torch.Tensor symmetries_dict: dict asym_unit_dict: dict rescale_asymmetric_unit: bool destandardize: bool mode: str, "hard" or "soft" fractional_basis: bool Returns ------- """ lattice_lengths = samples[:, :3] lattice_angles = samples[:, 3:6] mol_orientations = samples[:, 9:] if fractional_basis == 'asymmetric_unit': # basis is 0-1 within the asymmetric unit mol_positions = samples[:, 6:9] elif fractional_basis == 'unit_cell': # basis is 0-1 within the unit cell mol_positions = descale_asymmetric_unit(asym_unit_dict, samples[:, 6:9], sg_inds) else: assert False, f"{fractional_basis} is not an implemented fractional basis" lattice_lengths, lattice_angles, mol_positions, mol_orientations \ = clean_generator_output(lattice_lengths=lattice_lengths, lattice_angles=lattice_angles, mol_positions=mol_positions, mol_orientations=mol_orientations, lattice_means=lattice_means, lattice_stds=lattice_stds, destandardize=destandardize, mode=mode, skip_angular_dof=skip_angular_dof) fixed_lengths, fixed_angles = ( enforce_crystal_system(lattice_lengths, lattice_angles, sg_inds, symmetries_dict)) if rescale_asymmetric_unit: fixed_positions = descale_asymmetric_unit(asym_unit_dict, mol_positions, sg_inds) else: fixed_positions = mol_positions * 1 '''collect''' final_samples = torch.cat(( fixed_lengths, fixed_angles, fixed_positions, mol_orientations, ), dim=-1) return final_samples
[docs] def denormalize_generated_cell_params( normed_cell_samples: torch.FloatTensor, mol_data, asym_unit_dict: dict): # denormalize the predicted cell lengths cell_lengths = torch.pow(mol_data.sym_mult * mol_data.mol_volume, 1 / 3)[:, None] * normed_cell_samples[:, :3] # rescale asymmetric units # todo add assertions around these mol_positions = descale_asymmetric_unit(asym_unit_dict, normed_cell_samples[:, 6:9], mol_data.sg_ind) generated_samples_to_build = torch.cat( [cell_lengths, normed_cell_samples[:, 3:6], mol_positions, normed_cell_samples[:, 9:12]], dim=1) return generated_samples_to_build
[docs] def renormalize_generated_cell_params( generator_raw_samples, mol_data, asym_unit_dict): # renormalize the predicted cell lengths cell_lengths = generator_raw_samples[:, :3] / torch.pow(mol_data.sym_mult * mol_data.mol_volume, 1 / 3)[:, None] # rescale asymmetric units # todo add assertions around these mol_positions = rescale_asymmetric_unit(asym_unit_dict, generator_raw_samples[:, 6:9], mol_data.sg_ind) generated_samples_to_build = torch.cat( [cell_lengths, generator_raw_samples[:, 3:6], mol_positions, generator_raw_samples[:, 9:12]], dim=1) return generated_samples_to_build
[docs] def compute_prior_loss(norm_factors: torch.Tensor, sg_inds: torch.LongTensor, generator_raw_samples: torch.Tensor, prior: torch.Tensor, variation_factor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Take the norm of the scaled distances between prior and generated samples, and apply a quadratic penalty when it is larger than variation_factor Parameters ---------- data generator_raw_samples prior variation_factor Returns ------- """ scaling_factor = (norm_factors[sg_inds, :] + 1e-4) scaled_deviation = torch.abs(prior - generator_raw_samples) / scaling_factor prior_loss = F.relu(torch.linalg.norm(scaled_deviation, dim=1) - variation_factor) ** 2 # 'flashlight' search return prior_loss, scaled_deviation
[docs] def get_mol_embedding_for_proxy(crystal_batch, embedding_type, encoder: Optional = None): crystal_batch.pose_aunit() # get correct orientation crystal_batch.recenter_molecules() ipm_means = torch.tensor([35, 91, 105], dtype=torch.float32, device=crystal_batch.device) mol_volume_mean = 70 if embedding_type == 'autoencoder': v_embedding = encoder.encode(crystal_batch.clone()) s_embedding = encoder.scalarizer(v_embedding) elif embedding_type == 'principal_axes': v_embedding_i, s_embedding_i, _ = batch_molecule_principal_axes_torch( crystal_batch.pos, crystal_batch.batch, crystal_batch.num_graphs, crystal_batch.num_atoms, heavy_atoms_only=True, atom_types=crystal_batch.z ) s_embedding = s_embedding_i / ipm_means[None, :] v_embedding = v_embedding_i.permute(0, 2, 1) elif embedding_type == 'principal_moments': Ip, s_embedding_i, _ = batch_molecule_principal_axes_torch( crystal_batch.pos, crystal_batch.batch, crystal_batch.num_graphs, crystal_batch.num_atoms, heavy_atoms_only=True, atom_types=crystal_batch.z ) v_embedding = torch.zeros_like(Ip) s_embedding = s_embedding_i / ipm_means[None, :] elif embedding_type == 'mol_volume': s_embedding = crystal_batch.mol_volume[:, None].repeat(1, 3) / mol_volume_mean v_embedding = torch.zeros((crystal_batch.num_graphs, 3, 3), dtype=torch.float32, device=crystal_batch.device) elif embedding_type is None: s_embedding = torch.zeros_like(crystal_batch.mol_volume[:, None].repeat(1, 3)) v_embedding = torch.zeros((crystal_batch.num_graphs, 3, 3), dtype=torch.float32, device=crystal_batch.device) else: assert False, f"{embedding_type} is not an implemented proxy discriminator embedding" return torch.cat([s_embedding, v_embedding.flatten(-2)], dim=-1)
[docs] def embed_crystal_list( batch_size: int, crystal_list: list, embedding_type: str, encoder_checkpoint_path: Optional = None, device: Optional[str] = 'cpu', redo_crystal_analysis: Optional[bool] = False ) -> list: if encoder_checkpoint_path is not None: encoder = load_encoder(encoder_checkpoint_path).to(device) embeddings = [] num_chunks = len(crystal_list) // batch_size + int(len(crystal_list) % batch_size != 0) lj_pots, scaled_lj_pots, es_pots, bh_pots = (torch.zeros(len(crystal_list), dtype=torch.float32, device=device) for _ in range(4)) with torch.no_grad(): for ind in tqdm(range(num_chunks)): # do it this way so to avoid shuffling sample_inds = torch.arange(ind * batch_size, min((ind + 1) * batch_size, len(crystal_list))) crystal_batch = collate_data_list([crystal_list[ind] for ind in sample_inds] ).to(device) if redo_crystal_analysis: lj_pots[sample_inds], es_pots[sample_inds], scaled_lj_pots[ sample_inds], cluster_batch = crystal_batch.build_and_analyze(cutoff=10, return_cluster=True) bh_pots[sample_inds] = cluster_batch.compute_buckingham_energy() embedding = crystal_batch.do_embedding(embedding_type, encoder ).cpu().detach() if ind == 0: embeddings = torch.zeros( (len(crystal_list), embedding.shape[-1]), dtype=torch.float32, device=device ) embeddings[ind * batch_size:(ind + 1) * batch_size] = embedding for ind in range(len(crystal_list)): crystal_list[ind].embedding = embeddings[None, ind] if redo_crystal_analysis: crystal_list[ind].lj = lj_pots[ind].cpu() crystal_list[ind].es_pot = es_pots[ind].cpu() crystal_list[ind].scaled_lj = scaled_lj_pots[ind].cpu() crystal_list[ind].bh_pot = bh_pots[ind].cpu() return crystal_list
[docs] def load_encoder(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) model_config = Namespace(**checkpoint['config']) # overwrite the settings for the model allowed_types = np.array([1, 6, 7, 8, 9]) type_translation_index = np.zeros(allowed_types.max() + 1) - 1 for ind, atype in enumerate(allowed_types): type_translation_index[atype] = ind autoencoder_type_index = torch.tensor(type_translation_index, dtype=torch.long, device='cpu') model = Mo3ENet( 0, model_config.model, 5, autoencoder_type_index, 1, # will get overwritten protons_in_input=True ) checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) if list(checkpoint['model_state_dict'])[0][ 0:6] == 'module': # when we use dataparallel it breaks the state_dict - fix it by removing word 'module' from in front of everything for i in list(checkpoint['model_state_dict']): checkpoint['model_state_dict'][i[7:]] = checkpoint['model_state_dict'].pop(i) model.load_state_dict(checkpoint['model_state_dict']) return model
[docs] def get_model_sizes(models_dict: dict): num_params_dict = {model_name + "_num_params": get_n_config(model) for model_name, model in models_dict.items()} [print( f'{model_name} {num_params_dict[model_name] / 1e6:.3f} million or {int(num_params_dict[model_name])} parameters') for model_name in num_params_dict.keys()] return num_params_dict
[docs] def test_gradient_flow(grad_params, operation): grad_params.pos.requires_grad_(True) output = operation(grad_params) grad_params.pos.retain_grad() loss = (output + 1).log().abs().sum() loss.backward() print(grad_params.pos.grad) print(loss) print(torch.count_nonzero(grad_params.pos.grad) / len(grad_params.pos.flatten())) grad_params.pos.grad.zero_()