import warnings
from typing import Union
import numpy as np
import torch
from rdkit import Chem
from torch import Tensor
from torch_geometric.typing import OptTensor
from torch_scatter import scatter
from mxtaltools.common.geometry_utils import compute_Ip_handedness, rotvec2rotmat, batch_molecule_principal_axes_torch, \
extract_rotmat, apply_rotation_to_batch, rotmat2rotvec, \
fractional_transform, center_batch
from mxtaltools.common.utils import block_repeat_interleave
[docs]
def generate_sorted_fractional_translations(supercell_size):
xrange = torch.arange(-supercell_size, supercell_size + 1, dtype=torch.int16)
xx, yy, zz = torch.meshgrid(xrange, xrange, xrange,
indexing='ij'
)
# Flatten the meshgrid and stack the results into a 2D tensor
fractional_translations = torch.stack([xx.flatten(), yy.flatten(), zz.flatten()], dim=1).float()
# sort fractional vectors from closest to furthest from central unit cell
return fractional_translations[torch.argsort(fractional_translations.abs().sum(1))]
[docs]
def ucell2cluster(crystal_batch, cutoff: float = 6,
supercell_size: int = 10,
zp_buffer: Union[float, torch.Tensor] = 0): # todo this could be modularized/simplified
"""
1) generate supercell cluster including only unit cells which could plausibly interact with the asymmetric unit
2) instantiate the cluster
3) pare cluster to asymmetric units which actually interact with the asymmetric unit
"""
"""
Generate set of unit cells to instantiate
"""
# this method is simple but depends on the canonical conformer being properly set as the raw position
cc_centroids = scatter(crystal_batch.pos, crystal_batch.batch,
dim_size=crystal_batch.num_graphs, dim=0, reduce='mean')
good_translations, ucells_per_cluster = get_cart_translations(cc_centroids,
crystal_batch.T_fc,
crystal_batch.radius,
cutoff + zp_buffer,
supercell_size)
"""
Build constants & indices
"""
# unit cell
atoms_per_ucell = (crystal_batch.num_atoms * crystal_batch.sym_mult).long()
# cluster
atoms_per_cluster = atoms_per_ucell * ucells_per_cluster # total atoms per crystal
aunits_per_cluster = ucells_per_cluster * crystal_batch.sym_mult
# get atomwise translations
atoms_per_translation = atoms_per_ucell.repeat_interleave(ucells_per_cluster, dim=0)
atomwise_translation = good_translations.repeat_interleave(atoms_per_translation, dim=0)
"""
Instantiate explicit clusters
"""
cluster_batch = instantiate_cluster(atoms_per_ucell,
atoms_per_cluster,
atomwise_translation,
ucells_per_cluster,
crystal_batch,
)
"""
we then can further pare by molecule centroid, since our original box calculation is rather permissive
"""
cluster_batch = _pare_cluster_molwise(atoms_per_cluster, cc_centroids,
cluster_batch, crystal_batch,
cutoff + zp_buffer, aunits_per_cluster)
return cluster_batch
# zpg_inds = torch.argwhere(cluster_batch.z_prime > 1).flatten().tolist()
# cluster_batch.visualize(zpg_inds, mode='distance', cutoff=4, highlight_canonical_conformer=True)
# cluster_batch.visualize([1, 2, 3, 4, 5], mode='unit cell')
# cluster_batch.visualize([1, 2, 3, 4, 5], mode='convolve with')
def _pare_cluster_molwise(atoms_per_cluster, cc_centroids, cluster_batch,
crystal_batch, cutoff, aunits_per_cluster):
# get all mol centroids
# indexes which molecule each atom belongs to, globally (not within-crystal)
aunit_ptr = torch.cat([torch.zeros(1, device=cluster_batch.device, dtype=torch.long),
torch.cumsum(aunits_per_cluster, dim=0)])[:-1]
atom2aunit_ind = cluster_batch.mol_ind + aunit_ptr.repeat_interleave(atoms_per_cluster, dim=0)
# this is extremely memory heavy with float32, we can make huge savings with half/float16, but it can overflow
aunit_centroids = scatter(cluster_batch.pos, #.half(),
atom2aunit_ind, reduce='mean', dim=0)
# cc_centroids = scatter(crystal_batch.pos.half(), crystal_batch.batch, reduce='mean', dim=0) # already compute this above
# get the mol indices within the widest conv cutoff
molwise_batch = (torch.arange(crystal_batch.num_graphs, device=cc_centroids.device)
.repeat_interleave(aunits_per_cluster, dim=0))
mol_inds = torch.arange(len(molwise_batch), device=cluster_batch.device)
# get the distances
mol_to_cc_dist_sq = (aunit_centroids - cc_centroids.repeat_interleave(aunits_per_cluster, dim=0)).square().sum(1)
cluster_cutoff = (cutoff + 2 * crystal_batch.radius + 0.1).square()
mol_cutoff_sq = cluster_cutoff[molwise_batch]
# molecules to discard
bad_mol_inds = mol_inds[mol_to_cc_dist_sq > mol_cutoff_sq]
# convert back to atomic indices
# atom for each molecule in the kept list
# atoms_in_bad_mols_bool = torch.isin(atom_to_flat_mol_ind, bad_mol_inds)
# dramatically faster alternative to isin
# bad_sorted = bad_mol_inds.sort().values
# atom_to_mol index of each atom which is in a 'bad' molecule
# this works because atom_to_flat_mol_ind is already sorted
idx = torch.searchsorted(bad_mol_inds, atom2aunit_ind) # sorting step
valid = idx < bad_mol_inds.numel() # safety check for wayward searchsorted outputs > length of tensors
atoms_in_bad_mols_bool = torch.zeros_like(atom2aunit_ind, dtype=torch.bool)
atoms_in_bad_mols_bool[valid] = bad_mol_inds[idx[valid]] == atom2aunit_ind[valid] # atom is actually in a bad mol
# --- diagnostic: central aunit must never be marked outside the conv window ---
# By construction, the canonical conformer (mol_ind == 0 within each cluster)
# is at zero distance from cc_centroids, so it should never appear in bad_mol_inds.
# If it does, something upstream produced inf/garbage in cluster_batch.pos or
# aunit_centroids, or the centroid invariant has been violated.
# Per-cluster index of the central aunit in the global aunit_centroids tensor:
central_aunit_inds = aunit_ptr # mol_ind == 0 within each cluster, plus the cluster's offset
central_dist_sq = mol_to_cc_dist_sq[central_aunit_inds]
central_marked_bad = mol_to_cc_dist_sq[central_aunit_inds] > mol_cutoff_sq[central_aunit_inds]
if central_marked_bad.any():
bad_graphs = central_marked_bad.nonzero(as_tuple=True)[0]
msg_lines = ["central aunit marked outside conv window — invariant violated"]
for g in bad_graphs.tolist():
atoms_mask = cluster_batch.batch == g
graph_pos = cluster_batch.pos[atoms_mask]
graph_aux_zero_count = (cluster_batch.aux_ind[atoms_mask] == 0).sum().item()
msg_lines.append(
f" graph={g}: "
f"central_dist_sq={central_dist_sq[g].item():.4g}, "
f"cutoff_sq={mol_cutoff_sq[central_aunit_inds[g]].item():.4g}, "
f"pos_finite={graph_pos.isfinite().all().item()}, "
f"pos_absmax={graph_pos.abs().max().item():.4g}, "
f"aunit_centroid_finite={aunit_centroids[central_aunit_inds[g]].isfinite().all().item()}, "
f"aunit_centroid={aunit_centroids[central_aunit_inds[g]].tolist()}, "
f"cc_centroid={cc_centroids[g].tolist()}, "
f"cell_lengths={crystal_batch.cell_lengths[g].tolist()}, "
f"cell_angles={crystal_batch.cell_angles[g].tolist()}, "
f"radius={crystal_batch.radius[g].item():.4g}, "
f"current_aux_zeros_in_graph={graph_aux_zero_count}, "
)
raise AssertionError("\n".join(msg_lines))
# --- end diagnostic ---
cluster_batch.aux_ind[atoms_in_bad_mols_bool] = 2 # marker for "outside convolutional field"
return cluster_batch
[docs]
def instantiate_cluster(atoms_per_ucell,
atoms_per_cluster,
atomwise_translation,
ucells_per_cluster,
crystal_batch):
""""""
"""
Generate necessary indices to instantiate clusters in parallel
"""
unit_cell_batch = crystal_batch.unit_cell_batch
unit_cell_ptr = torch.cat(
[torch.tensor([0], device=crystal_batch.device),
torch.cumsum(atoms_per_ucell, dim=0)]).long()
atom_in_unit_cell_batch, molwise_nodes_within_unit_cell, supercell_batch = _index_cluster_nodes(atoms_per_cluster,
atoms_per_ucell,
crystal_batch,
unit_cell_ptr)
# add ptr offset per graph
node_within_mol_in_ucell = molwise_nodes_within_unit_cell + crystal_batch.ptr[:-1][unit_cell_batch]
# this indexes from the full cluster to the atom in each aunit
cluster_node2aunit_node = node_within_mol_in_ucell[atom_in_unit_cell_batch]
mol_ind = _index_mols(atom_in_unit_cell_batch, atoms_per_cluster, atoms_per_ucell, crystal_batch,
ucells_per_cluster)
"""
do the instantiation
"""
cluster_batch = _instantiate_cluster(atom_in_unit_cell_batch,
atoms_per_cluster, atomwise_translation,
crystal_batch, cluster_node2aunit_node,
supercell_batch,
mol_ind)
return cluster_batch
def _index_cluster_nodes(atoms_per_cluster, atoms_per_ucell, crystal_batch, unit_cell_ptr):
tot_num_atoms = atoms_per_cluster.sum()
cluster_atom_ind = torch.arange(tot_num_atoms, device=crystal_batch.device)
supercell_batch = torch.repeat_interleave(torch.arange(crystal_batch.num_graphs,
device=crystal_batch.device),
atoms_per_cluster)
supercell_ptr = torch.cumsum(torch.cat([torch.tensor([0], device=crystal_batch.device),
atoms_per_cluster[:-1]]), dim=0)
atom_in_crystal_index = cluster_atom_ind - supercell_ptr[supercell_batch]
atom_in_unit_cell_batch = (atom_in_crystal_index % atoms_per_ucell[supercell_batch]) + unit_cell_ptr[
supercell_batch]
# within-block index (0..num_atoms[i]-1), repeated for sym_mult[i]
# TODO replace with some nice batched offset, if possible
molwise_nodes_within_unit_cell = torch.cat([
torch.arange(n, device=crystal_batch.device).repeat(m)
for n, m in zip(crystal_batch.num_atoms.tolist(), crystal_batch.sym_mult.tolist())
])
''' test and compare this method on a large batch
num_atoms = crystal_batch.num_atoms
sym_mult = crystal_batch.sym_mult
# total length = sum(num_atoms * sym_mult)
lengths = num_atoms * sym_mult
total = lengths.sum()
# assign each position to its "unit cell"
ucell_idx = torch.arange(total, device=crystal_batch.device) \
.repeat_interleave(1) # just a flat arange
# expand each unit cell according to lengths
ucell_idx = torch.repeat_interleave(torch.arange(len(lengths), device=total.device), lengths)
# within each block, cycle through atom indices
within = torch.arange(total, device=total.device) - torch.repeat_interleave(
torch.cumsum(torch.cat([torch.zeros(1, device=total.device, dtype=num_atoms.dtype),
lengths[:-1]]), 0),
lengths
)
# reduce modulo num_atoms[i]
molwise_nodes_within_unit_cell = within % torch.repeat_interleave(num_atoms, lengths)
'''
return atom_in_unit_cell_batch, molwise_nodes_within_unit_cell, supercell_batch
def _index_mols(atom_in_unit_cell_batch, atoms_per_cluster, atoms_per_ucell, crystal_batch, ucells_per_cluster):
"""
Finally, Z'>1-functional, asymmetric unit indexing (for mol_ind/aux_ind)
"""
if crystal_batch.z_prime.amax() > 1:
assert False, "Don't try cluster construction for Z'>1 crystals - the mol indexing is hideous and not up-to-date"
atomwise_mol_ind = _zp_gr1_mol_indexing(atom_in_unit_cell_batch, atoms_per_cluster, atoms_per_ucell,
crystal_batch,
ucells_per_cluster)
else: # much simpler to do it this way
mols_per_ucell = crystal_batch.sym_mult
mols_per_cluster = ucells_per_cluster * mols_per_ucell
tot_num_mols = torch.sum(mols_per_cluster)
global_mol_index = torch.arange(tot_num_mols, device=crystal_batch.device)
mol_offset = torch.cat([torch.zeros(1, device=crystal_batch.device, dtype=torch.long),
torch.cumsum(mols_per_cluster, dim=0)[:-1]])
local_mol_index = global_mol_index - mol_offset.repeat_interleave(mols_per_cluster)
atoms_per_mol_in_cluster = crystal_batch.num_atoms.repeat_interleave(mols_per_cluster)
atomwise_mol_ind = local_mol_index.repeat_interleave(
atoms_per_mol_in_cluster
)
return atomwise_mol_ind
def _zp_gr1_mol_indexing(atom_in_unit_cell_batch, atoms_per_cluster, atoms_per_ucell, crystal_batch,
ucells_per_cluster):
"""
This appears technically functional, but is an avoidable monstrosity
:param atom_in_unit_cell_batch:
:param atoms_per_cluster:
:param atoms_per_ucell:
:param crystal_batch:
:param ucells_per_cluster:
:return:
"""
if not getattr(_zp_gr1_mol_indexing, '_warned', False):
warnings.warn(
"WARNING: You have done cluster building directly on a crystal batch with some Z prime greater than 1."
"This calls a complicated and avoidable indexing function. We recommend instead splitting to separate Zp=1 "
"graphs then recombining.", UserWarning, stacklevel=2)
_zp_gr1_mol_indexing._warned = True
mols_per_ucell = atoms_per_ucell // crystal_batch.num_atoms * crystal_batch.z_prime
ucell_mol_offset = torch.cumsum(torch.cat([torch.zeros(1, device=mols_per_ucell.device, dtype=mols_per_ucell.dtype),
mols_per_ucell[:-1]]), 0)
within_ucell_mol_ind = (torch.arange(mols_per_ucell.sum(), device=mols_per_ucell.device, dtype=torch.long)
- ucell_mol_offset.repeat_interleave(mols_per_ucell))
atoms_per_mol = crystal_batch.num_atoms // crystal_batch.z_prime
molwise_ucell_sym_mult = crystal_batch.sym_mult.repeat_interleave(mols_per_ucell)
# indexes the asymmetric units within each unit cell
within_ucell_aunit_ind = within_ucell_mol_ind % molwise_ucell_sym_mult
# ditto, atomwise
within_ucell_aunit_atom_ind = within_ucell_aunit_ind.repeat_interleave(
atoms_per_mol.repeat_interleave(mols_per_ucell))
# we then need to expand this to the cluster batch
cluster_aunit_inner_ind = within_ucell_aunit_atom_ind[atom_in_unit_cell_batch]
# and add offsets to account for the new tiles in the cluster
tot_num_ucells = ucells_per_cluster.sum()
aunits_per_cluster = crystal_batch.sym_mult * ucells_per_cluster
ucell_num_aunits = crystal_batch.sym_mult.repeat_interleave(ucells_per_cluster)
aunits_per_cluster_ucell = crystal_batch.sym_mult.repeat_interleave(ucells_per_cluster)
aunit_cum = torch.cumsum(
torch.cat([torch.zeros(1, device=ucells_per_cluster.device, dtype=ucells_per_cluster.dtype),
ucell_num_aunits[:-1]]), 0)
cluster_aunit_offset = torch.cumsum(
torch.cat([torch.zeros(1, device=ucells_per_cluster.device, dtype=ucells_per_cluster.dtype),
aunits_per_cluster[:-1]]), 0)
cluster_nodewise_offset = cluster_aunit_offset.repeat_interleave(atoms_per_cluster)
aunit2ucell = torch.arange(tot_num_ucells, device=crystal_batch.device).repeat_interleave(
aunits_per_cluster_ucell
)
aunit_offset = aunit_cum[aunit2ucell]
atoms_per_aunit = crystal_batch.num_atoms.repeat_interleave(aunits_per_cluster)
nodewise_offset = aunit_offset.repeat_interleave(atoms_per_aunit)
mol_ind = cluster_aunit_inner_ind + nodewise_offset - cluster_nodewise_offset
return mol_ind
def _instantiate_cluster(atom_in_unit_cell_batch,
atoms_per_cluster,
atomwise_translation,
crystal_batch,
cluster_node2aunit_node,
supercell_batch,
mol_ind):
cluster_batch = crystal_batch.clone()
# init cluster with right number of duplicated unit cells
cluster_batch.pos = crystal_batch.unit_cell_pos[atom_in_unit_cell_batch]
# build cluster by translating all atoms manually
cluster_batch.pos += atomwise_translation
# pattern out atom properties
cluster_batch.z = crystal_batch.z[cluster_node2aunit_node]
cluster_batch.x = crystal_batch.x[cluster_node2aunit_node]
# rebuild batch indices
cluster_batch.ptr = torch.cat([torch.zeros(1, dtype=torch.long, device=crystal_batch.device),
torch.cumsum(atoms_per_cluster, dim=0)]).long()
cluster_batch.batch = supercell_batch
# asymmetric unit indices
cluster_batch.mol_ind = mol_ind
cluster_batch.aux_ind = torch.ones_like(mol_ind)
cluster_batch.aux_ind[mol_ind == 0] = 0
return cluster_batch
def _molwise_indexing(cluster_batch,
crystal_batch,
mols_per_subunit,
tot_num_subunits):
"""
Indexing for downstream analysis
aux_ind labels 0: asymmetric unit, 1: convolutional window, 2: outside convolutional window
mol_ind indexes the molecule (within cluster) for each atom
"""
# which subunit each entry belongs to
mol2subunit = torch.arange(tot_num_subunits, device=mols_per_subunit.device).repeat_interleave(mols_per_subunit)
# index *within* the cluster
offsets = torch.cumsum(torch.cat([torch.zeros(1, device=mols_per_subunit.device, dtype=mols_per_subunit.dtype),
mols_per_subunit[:-1]]), 0)
mol_within_subunit = torch.arange(mols_per_subunit.sum(), device=mols_per_subunit.device) - offsets[mol2subunit]
atoms_per_mol_by_crystal = crystal_batch.num_atoms // crystal_batch.z_prime
atoms_per_mol_by_subunit = atoms_per_mol_by_crystal.repeat_interleave(crystal_batch.z_prime)
atoms_per_mol = atoms_per_mol_by_subunit.repeat_interleave(mols_per_subunit)
cluster_batch.mol_ind = mol_within_subunit.repeat_interleave(
atoms_per_mol
)
cluster_batch.aux_ind = torch.zeros_like(cluster_batch.z)
cluster_batch.aux_ind[cluster_batch.mol_ind != 0] = 1 # MK this may have to be adjusted for Z'>1
return cluster_batch
[docs]
def get_cart_translations(cc_centroids,
T_fc,
mol_radii,
cutoff,
supercell_size: int = 9,
):
# precompute some stuff
# get the set of all possible cartesian translations
cell_vectors = T_fc.permute(0, 2, 1)
# cell diagonals
cell_diag = cell_vectors.sum(1).norm(dim=-1)
# bounding box cutoff
cutoff_distance = mol_radii * 2 + cutoff + 0.1 + cell_diag # convolutional radius plus cell diagonal plus z'>1 buffer
# identify actual supercell size
box_lengths = torch.linalg.norm(cell_vectors, dim=-1)
max_translations_per_dim = int(torch.ceil(cutoff_distance[:, None] / box_lengths).amax())
actual_supercell_size = min(supercell_size, max_translations_per_dim) # use the argument size as a cap
sorted_fractional_translations = generate_sorted_fractional_translations(actual_supercell_size).to(T_fc.device)
# for the n samples, for the 3 (i) cell vectors, get the inner product with all the fractional translation vectors
parallelpiped_centroids = torch.einsum('nij, mi -> nmj', cell_vectors, sorted_fractional_translations)
# we need to include any parallelpiped which overlaps with a sphere of
# radius cutoff_distance around the canonical conformer for each cell - a bounding box calculation
# we compute this in a brute force way, adding the diagonal length of each parallelpiped to the cutoff
# this will take extra boxes, but is fast and simple
# centroid_dists = torch.linalg.norm(cc_centroids[:, None, :] - parallelpiped_centroids, dim=-1)
# good_translations_bool = centroid_dists < cutoff_distance[:, None]
# this is faster without the sqrt, and comparing against the squared dist
centroid_dists = (cc_centroids[:, None, :] - parallelpiped_centroids).square().sum(dim=-1)
good_translations_bool = centroid_dists < (cutoff_distance[:, None].square())
# build only the unit cells which are relevant to each cluster
good_translations_bool[:, 0] = True # always take the 0,0,0 element
good_translations = parallelpiped_centroids[good_translations_bool]
cells_per_cluster = good_translations_bool.sum(1)
return good_translations, cells_per_cluster
[docs]
def find_coord_in_box_torch(coords, box, epsilon=0):
# which of the given coords is inside the specified box, with option for a little leeway
return torch.where((coords[:, 0] <= (box[0] + epsilon)) *
(coords[:, 1] <= (box[1] + epsilon)) *
(coords[:, 2] <= (box[2] + epsilon)) *
(coords[:, 0] >= 0) *
(coords[:, 1] >= 0) *
(coords[:, 2] >= 0)
)[0]
[docs]
def parameterize_crystal_batch(crystal_batch,
asym_unit_dict,
enforce_right_handedness: bool = False,
return_aunit: bool = False,
):
"""
Asymmetric Unit Centroid Analysis
identify "canonical" asymmetric unit out of the Z asymmetric units in the unit cell
# TODO would be nice if this could be parallelized. Would be nice indeed
"""
canonical_conformer_coords_list = []
mol_position_list = []
well_defined_asym_unit_list = []
for i, (T_cf, sg_ind) in enumerate(zip(crystal_batch.T_cf, crystal_batch.sg_ind)):
unit_cell_coords = crystal_batch.unit_cell_pos[crystal_batch.unit_cell_batch == i]
unit_cell_atom_types = crystal_batch.z[crystal_batch.batch == i].repeat(crystal_batch.sym_mult[i])
# filter for heavy only
num_heavy_atoms = sum(crystal_batch.z[crystal_batch.batch == i] > 1)
heavy_ucell_coords = unit_cell_coords[unit_cell_atom_types > 1]
heavy_ucell_coords = heavy_ucell_coords.reshape(crystal_batch.sym_mult[i], num_heavy_atoms, 3)
(canonical_conformer_index,
centroids_fractional,
well_defined_asym_unit) = (
identify_canonical_asymmetric_unit(
T_cf,
asym_unit_dict,
sg_ind,
heavy_ucell_coords))
unit_cell_coords = unit_cell_coords.reshape(crystal_batch.sym_mult[i], crystal_batch.num_atoms[i], 3)
canonical_conformer_coords_list.append(unit_cell_coords[canonical_conformer_index[0]])
mol_position_list.append(centroids_fractional[canonical_conformer_index[0]])
well_defined_asym_unit_list.extend([well_defined_asym_unit])
mol_position_list = torch.stack(mol_position_list)
crystal_batch.pos = torch.cat(canonical_conformer_coords_list)
'''
Pose Analysis
compute the inverse of the rotation required to align the molecule with the cartesian axes
help from http://motion.pratt.duke.edu/RoboticSystems/3DRotations.html#:~:text=Another%20popular%20rotation%20representation%20is,be%20represented%20in%20this%20form!
'''
rotvec_list, handedness_list = extract_aunit_orientation(crystal_batch,
enforce_right_handedness,
canonicalize_orientation=True)
if return_aunit:
return mol_position_list, rotvec_list, handedness_list, well_defined_asym_unit_list, canonical_conformer_coords_list
else:
return mol_position_list, rotvec_list, handedness_list, well_defined_asym_unit_list
[docs]
def canonicalize_rotvec(rotvecs: torch.Tensor):
"""
since the direction of the axis is arbitrary, (x,y,z) is the same rotation as (-x,-y,-z),
we can obtain uniqueness constraining the axis to a half-sphere.
Here we will take the +z direction as 'canonical'
Swap the direction and take 2pi-norm to recapture the identical rotation.
"""
flip_inds = torch.where(rotvecs[:, -1] < 0)[0]
flip_vecs = rotvecs[flip_inds]
flip_norms = torch.linalg.norm(flip_vecs, dim=-1)
new_norms = 2 * torch.pi - flip_norms
new_vecs = -flip_vecs / flip_norms[:, None] * new_norms[:, None]
rotvecs[flip_inds] = new_vecs
return rotvecs
[docs]
def cleanup_invalid_rotvecs(rotation_matrix_list, rotvec_list):
r_arg_list = (torch.einsum('ijj->i', rotation_matrix_list) - 1) / 2
bad_rotations = torch.argwhere(r_arg_list >= 1).flatten()
# if we are close enough to one, the arccos will NaN
# situation corresponds to a rotation by ~pi, with an unknown direction
# fortunately, either direction results in the same transformation (C2)
# some ill conditioned rotations may also throw |args| greater than 1 -
# for safety must set these as something, may as well be this # todo look at such ill-conditioned cases
# set rotation then as pi
rotvec_list[bad_rotations, :] = 1
r_list = torch.arccos(r_arg_list)
r_list[bad_rotations] = torch.pi
more_bad_rotations = torch.argwhere(~torch.isfinite(r_list)).flatten()
if len(more_bad_rotations) > 0: # manually catch other edge cases
r_list[more_bad_rotations] = torch.pi
# more manual edge cases
bad_directions = torch.argwhere(torch.logical_or(
(~torch.isfinite(rotvec_list.sum(1))),
(rotvec_list.abs().sum(1) == 0))
).flatten()
if len(bad_directions) > 0:
r_list[bad_directions] = torch.pi
rotvec_list[bad_directions] = 1
rotvec_list = rotvec_list / torch.linalg.norm(rotvec_list, dim=-1)[:, None] * r_list[:, None]
return rotvec_list
[docs]
def identify_canonical_asymmetric_unit(T_cf, asym_unit_dict, sg_ind, unit_cell_coords):
# cartesian mol centroids -> fractional mol centroids -> confirm inside box
centroids_cartesian = unit_cell_coords.mean(-2) # assume a shape [Z, n_atoms, 3]
centroids_fractional = fractional_transform(centroids_cartesian, T_cf)
centroids_fractional -= torch.floor(centroids_fractional)
aunit_size = asym_unit_dict[str(int(sg_ind))]
canonical_conformer_index_i = find_coord_in_box_torch(centroids_fractional, aunit_size)
# if we didn't find one, or found more than one pick the closest.
# In some cases, they are truly indistinguishable.
if len(canonical_conformer_index_i) == 0:
well_defined_asym_unit = False
canonical_conformer_index = [torch.argmin(torch.linalg.norm(centroids_fractional, axis=1))]
elif len(canonical_conformer_index_i) == 1:
well_defined_asym_unit = True # if there is any ambiguity, it is not 'well defined'
canonical_conformer_index = canonical_conformer_index_i * 1
elif len(canonical_conformer_index_i) > 1:
# if we find more than one inside
well_defined_asym_unit = False
dists = torch.linalg.norm(centroids_fractional[canonical_conformer_index_i], axis=1)
if len(set(dists)) < len(dists): # if any are equal, go dimension-by-dimension
canonical_conformer_index = [canonical_conformer_index_i[0]] # pre-set a random ind, overwrite later
for dim in range(3):
dists = centroids_fractional[canonical_conformer_index_i][:, dim]
if len(set(dists)) < len(dists):
pass # if are equal, go to next dim
else:
canonical_conformer_index = [canonical_conformer_index_i[torch.argmin(dists)]]
break # stop at this dim
else:
canonical_conformer_index = [canonical_conformer_index_i[torch.argmin(dists)]]
return canonical_conformer_index, centroids_fractional, well_defined_asym_unit
[docs]
def align_mol_batch_to_standard_axes(mol_batch, handedness=None, return_rot=False):
"""
align principal inertial axes of molecules in a crystaldata object to the xyz or xy(-z) axes
only works for geometric principal axes (all atoms mass = 1)
"""
Ip, Ipm, I = batch_molecule_principal_axes_torch(
mol_batch.pos,
mol_batch.batch,
mol_batch.num_graphs,
mol_batch.num_atoms,
heavy_atoms_only=True,
atom_types=mol_batch.z
)
eye = torch.tile(torch.eye(3,
device=mol_batch.pos.device,
dtype=torch.float32),
(mol_batch.num_graphs, 1, 1)
) # set as right-handed by default
if handedness is not None: # otherwise, custom
# alignment vector is (x,y,z) by default, or (-x,y,z) for left-handed structures
if torch.is_tensor(handedness):
eye[:, 0, 0] = handedness[:, :1].flatten() # flatten - in case it's higher dimensional [n,1]
else:
eye[:, 0, 0] = handedness[:, :1]
# rotation is p_align = R @ p_orig,
# R = p_align \ p_orig
std_rotations = extract_rotmat(eye, torch.linalg.inv(Ip))
mol_batch.recenter_molecules() # ensures molecules are centered
mol_batch.pos = apply_rotation_to_batch(mol_batch.pos,
std_rotations,
mol_batch.batch)
# test
# Ip, Ipm, I = batch_molecule_principal_axes_torch(
# mol_batch.pos,
# mol_batch.batch,
# mol_batch.num_graphs,
# mol_batch.num_atoms
# )
# assert torch.all(torch.isclose(Ip, eye, atol=1e-4))
if return_rot:
return mol_batch, std_rotations
else:
return mol_batch
[docs]
def aunit2ucell(mol_batch):
"""
:param mol_batch:
:return:
"""
atoms_per_mol = mol_batch.num_atoms
sym_mult = mol_batch.sym_mult
atoms_per_unit_cell = mol_batch.num_atoms * sym_mult
tot_num_mols = torch.sum(sym_mult)
aunit_coords_c = mol_batch.pos
centroids_f = mol_batch.aunit_centroid
# repeat each molecule Z times
mol2ucell = torch.arange(len(sym_mult), device=sym_mult.device).repeat_interleave(sym_mult)
ucell_centroids_f = mol_batch.assign_aunit_centroid(mol_batch.aunit_centroid[mol2ucell])
fc_transform_list = mol_batch.T_fc
cf_transform_list = mol_batch.T_cf
# same, atomwise
ucell_node2mol_node = block_repeat_interleave(mol_batch.num_atoms, sym_mult)
unit_cell_coords_c = aunit_coords_c[ucell_node2mol_node]
unit_cell_batch = mol_batch.batch[ucell_node2mol_node]
# assign which molecule each atom is in
ucell_node2ucell_mol = torch.arange(tot_num_mols, device=mol_batch.device).repeat_interleave(
atoms_per_mol.repeat_interleave(sym_mult))
# MK todo this tensor(np.concatenate) can't be very efficient
molwise_sym_ops = torch.tensor(np.concatenate(mol_batch.symmetry_operators, axis=0),
device=centroids_f.device, dtype=torch.float32)
# add 4th dimension as a dummy for affine transforms, should have shape (num_aunits, 4, 4)
unit_cell_affine_centroids_f = torch.cat(
(
ucell_centroids_f,
torch.ones(ucell_centroids_f.shape[:-1] + (1,), dtype=torch.float32)
.to(aunit_coords_c.device)),
dim=-1)
# get all molecule centroids via symmetry ops
ucell_centroids_f = torch.einsum('nij,nj->ni', (molwise_sym_ops, unit_cell_affine_centroids_f))[..., :-1]
# force centroids within unit cell # note this disallows centroids at fractional 1.0 and moves them to 0.0, which causes some artifacting
# centroids_f_in_cell = ucell_centroids_f - torch.floor(ucell_centroids_f)
# this one lets them sit at 1.0 exactly
mask = (ucell_centroids_f != 1.0)
centroids_f_in_cell = torch.where(mask,
ucell_centroids_f - torch.floor(ucell_centroids_f),
ucell_centroids_f)
# subtract centroids and apply point symmetry to the molecule coordinates in fractional frame
centered_coords_c = center_batch(unit_cell_coords_c,
unit_cell_batch,
mol_batch.num_graphs,
atoms_per_unit_cell,
center_on_heavy_atoms=True,
atom_types=mol_batch.z[ucell_node2mol_node])
centered_coords_f = fractional_transform(centered_coords_c,
cf_transform_list[unit_cell_batch])
rot_coords_f = torch.einsum('nj,nij->ni', (centered_coords_f, molwise_sym_ops[ucell_node2ucell_mol, :3, :3]))
# translate rotated aunits to their correct position
unit_cell_pos = torch.einsum('nij,nj->ni',
(fc_transform_list[unit_cell_batch],
rot_coords_f + centroids_f_in_cell[ucell_node2ucell_mol, :]))
return unit_cell_pos, unit_cell_batch, ucell_node2ucell_mol
[docs]
def descale_asymmetric_unit(asym_unit_dict, mol_position, sg_inds):
"""
input fractional coordinates are scaled on 0-1
rescale these for the specific ranges according to each space group
only space groups in asym_unit_dict will work - not all have been manually encoded
this approach will not work for asymmetric units which are not neat parallelpipeds
Parameters
----------
asym_unit_dict
mol_position
sg_inds
Returns
-------
"""
return mol_position * torch.stack([asym_unit_dict[str(int(ind))] for ind in sg_inds])
[docs]
def rescale_asymmetric_unit(asym_unit_dict, mol_position, sg_inds):
"""
input fractional coordinates are scaled on 0-max
rescale these for the specific ranges according to each space group
only space groups in asym_unit_dict will work - not all have been manually encoded
this approach will not work for asymmetric units which are not neat parallelpipeds
Parameters
----------
asym_unit_dict
mol_position
sg_inds
Returns
-------
"""
return mol_position / torch.stack([asym_unit_dict[str(int(ind))] for ind in sg_inds])
[docs]
def get_aunit_positions(mol_batch,
std_orientation: bool = True,
mol_handedness: OptTensor = None,
) -> Tensor:
"""
pose asymmetric unit via fractional translations and rotations
always take the zp=1 component
"""
rotations_list = rotvec2rotmat(mol_batch.aunit_orientation[:, :3])
if std_orientation: # align canonical conformers principal axes to cartesian axes - not usually done here, but allowed
assert mol_handedness is not None, "Must provide explicit handedness when aligning mol to canonical axes"
mol_batch = align_mol_batch_to_standard_axes(mol_batch, handedness=mol_handedness[:, :1])
else:
mol_batch.recenter_molecules()
# apply rotations and fractional translations
pos = (apply_rotation_to_batch(mol_batch.pos,
rotations_list,
mol_batch.batch)
+ fractional_transform(mol_batch.aunit_centroid[:, :3], mol_batch.T_fc)[mol_batch.batch])
# # ensure embeddings rotate along with the aunit positions
# this is a bit risky - leave it for now
# if hasattr(mol_batch, 'embedding'):
# if mol_batch.embedding.shape[1] == 3: # if there is a vector embedding, we should also rotate it
# mol_batch.embedding = torch.einsum('nij, njk -> njk', rotations_list, mol_batch.embedding)
return pos
[docs]
def canonicalize_aunit_order(batch):
# canonical aunit ordering according to distance to the origin, with x, y, z, tiebreak
# TODO implement tiebreak
per_aunit_centroids = batch.aunit_centroid.reshape(batch.num_graphs, batch.max_z_prime, 3)
# mask out elements with lower z prime than max
idx = torch.arange(batch.max_z_prime, device=batch.device)[None, ...]
mask = (idx >= (batch.z_prime[..., None]))[..., None].expand(-1, -1, 3)
per_aunit_centroids[mask] = 1 # this will put lower Z' options always at the end
origin_dists = per_aunit_centroids.norm(dim=2)
canonical_order = origin_dists.argsort(dim=1)
canonical_order_exp = canonical_order.unsqueeze(-1).expand(-1, -1, 3) # [n, k, 3]
per_aunit_centroids[mask] = 0.5 # this is the placeholder value for 'not being used'
sorted_centroids = torch.gather(per_aunit_centroids, dim=1, index=canonical_order_exp)
aunit_centroid = sorted_centroids.reshape(batch.num_graphs, batch.max_z_prime * 3)
# also reorder the orientations accordingly
per_aunit_orient = batch.aunit_orientation.reshape(batch.num_graphs, batch.max_z_prime, 3)
per_aunit_orient[mask] = torch.pi / 2 # this places all unused samples in the same state
canonical_order_exp = canonical_order.unsqueeze(-1).expand(-1, -1, 3) # [n, k, 3]
sorted_orient = torch.gather(per_aunit_orient, dim=1, index=canonical_order_exp)
aunit_orientation = sorted_orient.reshape(batch.num_graphs, batch.max_z_prime * 3)
aunit_handedness = torch.gather(batch.aunit_handedness, dim=1, index=canonical_order)
return aunit_centroid, aunit_orientation, aunit_handedness
[docs]
def protonate_mol(atom_types, coords):
"""
atom_types : list/array of atomic numbers
coords : (N,3) float array in Å
returns: (new_atom_types, new_coords) with added hydrogens
"""
# 1. Build an RDKit mol without sanitization
mol = Chem.RWMol()
for Z in atom_types:
mol.AddAtom(Chem.Atom(int(Z)))
# 2. Add a conformer with your coords
conf = Chem.Conformer(len(atom_types))
for i, (x, y, z) in enumerate(coords):
conf.SetAtomPosition(i, (float(x), float(y), float(z)))
mol = mol.GetMol()
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ADJUSTHS)
mol.AddConformer(conf, assignId=True)
# 3. Add hydrogens (fast, valence-based)
molH = Chem.AddHs(mol, addCoords=True)
# If your input had 3D coordinates, this preserves heavy atom positions
# and places hydrogens using RDKit's built-in valence geometry rules.
# No re-embedding needed for most molecules.
# 4. Extract arrays back out
confH = molH.GetConformer()
n = molH.GetNumAtoms()
new_coords = np.array([list(confH.GetAtomPosition(i)) for i in range(n)], dtype=float)
new_atom_types = np.array([molH.GetAtomWithIdx(i).GetAtomicNum() for i in range(n)], dtype=int)
return new_atom_types, new_coords
"""
m1 = Atoms(positions=coords, numbers=atom_types)
m2 = Atoms(positions=new_coords, numbers=new_atom_types)
view([m1, m2])
"""