Source code for mxtaltools.common.mol_classifier_utils

import numpy as np
import torch
import torch.nn.functional as F

from mxtaltools.common.utils import softmax_np
from mxtaltools.common.geometry_utils import coor_trans_matrix_np
from mxtaltools.constants.classifier_constants import defect_names

from sklearn.metrics import roc_auc_score, confusion_matrix, f1_score
import plotly.graph_objects as go

import pandas as pd

"""
utility functions for the molecular crystal local environment classification module
"""


[docs] def delete_pandas_dataframe_rows(df: pd.DataFrame, inds): df.drop(index=inds, inplace=True) df.reset_index(drop=True, inplace=True) return df
#
[docs] def convert_box_to_cell_params(box_params): T_fc_list, angles, lengths = convert_box_to_cell_vectors(box_params) T_fc_list2 = [] # double check the answer for i in range(len(lengths)): T_fc_list2.append(coor_trans_matrix_np('f_to_c', lengths[i], angles[i])) T_fc_list2 = np.stack(T_fc_list2) assert np.sum(np.abs(T_fc_list - T_fc_list2)) < 1e-3 return T_fc_list
[docs] def convert_box_to_cell_vectors(box_params): """ LAMMPS periodic box style ITEM: BOX BOUNDS xy xz yz xlo_bound xhi_bound xy ylo_bound yhi_bound xz zlo_bound zhi_bound yz a = xhi-xlo, 0, 0 b = xy, yhi-ylo, 0 c = xz, yz, zhi-zlo xlo = xlo_bound - MIN(0, xy, xz, xy+xz) xhi = xhi_bound - MAX(0, xy, xz, xy+xz) ylo = ylo_bound - MIN(0, yz) yhi = yhi_bound - MAX(0, yz) zlo = zlo_bound zhi = zhi_bound """ try: box_params = np.stack(box_params) except ValueError: # pad zeros to orthorhombic boxes box_params = box_params.tolist() for ind in range(len(box_params)): if box_params[ind].shape[-1] == 2: box_params[ind] = np.concatenate([box_params[ind], np.zeros(3)[:, None]], axis=-1) box_params = np.stack(box_params) xlo_bound = box_params[:, 0, 0] ylo_bound = box_params[:, 1, 0] zlo_bound = box_params[:, 2, 0] xhi_bound = box_params[:, 0, 1] yhi_bound = box_params[:, 1, 1] zhi_bound = box_params[:, 2, 1] if box_params[0].shape == (3, 3): # non-orthogonal box xy = box_params[:, 0, 2] xz = box_params[:, 1, 2] yz = box_params[:, 2, 2] else: xy = np.zeros_like(xhi_bound) xz = np.zeros_like(xy) yz = np.zeros_like(xy) xlo = xlo_bound - np.stack((np.zeros_like(xy), xy, xz, xy + xz)).T.min(1) xhi = xhi_bound - np.stack((np.zeros_like(xy), xy, xz, xy + xz)).T.max(1) ylo = ylo_bound - np.stack((np.zeros_like(yz), yz)).T.min(1) yhi = yhi_bound - np.stack((np.zeros_like(yz), yz)).T.max(1) zlo = zlo_bound zhi = zhi_bound av = np.asarray([xhi - xlo, np.zeros_like(xhi), np.zeros_like(xhi)]).T bv = np.asarray([xy, yhi - ylo, np.zeros_like(xy)]).T cv = np.asarray([xz, yz, zhi - zlo]).T T_fc_list = np.zeros((len(av), 3, 3)) for i in range(len(T_fc_list)): # warning dubious T_fc_list[i] = np.stack((av[i], bv[i], cv[i])).T a = xhi - xlo b = np.sqrt((yhi - ylo) ** 2 + xy ** 2) c = np.sqrt((zhi - zlo) ** 2 + xz ** 2 + yz ** 2) alpha = np.arccos((xy * xz + (yhi - ylo) * yz) / (b * c)) beta = np.arccos(xz / c) gamma = np.arccos(xy / b) lengths = np.stack([a, b, c]).T angles = np.stack([alpha, beta, gamma]).T return T_fc_list, angles, lengths
[docs] def reindex_mols(dataset, i, mol_num_atoms): ref_coords = torch.Tensor(dataset.loc[i]['coordinates'][0]) atoms = dataset.loc[i]['atom_type'][0] atomic_numbers = torch.tensor(atoms, dtype=torch.long) num_molecules = (len(ref_coords)) // mol_num_atoms # todo this must be adaptive e.g., inclusive of defects mol_ind = torch.tensor(dataset.loc[i]['mol_ind'][0], dtype=torch.long) assert num_molecules == len(torch.unique(mol_ind)) return atomic_numbers, mol_ind, num_molecules, ref_coords
# # def filter_mols(dataset, dataset_path, early_only, filter_early, melt_only, no_melt, temperatures, # periodic_only, aperiodic_only, max_box_volume, min_box_volume): # if temperatures is not None: # good_inds = [] # for temperature in temperatures: # good_inds.append(np.argwhere(np.asarray(dataset['temperature']) == temperature)[:, 0]) # # good_inds = np.unique(np.concatenate(good_inds)) # bad_inds = np.asarray([ind for ind in np.arange(len(dataset)) if ind not in good_inds]) # print(f"Temperature filter killed {len(bad_inds)} out of {len(dataset)} samples") # # dataset = delete_from_dataframe(dataset, bad_inds) # dataset = dataset.reset_index().drop(columns='index') # if filter_early: # bad_inds = np.argwhere(np.asarray(dataset['time_step']) <= int(1e4))[:, # 0] # filter first 10ps steps for equilibration # print(f"Early filter killed {len(bad_inds)} out of {len(dataset)} samples") # # dataset = delete_from_dataframe(dataset, bad_inds) # dataset = dataset.reset_index().drop(columns='index') # if early_only: # bad_inds = np.argwhere(np.asarray(dataset['time_step']) >= int(1e6))[:, 0] # keep only 1 ns maximum # print(f"Early only filter killed {len(bad_inds)} out of {len(dataset)} samples") # # dataset = delete_from_dataframe(dataset, bad_inds) # dataset = dataset.reset_index().drop(columns='index') # # if max_box_volume is not None: # T_fc_list = torch.Tensor(convert_box_to_cell_params(dataset['cell_params'])) # approx_box_volume = (T_fc_list[:, 0, 0] * T_fc_list[:, 1, 1] * T_fc_list[:, 2, 2]) # bad_inds = np.argwhere(approx_box_volume > max_box_volume)[0, :] # print(f"Max box filter killed {len(bad_inds)} out of {len(dataset)} samples") # # dataset = delete_from_dataframe(dataset, bad_inds) # dataset = dataset.reset_index().drop(columns='index') # # if min_box_volume is not None: # T_fc_list = torch.Tensor(convert_box_to_cell_params(dataset['cell_params'])) # approx_box_volume = (T_fc_list[:, 0, 0] * T_fc_list[:, 1, 1] * T_fc_list[:, 2, 2]) # bad_inds = np.argwhere(approx_box_volume < min_box_volume)[0, :] # print(f"Min box filter killed {len(bad_inds)} out of {len(dataset)} samples") # # dataset = delete_from_dataframe(dataset, bad_inds) # dataset = dataset.reset_index().drop(columns='index') # # if periodic_only or aperiodic_only: # num_atoms = np.asarray([len(thing[0]) for thing in dataset['atom_type']]) # T_fc_list = torch.Tensor(convert_box_to_cell_params(dataset['cell_params'])) # density = num_atoms / (T_fc_list[:, 0, 0] * T_fc_list[:, 1, 1] * T_fc_list[:, 2, 2]) # # if periodic_only: # bad_inds = np.argwhere(density <= 0.025).flatten() # print(f"Periodic only filter killed {len(bad_inds)} out of {len(dataset)} samples") # # dataset = delete_from_dataframe(dataset, bad_inds) # dataset = dataset.reset_index().drop(columns='index') # if aperiodic_only: # bad_inds = np.argwhere(density > 0.025).flatten() # print(f"Aperiodic only filter killed {len(bad_inds)} out of {len(dataset)} samples") # # dataset = delete_from_dataframe(dataset, bad_inds) # dataset = dataset.reset_index().drop(columns='index') # if True: # if 'gap_rate' in dataset.columns: # bad_inds = np.argwhere(np.asarray(dataset['gap_rate']) > 0)[:, 0] # cannot process gaps right now # print(f"No Gaps filter killed {len(bad_inds)} out of {len(dataset)} samples") # # dataset = delete_from_dataframe(dataset, bad_inds) # dataset = dataset.reset_index().drop(columns='index') # # forms = np.sort(np.unique(dataset['form'])) # # forms2tgt = {form: i for i, form in enumerate(forms)} # targets = np.asarray( # dataset['form']) - 1 # no longer need to reindex, as we have this now managed through the constants # # this will throw an error later if the combined dataset is missing any forms, but it shouldn't be missing any forms in general # # so that's fine # # set high temperature samples to 'melted' class # if 'urea' in dataset_path: # melt_class_num = 6 # else: # melt_class_num = 9 # nicotinamide # if melt_only: # bad_inds = np.argwhere(targets != melt_class_num)[:, 0] # print(f"Melt only filter killed {len(bad_inds)} out of {len(dataset)} samples") # # good_inds = np.argwhere(targets == melt_class_num)[:, 0] # dataset = delete_from_dataframe(dataset, bad_inds) # dataset = dataset.reset_index().drop(columns='index') # targets = targets[good_inds] # if no_melt: # bad_inds = np.argwhere(targets == melt_class_num)[:, 0] # print(f"No Melt filter killed {len(bad_inds)} out of {len(dataset)} samples") # # good_inds = np.argwhere(targets != melt_class_num)[:, 0] # # dataset = delete_from_dataframe(dataset, bad_inds) # dataset = dataset.reset_index().drop(columns='index') # targets = targets[good_inds] # return dataset, targets
[docs] def identify_surface_molecules(cluster_coords, cluster_targets, conv_cutoff, good_mols, mol_num_atoms, mol_radii): coord_shell_num = 20 true_max_mol_radius = torch.amax(mol_radii[good_mols]) centroids = cluster_coords.mean(1) dist = torch.cdist(centroids, centroids) coordination_cutoff = true_max_mol_radius + conv_cutoff coordination_number = (dist < coordination_cutoff).sum(1) surface_mols_ind = torch.argwhere(coordination_number < coord_shell_num)[:, 0] defect_type = torch.zeros_like(cluster_targets) defect_type[surface_mols_ind] = 1 # defect type 1 is surfaces # cluster_targets[surface_mols_ind] = len(forms2tgt) # label surface molecules as 'disordered' cluster_mol_ind = torch.arange(len(good_mols)).repeat(mol_num_atoms, 1).T return centroids, cluster_mol_ind, coordination_number, defect_type
[docs] def pare_fragmented_molecules(cluster_atoms, cluster_coords, cluster_targets, pare_fragmented): """ Identify molecules which are fragmented, or split across a periodic boundary, and delete them. Fragmented molecules are identified as having significanly larger molecular radii than normal. Args: cluster_atoms: cluster_coords: cluster_targets: pare_fragmented: Returns: """ mol_centroids = cluster_coords.mean(1) intramolecular_centroid_dists = torch.linalg.norm(mol_centroids[:, None, :] - cluster_coords, dim=2) mol_radii = intramolecular_centroid_dists.amax(1) max_mol_radius = torch.quantile(mol_radii, 0.05) * 1.25 # 25% leniency on the 5% quantile if pare_fragmented: good_mols = torch.argwhere(mol_radii < max_mol_radius)[:, 0] else: good_mols = torch.arange(len(mol_radii)) # keep everything if periodic cluster_coords = cluster_coords[good_mols] cluster_atoms = cluster_atoms[good_mols] cluster_targets = cluster_targets[good_mols] return cluster_atoms, cluster_coords, cluster_targets, good_mols, mol_radii
[docs] def compute_mol_radii(cluster_coords, pare_fragmented): mol_centroids = cluster_coords.mean(1) intramolecular_centroid_dists = torch.linalg.norm(mol_centroids[:, None, :] - cluster_coords, dim=2) mol_radii = intramolecular_centroid_dists.amax(1) if pare_fragmented: max_mol_radius = torch.quantile(mol_radii, 0.05) * 1.25 # 25% leniency on the 5% quantile good_mols = torch.argwhere(mol_radii < max_mol_radius)[:, 0] else: good_mols = torch.arange(len(mol_radii)) return good_mols, mol_radii
[docs] def reindex_molecules(atomic_numbers, i, mol_ind, num_molecules, ref_coords, targets): cluster_coords, cluster_atoms = [], [] for ind in torch.unique(mol_ind): inds = mol_ind == ind cluster_coords.append(ref_coords[inds]) cluster_atoms.append(atomic_numbers[inds]) cluster_coords = torch.stack(cluster_coords) cluster_atoms = torch.stack(cluster_atoms) cluster_targets = torch.tensor(targets[i].repeat(num_molecules), dtype=torch.long) return cluster_atoms, cluster_coords, cluster_targets
[docs] def force_molecules_into_box(T_fc_list, cluster_coords, i, periodic): """ will fail on fragmented molecules or molecules otherwise wrapped """ # recenter about zero if periodic: # don't do this for clusters or other floating objects cluster_coords -= cluster_coords.amin((0, 1))[None, None, :] mol_centroids = cluster_coords.mean(1) frac_mol_centroids = mol_centroids @ torch.linalg.inv(T_fc_list[i].T) adjustment_fractional_vector = -torch.floor(frac_mol_centroids) adjustment_cart_vector = adjustment_fractional_vector @ T_fc_list[i].T cluster_coords += adjustment_cart_vector[:, None, :] return cluster_coords
[docs] def pare_cluster_radius(cluster_atoms, cluster_coords, cluster_targets, max_cluster_radius): mol_centroid_dists = torch.linalg.norm(cluster_coords.mean(1) - cluster_coords.mean((0, 1)), dim=1) good_mols = torch.argwhere(mol_centroid_dists < max_cluster_radius)[:, 0] # 60 angstrom sphere at maximum cluster_coords = cluster_coords[good_mols] cluster_atoms = cluster_atoms[good_mols] cluster_targets = cluster_targets[good_mols] return cluster_atoms, cluster_coords, cluster_targets
[docs] def classifier_reporting(true_labels, true_defects, probs, class_names, ordered_class_names, wandb, epoch_type): present_classes = np.unique(true_labels) present_class_names = [ordered_class_names[ind] for ind in present_classes] type_probs = softmax_np(probs[:, present_classes]) predicted_class = np.argmax(type_probs, axis=1) present_defects = np.unique(true_defects) present_defect_names = [defect_names[ind] for ind in present_defects] defect_probs = softmax_np(probs[:, len(present_classes):]) predicted_defect = np.argmax(defect_probs, axis=-1) train_score = roc_auc_score(true_labels, type_probs, multi_class='ovo') train_f1_score = f1_score(true_labels, predicted_class, average='micro') train_cmat = confusion_matrix(true_labels, predicted_class, normalize='true') fig = go.Figure(go.Heatmap(z=train_cmat, x=present_class_names, y=present_class_names)) fig.update_layout(xaxis=dict(title="Predicted Forms"), yaxis=dict(title="True Forms") ) wandb.log({f"{epoch_type} ROC_AUC": train_score, f"{epoch_type} F1 Score": train_f1_score, f"{epoch_type} 1-ROC_AUC": 1 - train_score, f"{epoch_type} 1-F1 Score": 1 - train_f1_score, f"{epoch_type} Confusion Matrix": fig}) if len(present_defects) > 1: train_score = roc_auc_score(true_defects, defect_probs[:, 1], multi_class='ovo') train_f1_score = f1_score(true_defects, predicted_defect, average='micro') train_cmat = confusion_matrix(true_defects, predicted_defect, normalize='true') fig = go.Figure(go.Heatmap(z=train_cmat, x=present_defect_names, y=present_defect_names)) fig.update_layout(xaxis=dict(title="Predicted Defect"), yaxis=dict(title="True Defect") ) wandb.log({f"{epoch_type} Defect ROC_AUC": train_score, f"{epoch_type} Defect F1 Score": train_f1_score, f"{epoch_type} 1-Defect ROC_AUC": 1 - train_score, f"{epoch_type} 1-Defect F1 Score": 1 - train_f1_score, f"{epoch_type} Defect Confusion Matrix": fig})
[docs] def reload_model(model, device, optimizer, path, reload_optimizer=False): """ load model and state dict from path includes fix for potential dataparallel issue """ checkpoint = torch.load(path, map_location=device, weights_only=False) if list(checkpoint['model_state_dict'])[0][ 0:6] == 'module': # when we use dataparallel it breaks the state_dict - fix it by removing word 'module' from in front of everything for i in list(checkpoint['model_state_dict']): checkpoint['model_state_dict'][i[7:]] = checkpoint['model_state_dict'].pop(i) model.load_state_dict(checkpoint['model_state_dict']) if optimizer is not None: if reload_optimizer: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) return model, optimizer
[docs] def record_step_results(results_dict, output, sample, data, latents, embeddings, step, config, index_offset=0): if results_dict is None: results_dict = {'Temperature': [], 'Time_Step': [], 'Loss': [], 'Type_Prediction': [], 'Defect_Prediction': [], 'Targets': [], 'Defects': [], 'Latents': [], 'Sample_Index': [], 'Coordinates': [], 'Atom_Types': [], 'Molecule_Index': [], 'Molecule_Centroids': [], 'Coordination_Numbers': [], 'Embeddings': []} results_dict['Loss'].append(get_loss(output, sample, config['num_forms']).cpu().detach().numpy()) results_dict['Type_Prediction'].append(F.softmax(output[:, :config['num_forms']], dim=1).cpu().detach().numpy()) results_dict['Defect_Prediction'].append(F.softmax(output[:, config['num_forms']:], dim=1).cpu().detach().numpy()) results_dict['Targets'].append(sample.y.cpu().detach().numpy()) results_dict['Defects'].append(sample.defect.cpu().detach().numpy()) results_dict['Latents'].append(latents.cpu().detach().numpy()) # ['final_activation']) results_dict['Embeddings'].append(embeddings.cpu().detach().numpy()) results_dict['Temperature'].append(np.ones(len(sample.y)) * data.tracking[0][0]) results_dict['Time_Step'].append(np.ones(len(sample.y)) * data.tracking[0][1]) results_dict['Sample_Index'].append(np.ones(len(sample.y)) * step + index_offset) results_dict['Coordinates'].append(sample.pos.cpu().detach().numpy()) results_dict['Atom_Types'].append(sample.z.cpu().detach().numpy()) results_dict['Molecule_Index'].append(sample.mol_ind.cpu().detach().numpy()) results_dict['Molecule_Centroids'].append(sample.centroid_pos[0]) results_dict['Coordination_Numbers'].append(sample.coord_number[0]) return results_dict
[docs] def process_trajectory_results_dict(results_dict, loader, mol_num_atoms): num_atoms = len(results_dict['Atomwise_Sample_Index']) num_mols = len(results_dict['Sample_Index']) molwise_results_dict = {} keys = list(results_dict.keys()) for key in keys: if len(results_dict[key]) == num_atoms: index = results_dict['Atomwise_Sample_Index'] molwise_results_dict[key] = [results_dict[key][index == ind] for ind in range(len(loader))] elif len(results_dict[key]) == num_mols: index = results_dict['Sample_Index'] molwise_results_dict['Molecule_' + key] = [results_dict[key][index == ind] for ind in range(len(loader))] # molwise_results_dict[key] = [results_dict[key][index == ind].repeat(mol_num_atoms) for ind in range(len(loader))] else: print(f"{key} is omitted from results dict") pass if 'Index' not in key: del results_dict[key] time_inds = [time[0] for time in molwise_results_dict['Molecule_Time_Step']] sort_inds = np.argsort(np.asarray(time_inds)) sorted_molwise_results_dict = {} for key in molwise_results_dict.keys(): sorted_molwise_results_dict[key] = [molwise_results_dict[key][ind] for ind in sort_inds] centroid_dists = [] for ind in range(len(sorted_molwise_results_dict['Coordinates'])): coords = sorted_molwise_results_dict['Coordinates'][ind] centroids = coords.reshape(coords.shape[0] // mol_num_atoms, mol_num_atoms, 3).mean(1) centroid_dists.append(np.linalg.norm(centroids - centroids.mean(0), axis=1)) sorted_molwise_results_dict['Centroid Radii'] = centroid_dists return sorted_molwise_results_dict, np.asarray(time_inds)[sort_inds]
[docs] def get_loss(output, sample, num_forms): return F.cross_entropy(output[:, :num_forms], sample.y) + F.cross_entropy(output[:, num_forms:], sample.defect)