Source code for mxtaltools.common.geometry_utils

import sys
from math import sqrt
from typing import Optional

import numpy as np
import torch
from torch import Tensor
from torch_scatter import scatter, scatter_max

from mxtaltools.common.sym_utils import init_sym_info
from mxtaltools.constants.atom_properties import VDW_RADII
from mxtaltools.models.functions.radial_graph import radius


[docs] def compute_principal_axes_np(coords): """ Compute the principal inertial axes for a given set of particle coordinates, ignoring particle mass. Use our overlap rules to ensure a fixed direction for all axes under almost all circumstances, excepting e.g., certain symmetric molecules. Parameters ---------- coords Returns ------- Ip : np.array(3,3) Principal inertial axes. Ipm : np.array(3) Principal inertial moments I : np.array(3) Inertial tensor in original frame """ # todo harmonize with torch version - currently disagrees ~0.5% of the time points = coords - coords.mean(0) x, y, z = points.T Ixx = np.sum((y ** 2 + z ** 2)) Iyy = np.sum((x ** 2 + z ** 2)) Izz = np.sum((x ** 2 + y ** 2)) Ixy = -np.sum(x * y) Iyz = -np.sum(y * z) Ixz = -np.sum(x * z) I = np.array([[Ixx, Ixy, Ixz], [Ixy, Iyy, Iyz], [Ixz, Iyz, Izz]]) # inertial tensor Ipm, Ip = np.linalg.eig(I) # principal inertial tensor Ipm, Ip = np.real(Ipm), np.real(Ip) sort_inds = np.argsort(Ipm) Ipm = Ipm[sort_inds] Ip = Ip.T[sort_inds] # want eigenvectors to be sorted row-wise (rather than column-wise) # cardinal direction is vector from CoM to the farthest atom dists = np.linalg.norm(points, axis=1) max_ind = np.argmax(dists) max_equivs = np.argwhere(np.round(dists, 8) == np.round(dists[max_ind], 8))[:, 0] # if there are multiple equidistant atoms - pick the one with the lowest index max_ind = int(np.amin(max_equivs)) direction = points[max_ind] direction = np.divide(direction, np.linalg.norm(direction)) overlaps = Ip.dot(direction) # check if the principal components point towards or away from the CoG signs = np.sign(overlaps) # returns zero for zero overlap, but we want it to default to +1 in this case signs[signs == 0] = 1 Ip = (Ip.T * signs).T # if the vectors have negative overlap, flip the direction if np.any( np.abs( overlaps) < 1e-3): # if any overlaps are vanishing, determine the direction via the RHR (if two overlaps are vanishing, this will not work) # align the 'good' vectors fix_ind = np.argmin(np.abs(overlaps)) # vector with vanishing overlap if compute_Ip_handedness(Ip) < 0: # make sure result is right handed Ip[fix_ind] = -Ip[fix_ind] return Ip, Ipm, I
[docs] def compute_inertial_tensor_torch(x: torch.tensor, y: torch.tensor, z: torch.tensor): """ Compute the principal inertial axes for a given set of particle coordinates, ignoring particle mass. Parameters ---------- x : torch.tensor y : torch.tensor z : torch.tensor Returns ------- Ip : np.array(3,3) Principal inertial axes. Ipm : np.array(3) Principal inertial moments I : np.array(3) Inertial tensor in original frame """ # todo harmonize with numpy version - currently disagrees ~0.5% of the time Ixy = -torch.sum(x * y) Iyz = -torch.sum(y * z) Ixz = -torch.sum(x * z) # I = torch.tensor([[Ixx, Ixy, Ixz], [Ixy, Iyy, Iyz], [Ixz, Iyz, Izz]],device=points.device) # inertial tensor I = torch.tensor( [[torch.sum((y ** 2 + z ** 2)), Ixy, Ixz], [Ixy, torch.sum((x ** 2 + z ** 2)), Iyz], [Ixz, Iyz, torch.sum((x ** 2 + y ** 2))]], device=x.device) # inertial tensor Ipm, Ip = torch.linalg.eig(I) # principal inertial tensor return I, Ip, Ipm
[docs] def single_molecule_principal_axes_torch(coords: torch.tensor, masses=None, return_direction=False): """ Compute the principal inertial axes for a given set of particle coordinates, ignoring particle mass. Use our overlap rules to ensure a fixed direction for all axes under almost all circumstances, excepting e.g., certain symmetric molecules. Parameters ---------- coords : torch.tensor(n,3) masses : None not used return_direction : bool whether to add the direction between centroid and most distant coordinate to the output Returns ------- Ip : torch.tensor(3,3) Principal inertial axes. Ipm : torch.tensor(3) Principal inertial moments I : torch.tensor(3) Inertial tensor in original frame """ if masses is not None: print('Inertial tensor is purely geometric! Calculation will not account for varying masses') x, y, z = coords.T I, Ip, Ipm = compute_inertial_tensor_torch(x, y, z) Ipm, Ip = torch.real(Ipm), torch.real(Ip) sort_inds = torch.argsort(Ipm) Ipm = Ipm[sort_inds] Ip = Ip.T[sort_inds] # want eigenvectors to be sorted row-wise (rather than column-wise) # cardinal direction is vector from CoM to farthest atom dists = torch.linalg.norm(coords, axis=1) # CoM is at 0,0,0 max_ind = torch.argmax(dists) max_equivs = torch.where(dists == dists[max_ind])[ 0] # torch.where(torch.round(dists, decimals=8) == torch.round(dists[max_ind], decimals=8))[0] # if there are multiple equidistant atoms - pick the one with the lowest index max_ind = int(torch.amin(max_equivs)) direction = coords[max_ind] # direction = direction / torch.linalg.norm(direction) # magnitude doesn't matter, only the sign overlaps = torch.inner(Ip, direction) # Ip.dot(direction) # check if the principal components point towards or away from the CoG if any(overlaps == 0): # exactly zero is invalid # overlaps[overlaps == 0] = 1e-9 if any(torch.abs( overlaps) < 1e-8): # if any overlaps are vanishing, determine the direction via the RHR (if two overlaps are vanishing, this will not work) # align the 'good' vectors Ip = (Ip.T * torch.sign(overlaps)).T # if the vectors have negative overlap, flip the direction fix_ind = torch.argmin(torch.abs(overlaps)) other_vectors = np.delete(np.arange(3), fix_ind) check_direction = torch.cross(Ip[other_vectors[0]], Ip[other_vectors[1]]) # align the 'bad' vector Ip[fix_ind] = check_direction # Ip[fix_ind] * torch.sign(torch.dot(check_direction, Ip[fix_ind])) else: Ip = (Ip.T * torch.sign(overlaps)).T # if the vectors have negative overlap, flip the direction if return_direction: return Ip, Ipm, I, direction else: return Ip, Ipm, I
[docs] def list_molecule_principal_axes_torch(coords_list: list = None, skip_centring=False): """ Parallel computation of principal inertial axes from a list of coordinate lists. Parameters ---------- coords_list : list(torch.tensor(n,3)) skip_centring : bool Whether to skip centering each point cloud - e.g., if the input is already centered Returns ------- Ip_fin : list(torch.tensor(3,3)) Ipm_fin : list(torch.tensor(3)) I : list(torch.tensor(3,3)) """ if not skip_centring: # todo accelerate with scatter coords_list_centred = [coord - coord.mean(0) for coord in coords_list] all_coords = torch.cat(coords_list_centred) else: all_coords = torch.cat(coords_list) batch, ptrs = extract_batching_info(coords_list, all_coords.device) # todo pass batch info as an argument instead calculating here Ip, Ipm_fin, I = scatter_compute_Ip(all_coords, batch) # cardinal direction is vector from CoM to the farthest atom direction = batch_get_furthest_node_vector(all_coords, batch, num_graphs=len(coords_list)) normed_direction = direction / torch.linalg.norm(direction, dim=1)[:, None] overlaps, signs = get_overlaps(Ip, normed_direction) Ip_fin = correct_Ip_directions(Ip, overlaps, signs) # somehow, fails for mirror planes, on top of symmetric and spherical tops return Ip_fin, Ipm_fin, I
[docs] def batch_molecule_principal_axes_torch(coords_i: torch.FloatTensor, batch: torch.LongTensor, num_graphs: int, nodes_per_graph: torch.LongTensor, heavy_atoms_only: bool = True, atom_types: Optional[torch.LongTensor] = None, ): """ Parallel computation of principal inertial axes from a list of coordinate lists. Parameters ---------- skip_centring : bool Whether to skip centering each point cloud - e.g., if the input is already centered Returns ------- Ip_fin : list(torch.tensor(3,3)) Ipm_fin : list(torch.tensor(3)) I : list(torch.tensor(3,3)) """ bc = torch.bincount(batch, minlength=num_graphs) assert torch.equal(bc, nodes_per_graph), \ f"bincount(batch) vs num_atoms mismatch:\n bincount={bc.tolist()}\n num_atoms={nodes_per_graph.tolist()}" coords = center_batch(coords_i, batch, num_graphs, nodes_per_graph, center_on_heavy_atoms=heavy_atoms_only, atom_types=atom_types) if heavy_atoms_only: mask = atom_types > 1 Ip, Ipm_fin, I = scatter_compute_Ip(coords[mask], batch[mask]) direction = batch_get_furthest_node_vector(coords[mask], batch[mask], num_graphs) # todo ON NEXT REPARAMETERIZATION SET THIS TO A MORE STABLE ANCHOR like median atom or some inner product else: Ip, Ipm_fin, I = scatter_compute_Ip(coords, batch) direction = batch_get_furthest_node_vector(coords, batch, num_graphs) # cardinal direction is vector from CoM to the farthest atom normed_direction = direction / torch.linalg.norm(direction, dim=1)[:, None] overlaps, signs = get_overlaps(Ip, normed_direction) Ip_fin = correct_Ip_directions(Ip, overlaps, signs) # fails for mirror planes, on top of symmetric and spherical tops return Ip_fin, Ipm_fin, I ''' visualize clouds and axes for testing ind = 16 x, y, z = coords[batch == ind].T.cpu().detach().numpy() types = atom_types[batch==ind].cpu().detach().numpy() fig = go.Figure(go.Scatter3d(x=x, y=y, z=z, mode='markers')) a, b, c = torch.stack([torch.zeros_like(direction[ind]), direction[ind]]).T.cpu().detach().numpy() fig.add_trace(go.Scatter3d(x=a, y=b, z=c, marker_color=types)) colors = ['red', 'green', 'blue'] for i in range(3): a, b, c = torch.stack([torch.zeros_like(direction[ind]), Ip[ind, i]]).T.cpu().detach().numpy() fig.add_trace(go.Scatter3d(x=a, y=b, z=c, marker_color=colors[i], name='Initial principal axes', legendgroup='Initial principal axes', showlegend=i == 0)) for i in range(3): a, b, c = torch.stack([torch.zeros_like(direction[ind]), Ip_fin[ind, i]]).T.cpu().detach().numpy() fig.add_trace(go.Scatter3d(x=a, y=b, z=c, marker_color=colors[i], name='Fixed principal axes', legendgroup='Fixed principal axes', showlegend=i == 0)) fig.show(renderer='browser') '''
def correct_Ip_directions(Ip, overlaps, signs, overlap_threshold: float = 1e-5): """ # TODO speed this up for large batches Enforce positive overlaps for given inertial principal axes with a given canonical direction, given their overlaps. Parameters ---------- Ip : torch.tensor(3,3) overlaps : torch.tensor(3) signs : torch.tensor(3) overlap_threshold : float Returns ------- Ip_fin: torch.tensor(3,3) Inertial principal axes with positive overlaps to the given canonical direction """ handedness = compute_Ip_handedness(Ip) # if the vectors have negative overlap, flip the direction, Ip_fixed = (Ip.permute(0, 2, 1) * signs[..., None]).permute(0, 2, 1) # if any overlaps are vanishing (up to 32 bit precision), # happens if the cardinal direction is too close to an existing principal axis # determine the direction via the RHR (if two overlaps are vanishing, this will not work) small_overlaps_bool = (overlaps.abs() < overlap_threshold).any(dim=1) # identify the smallest overlap if we need to flip a direction fix_ind = torch.argmin(overlaps.abs(), dim=-1) # [B] left_handed = handedness < 0 to_flip_bools = left_handed * small_overlaps_bool Ip_fixed[to_flip_bools, fix_ind[to_flip_bools]] *= -1 Ip_fixed_list = [] for ii, Ip_i in enumerate(Ip): # if the vectors have negative overlap, flip the direction, Ip_fixed = (Ip_i.T * signs[ii]).T # if any overlaps are vanishing (up to 32 bit precision), # happens if the cardinal direction is too close to an existing principal axis # determine the direction via the RHR (if two overlaps are vanishing, this will not work) if any(torch.abs(overlaps[ii]) < overlap_threshold): # enforce right-handedness in the free vector (vanishing overlap) fix_ind = torch.argmin(torch.abs(overlaps[ii])) # vector with vanishing overlap if compute_Ip_handedness(Ip_i) < 0: # if result is not right-handed, swap it Ip_fixed[fix_ind] = -Ip_fixed[fix_ind] Ip_fixed_list.append(Ip_fixed) Ip_fin = torch.stack(Ip_fixed_list) return Ip_fin
[docs] def correct_Ip_directions(Ip, overlaps, signs, overlap_threshold: float = 1e-5): """ Enforce positive overlaps for given inertial principal axes with a given canonical direction, given their overlaps. Parameters ---------- Ip : torch.tensor(3,3) overlaps : torch.tensor(3) signs : torch.tensor(3) overlap_threshold : float Returns ------- Ip_fin: torch.tensor(3,3) Inertial principal axes with positive overlaps to the given canonical direction """ handedness = compute_Ip_handedness(Ip) # if the vectors have negative overlap, flip the direction, Ip_fin = (Ip.permute(0, 2, 1) * signs[:, None, :]).permute(0, 2, 1) # if any overlaps are vanishing (up to 32 bit precision), # happens if the cardinal direction is too close to an existing principal axis # determine the direction via the RHR (if two overlaps are vanishing, this will not work) small_overlaps_bool = (overlaps.abs() < overlap_threshold).any(dim=1) # identify the smallest overlap if we need to flip a direction fix_ind = torch.argmin(overlaps.abs(), dim=-1) # [B] left_handed = handedness < 0 to_flip_bools = left_handed * small_overlaps_bool Ip_fin[to_flip_bools, fix_ind[to_flip_bools]] *= -1 # # Ip_fixed_list = [] # for ii, Ip_i in enumerate(Ip): # # if the vectors have negative overlap, flip the direction, # Ip_fixed = (Ip_i.T * signs[ii]).T # # if any overlaps are vanishing (up to 32 bit precision), # # happens if the cardinal direction is too close to an existing principal axis # # determine the direction via the RHR (if two overlaps are vanishing, this will not work) # if any(torch.abs(overlaps[ii]) < overlap_threshold): # # enforce right-handedness in the free vector (vanishing overlap) # fix_ind = torch.argmin(torch.abs(overlaps[ii])) # vector with vanishing overlap # if compute_Ip_handedness(Ip_i) < 0: # if result is not right-handed, swap it # Ip_fixed[fix_ind] = -Ip_fixed[fix_ind] # # Ip_fixed_list.append(Ip_fixed) # Ip_fin = torch.stack(Ip_fixed_list) return Ip_fin
[docs] def get_overlaps(Ip, direction): """ Compute overlaps and signs for given inertial principal axes with a given canonical direction Parameters ---------- Ip : torch.tensor(3,3) direction : torch.tensor(3) Returns ------- overlaps : torch.tensor(3) signs : torch.tensor(3) """ overlaps = torch.einsum('nij,nj->ni', (Ip, direction)) # Ip.dot(direction) # check if the principal components point towards or away from the CoG signs = torch.sign(overlaps) signs[signs == 0] = 1 # we want any exactly zero overlaps to come with positive signs return overlaps, signs
[docs] def batch_get_furthest_node_vector(all_coords: torch.FloatTensor, batch: torch.LongTensor, num_graphs: int) -> Tensor: """ Compute cardinal direction for a list of sets of coordinates, defined as the vector from the centroid to the furthest coordinate. Output is not unique for certain symmetric inputs. Parameters ---------- all_coords : torch.tensor(n,3) batch : torch.tensor(n) num_graphs : int Returns ------- direction : torch.tensor(num_graphs, 3) """ dists = torch.linalg.norm(all_coords, axis=1) # CoM is at 0,0,0 by construction max_dist, max_ind = scatter_max(dists, batch, dim_size=num_graphs) return all_coords[max_ind]
[docs] def nan_hook(name, tensor_ref, batch): """Return a backward hook that prints debug info and halts on NaN gradients.""" def _hook(grad): if torch.isnan(grad).any(): print(f"NaNs in grad of {name}") print(f"Original tensor: {tensor_ref}") print(f"Batch: {batch}") assert False, "Stop the code!!" return torch.nan_to_num(grad) return _hook
[docs] def scatter_compute_Ip(all_coords, batch, eps: float = 0.05, add_noise: bool = False): """ Parallel function to compute inertial for a list of unequal sized sets of coordinates. Parameters ---------- all_coords : torch.tensor(n,3) batch : torch.tensor(n) Returns ------- Ip : torch.tensor(num_graphs, 3, 3) Ipm : torch.tensor(num_graphs, 3) I : torch.tensor(num_graphs, 3, 3) """ """ input is symmetric by construction when backpropagating, we need to ensure no degenerate eigenvalues empirically has to be bigger already than ~0.0127 to lift degeneracies taking default at 0.05 for safety """ if all_coords.requires_grad or add_noise: coords_to_compute = all_coords + torch.randn_like(all_coords) * eps else: coords_to_compute = all_coords Ixy = -scatter(coords_to_compute[:, 0] * coords_to_compute[:, 1], batch, reduce='sum') Iyz = -scatter(coords_to_compute[:, 1] * coords_to_compute[:, 2], batch, reduce='sum') Ixz = -scatter(coords_to_compute[:, 0] * coords_to_compute[:, 2], batch, reduce='sum') Ixx = scatter(coords_to_compute[:, 1] ** 2 + coords_to_compute[:, 2] ** 2, batch, reduce='sum') Iyy = scatter(coords_to_compute[:, 0] ** 2 + coords_to_compute[:, 2] ** 2, batch, reduce='sum') Izz = scatter(coords_to_compute[:, 0] ** 2 + coords_to_compute[:, 1] ** 2, batch, reduce='sum') inertial_tensor = torch.cat( (torch.vstack((Ixx, Ixy, Ixz))[:, None, :].permute(2, 1, 0), torch.vstack((Ixy, Iyy, Iyz))[:, None, :].permute(2, 1, 0), torch.vstack((Ixz, Iyz, Izz))[:, None, :].permute(2, 1, 0) ), dim=-2) # inertial tensor try: Ipm_c, Ip_c = torch.linalg.eigh(inertial_tensor) except RuntimeError as e: if 'cuda' in str(inertial_tensor.device): Ipm_c, Ip_c = torch.linalg.eigh(inertial_tensor.cpu()) Ipm_c = Ipm_c.to('cuda') Ip_c = Ip_c.to('cuda') else: assert False, "Ipm error" Ipms, Ip_o = torch.real(Ipm_c), torch.real(Ip_c) # Ip_o, Ipms, _ = torch.linalg.svd(inertial_tensor) # superior numerical stability, yet equivalent for symmetric positive semi-definite Ips = Ip_o.permute(0, 2, 1) # switch to row-wise eigenvectors sort_inds = torch.argsort(Ipms, dim=1) # too slow # Ipm = torch.stack([Ipms[i, sort_inds[i]] for i in range(len(sort_inds))]) # Ip = torch.stack([Ips[i][sort_inds[i]] for i in range(len(sort_inds))]) # sort also the eigenvectors # much faster Ipm = torch.gather(Ipms, dim=1, index=sort_inds) Ip = torch.gather(Ips, dim=1, index=sort_inds.unsqueeze(2).expand(-1, -1, Ips.shape[2])) # if Ipms.requires_grad: # coords_to_compute.retain_grad() # coords_to_compute.register_hook(nan_hook("coords", coords_to_compute, batch)) return Ip, Ipm, inertial_tensor
[docs] def extract_batching_info(nodes_list, device='cpu'): """ Extract batch and ptr info from a list of sets of coordinates. Parameters ---------- nodes_list : list(torch.tensor(n,3)) with different n throughout device : str Returns ------- batch : torch.tensor(num_nodes) ptr : torch.tensor(num_graphs + 1) """ ptrs = [0] ptrs.extend([len(coord) for coord in nodes_list]) ptrs = torch.tensor(ptrs, dtype=torch.int, device=device).cumsum(0) batch = torch.cat( [(i - 1) * torch.ones(ptrs[i] - ptrs[i - 1], dtype=torch.int64, device=device) for i in range(1, len(ptrs))]) return batch, ptrs
[docs] def sph2cart_rotvec(angles): """ Transform from axis-angle in polar coordinates to rotation vector Parameters ---------- angles : (nx3) theta, phi, r Returns ------- rotvec : (nx3) x, y, z """ if isinstance(angles, np.ndarray): if angles.ndim > 1: theta, phi, r = angles.T rotvec = r[:, None] * np.stack((np.sin(theta) * np.cos(phi), np.sin(theta) * np.sin(phi), np.cos(theta))).T else: theta, phi, r = angles rotvec = r * np.asarray((np.sin(theta) * np.cos(phi), np.sin(theta) * np.sin(phi), np.cos(theta))) return rotvec elif torch.is_tensor(angles): if angles.ndim > 1: theta, phi, r = angles.T rotvec = r[:, None] * torch.stack((theta.sin() * phi.cos(), theta.sin() * phi.sin(), theta.cos())).T else: theta, phi, r = angles rotvec = r * torch.Tensor(theta.sin() * phi.cos(), theta.sin() * phi.sin(), theta.cos()) return rotvec else: print("Array type not supported! Must be np.ndarray or torch.tensor") return None
[docs] def cart2sph_rotvec(rotvec): """ transform rotation vector with axis rotvec/norm(rotvec) and angle ||rotvec|| to spherical coordinates theta, phi and r=||rotvec|| Parameters ---------- rotvec : (nx3) x, y, z Returns ------- angles : (nx3) theta, phi, r """ if isinstance(rotvec, np.ndarray): r = np.linalg.norm(rotvec, axis=-1) if rotvec.ndim == 1: rotvec = rotvec[None, :] r = np.asarray(r)[None] unit_vector = rotvec / r[:, None] # convert unit vector to angles theta = np.arctan2(np.sqrt(unit_vector[:, 0] ** 2 + unit_vector[:, 1] ** 2), unit_vector[:, 2]) phi = np.arctan2(unit_vector[:, 1], unit_vector[:, 0]) if rotvec.ndim == 1: return np.concatenate((theta, phi, r), axis=-1) # polar, azimuthal, applied rotation else: return np.concatenate((theta[:, None], phi[:, None], r[:, None]), axis=-1) # polar, azimuthal, applied rotation elif torch.is_tensor(rotvec): r = torch.linalg.norm(rotvec, axis=-1) if rotvec.ndim == 1: rotvec = rotvec[None, :] r = torch.Tensor(r)[None] unit_vector = rotvec / r[:, None] # convert unit vector to angles theta = torch.arctan2(torch.sqrt(unit_vector[:, 0] ** 2 + unit_vector[:, 1] ** 2), unit_vector[:, 2]) phi = torch.arctan2(unit_vector[:, 1], unit_vector[:, 0]) if rotvec.ndim == 1: return torch.cat((theta, phi, r), dim=-1) # polar, azimuthal, applied rotation else: return torch.cat((theta[:, None], phi[:, None], r[:, None]), dim=-1) # polar, azimuthal, applied rotation else: print("Array type not supported! Must be np.ndarray or torch.tensor") return None
[docs] def batch_compute_fractional_transform(cell_lengths, cell_angles): """ compute f->c and c->f transforms as well as cell volume in a vectorized, differentiable way Parameters ---------- cell_lengths : torch.tensor(nx3) a, b, c cell_angles : torch.tensor(nx3) alpha, beta, gamma Returns ------- fc_transform : torch.tensor(n,3,3) cf_transform : torch.tensor(n,3,3) cell_volumes : torch.tensor(n) """ cos_a = torch.cos(cell_angles) sin_a = torch.sin(cell_angles) ''' Calculate volume of the unit cell ''' val = 1.0 - cos_a[:, 0] ** 2 - cos_a[:, 1] ** 2 - cos_a[:, 2] ** 2 + 2.0 * cos_a[:, 0] * cos_a[:, 1] * cos_a[:, 2] vol = torch.sign(val) * torch.prod(cell_lengths, dim=1) * torch.sqrt( torch.abs(val)) # technically a signed quanitity ''' Setting the transformation matrix ''' T_fc_list = torch.zeros((len(cell_lengths), 3, 3), device=cell_lengths.device, dtype=cell_lengths.dtype) T_cf_list = torch.zeros((len(cell_lengths), 3, 3), device=cell_lengths.device, dtype=cell_lengths.dtype) ''' Converting from cartesian to fractional ''' T_cf_list[:, 0, 0] = 1.0 / cell_lengths[:, 0] T_cf_list[:, 0, 1] = -cos_a[:, 2] / cell_lengths[:, 0] / sin_a[:, 2] T_cf_list[:, 0, 2] = cell_lengths[:, 1] * cell_lengths[:, 2] * ( cos_a[:, 0] * cos_a[:, 2] - cos_a[:, 1]) / vol / sin_a[:, 2] T_cf_list[:, 1, 1] = 1.0 / cell_lengths[:, 1] / sin_a[:, 2] T_cf_list[:, 1, 2] = cell_lengths[:, 0] * cell_lengths[:, 2] * ( cos_a[:, 1] * cos_a[:, 2] - cos_a[:, 0]) / vol / sin_a[:, 2] T_cf_list[:, 2, 2] = cell_lengths[:, 0] * cell_lengths[:, 1] * sin_a[:, 2] / vol ''' Converting from fractional to cartesian ''' T_fc_list[:, 0, 0] = cell_lengths[:, 0] T_fc_list[:, 0, 1] = cell_lengths[:, 1] * cos_a[:, 2] T_fc_list[:, 0, 2] = cell_lengths[:, 2] * cos_a[:, 1] T_fc_list[:, 1, 1] = cell_lengths[:, 1] * sin_a[:, 2] T_fc_list[:, 1, 2] = cell_lengths[:, 2] * (cos_a[:, 0] - cos_a[:, 1] * cos_a[:, 2]) / sin_a[:, 2] T_fc_list[:, 2, 2] = vol / cell_lengths[:, 0] / cell_lengths[:, 1] / sin_a[:, 2] return T_fc_list, T_cf_list, torch.abs(vol)
[docs] def cell_vol_torch(v: torch.tensor, a: torch.tensor): """ compute the volume of a parallelpiped given basis vector lengths and internal angles Parameters ---------- v : torch.tensor(3) [a b c] a : torch.tensor(3) [alpha beta gamma] Returns ------- cell_volume : float """ ''' Calculate cos and sin of cell angles ''' cos_a = torch.cos(a) # in natural units ''' Calculate volume of the unit cell ''' vol = v[0] * v[1] * v[2] * torch.sqrt( torch.abs(1.0 - cos_a[0] ** 2 - cos_a[1] ** 2 - cos_a[2] ** 2 + 2.0 * cos_a[0] * cos_a[1] * cos_a[2])) return vol
[docs] def batch_cell_vol_torch(v: torch.tensor, a: torch.tensor): """ Batched computation of unit cell volumes given basis vector lengths and internal angles. Parameters ---------- v : torch.tensor(n, 3) batch of [a, b, c] lengths a : torch.tensor(n, 3) batch of [alpha, beta, gamma] angles in radians Returns ------- cell_volumes : torch.tensor(n) """ ''' Calculate cos and sin of cell angles ''' cos_a = torch.cos(a) # in natural units ''' Calculate volume of the unit cell ''' vol = v[:, 0] * v[:, 1] * v[:, 2] * torch.sqrt( torch.abs( 1.0 - cos_a[:, 0] ** 2 - cos_a[:, 1] ** 2 - cos_a[:, 2] ** 2 + 2.0 * cos_a[:, 0] * cos_a[:, 1] * cos_a[:, 2])) return vol
[docs] def cell_vol_angle_factor(cell_angles): """ Compute the angular factor sqrt(1 - cos²α - cos²β - cos²γ + 2cosα cosβ cosγ) used in cell volume calculations. Parameters ---------- cell_angles : torch.tensor(..., 3) [alpha, beta, gamma] in radians; leading batch dimensions are supported Returns ------- factor : torch.tensor(...) """ cos_a = torch.cos(cell_angles) # in natural units return torch.sqrt( torch.abs(1.0 - cos_a[..., 0] ** 2 - cos_a[..., 1] ** 2 - cos_a[..., 2] ** 2 + 2.0 * cos_a[..., 0] * cos_a[..., 1] * cos_a[..., 2]))
[docs] def compute_Ip_handedness(Ip): """ determine the right or left handedness from the cross products of principal inertial axes np.array or torch.tensor input, single or multiple samples Parameters ---------- Ip : (opt n, 3, 3) principal inertial tensor Returns ------- handedness : (n) +/- 1, the handedness of the cross products of principal inertial axes """ if isinstance(Ip, np.ndarray): if Ip.ndim == 2: return np.sign(np.dot(Ip[0], np.cross(Ip[1], Ip[2])).sum()) elif Ip.ndim == 3: return np.sign(np.dot(Ip[:, 0], np.cross(Ip[:, 1], Ip[:, 2], axis=1).T).sum(1)) elif torch.is_tensor(Ip): if Ip.ndim == 2: return torch.sign(torch.mul(Ip[0], torch.cross(Ip[1], Ip[2], dim=0)).sum()).float() elif Ip.ndim == 3: return torch.sign(torch.mul(Ip[:, 0], torch.cross(Ip[:, 1], Ip[:, 2], dim=1)).sum(1)) else: print("Ip handedness calculation failed! Inputs were neither torch.tensor or numpy.array") sys.exit()
[docs] def cell_vol_np(v, a): """ compute the volume of a parallelpiped given basis vector lengths and internal angles Parameters ---------- v : np.array(3) [a b c] a : np.array(3) [alpha beta gamma] Returns ------- cell_volume : float """ """ Calculate cos and sin of cell angles """ cos_a = np.cos(a) # in natural units ''' Calculate volume of the unit cell ''' val = 1.0 - cos_a[0] ** 2 - cos_a[1] ** 2 - cos_a[2] ** 2 + 2.0 * cos_a[0] * cos_a[1] * cos_a[2] vol = v[0] * v[1] * v[2] * np.sqrt(np.abs(val)) # technically a signed quanitity return vol
[docs] def coor_trans_matrix_np(opt, v, a, return_vol=False): """ compute f->c and c->f transforms as well as cell volume in a vectorized, differentiable way Parameters ---------- opt : str 'c_to_f' or 'f_to_c' which direction to transform between fractional and cartesian v : np.array(3) a, b, c a : np.array(3) alpha, beta, gamma return_vol : bool, optional return the absolute value of the cell volume Returns ------- transform : np.array(3,3) cell_volumes : float """ """ Calculate cos and sin of cell angles """ # todo test - enforce this agrees with the torch version if np.amax(a) > np.pi: print('Warning - large angles! Remember to convert to natural units!') cos_a = np.cos(a) sin_a = np.sin(a) ''' Calculate volume of the unit cell ''' val = 1.0 - cos_a[0] ** 2 - cos_a[1] ** 2 - cos_a[2] ** 2 + 2.0 * cos_a[0] * cos_a[1] * cos_a[2] vol = np.sign(val) * v[0] * v[1] * v[2] * np.sqrt(np.abs(val)) # technically a signed quanitity ''' Setting the transformation matrix ''' m = np.zeros((3, 3), dtype=np.float64) if opt == 'c_to_f': ''' Converting from cartesian to fractional ''' m[0, 0] = 1.0 / v[0] m[0, 1] = -cos_a[2] / v[0] / sin_a[2] m[0, 2] = v[1] * v[2] * (cos_a[0] * cos_a[2] - cos_a[1]) / vol / sin_a[2] m[1, 1] = 1.0 / v[1] / sin_a[2] m[1, 2] = v[0] * v[2] * (cos_a[1] * cos_a[2] - cos_a[0]) / vol / sin_a[2] m[2, 2] = v[0] * v[1] * sin_a[2] / vol elif opt == 'f_to_c': ''' Converting from fractional to cartesian ''' m[0, 0] = v[0] m[0, 1] = v[1] * cos_a[2] m[0, 2] = v[2] * cos_a[1] m[1, 1] = v[1] * sin_a[2] m[1, 2] = v[2] * (cos_a[0] - cos_a[1] * cos_a[2]) / sin_a[2] m[2, 2] = vol / v[0] / v[1] / sin_a[2] if return_vol: return m, np.abs(vol) else: return m
[docs] def mol_batch_vdW_volume(mol_batch): """ wrapper for batch_compute_vdW_volume """ return batch_compute_molecule_volume( mol_batch.z, mol_batch.pos, mol_batch.batch, mol_batch.num_graphs, torch.tensor(list(VDW_RADII.values()), device=mol_batch.z.device, dtype=torch.float32))
[docs] def batch_compute_molecule_volume( atom_types: torch.LongTensor, pos: torch.FloatTensor, batch: torch.LongTensor, num_graphs: int, vdw_radii_tensor: torch.Tensor ): """ Estimate molecular vdW volumes for a batch of molecules using a sphere-overlap correction. Sums atomic sphere volumes then subtracts pairwise sphere-sphere intersection volumes scaled by an empirical correction factor (~0.73) to approximate triple overlaps. Parameters ---------- atom_types : torch.LongTensor(n) Atomic numbers used to index vdw_radii_tensor pos : torch.FloatTensor(n, 3) Atomic coordinates batch : torch.LongTensor(n) Graph index for each atom num_graphs : int vdw_radii_tensor : torch.Tensor Lookup table of vdW radii indexed by atomic number Returns ------- corrected_mol_volume : torch.FloatTensor(num_graphs) """ atom_volumes = 4 / 3 * torch.pi * vdw_radii_tensor[atom_types] ** 3 raw_vdw_volumes = scatter(atom_volumes, batch, dim=0, dim_size=num_graphs, reduce='sum') bonds_i, bonds_j = radius(pos, pos, r=2 * vdw_radii_tensor.max(), batch_x=batch, batch_y=batch, max_num_neighbors=100) mask = ~(bonds_i >= bonds_j) # eliminate duplicates bonds_i, bonds_j = bonds_i[mask], bonds_j[mask] bond_lengths = torch.linalg.norm(pos[bonds_i] - pos[bonds_j], dim=1) radii_i, radii_j = vdw_radii_tensor[atom_types[bonds_i]], vdw_radii_tensor[atom_types[bonds_j]] # https://mathworld.wolfram.com/Sphere-SphereIntersection.html # c1 = pi*(r1+r2-d)^2 # c2 = d^2 + 2dr1 + 2dr2 + 6r1r2 - 3r1^2 - 3r2^2 # c3 = 12d c1 = torch.pi * (radii_i + radii_j - bond_lengths) ** 2 c2 = bond_lengths ** 2 + 2 * bond_lengths * ( radii_i + radii_j) - 3 * radii_i ** 2 - 3 * radii_j ** 2 + 6 * radii_i * radii_j c3 = 12 * bond_lengths # sphere_overlaps = (torch.pi * (radii_i + radii_j - bond_lengths) ** 2 * # (bond_lengths ** 2 + 2 * bond_lengths * radii_j - 3 * radii_j ** 2 # + 2 * bond_lengths * radii_i + 6 * radii_j * radii_i - 3 * radii_i ** 2) / (12 * bond_lengths)) sphere_overlaps = c1 * c2 / c3 sphere_overlaps[bond_lengths > (radii_i + radii_j)] = 0 molwise_sphere_overlaps = scatter(sphere_overlaps, batch[bonds_i], dim=0, dim_size=num_graphs, reduce='sum') # estimate correction factor (omits triple overlaps) by comparison of representative molecule batch with well-converged probe method corrected_mol_volume = raw_vdw_volumes - molwise_sphere_overlaps * 0.7272 # a very coarse correction factor return corrected_mol_volume
[docs] def probe_compute_molecule_volume( atom_types: torch.LongTensor, pos: torch.FloatTensor, batch: torch.LongTensor, num_graphs: int, vdw_radii_tensor: torch.Tensor, probes_per_mol: int = 100, eps: float = 1e-2, max_iters: int = 1000, min_iters: int = 5 ): """ Estimate molecular vdW volumes via Monte Carlo probe sampling, iterating until convergence. Random probes are scattered within each molecule's bounding box; the fraction inside any atomic vdW sphere gives the volume estimate. Runs until the relative change in the running mean drops below eps for min_iters consecutive iterations. Parameters ---------- atom_types : torch.LongTensor(n) pos : torch.FloatTensor(n, 3) batch : torch.LongTensor(n) num_graphs : int vdw_radii_tensor : torch.Tensor Lookup table of vdW radii indexed by atomic number probes_per_mol : int Number of random probes per molecule per iteration eps : float Convergence threshold on relative change in running mean max_iters : int Maximum number of sampling iterations min_iters : int Minimum iterations before convergence is checked Returns ------- volumes : torch.FloatTensor(num_graphs) Converged volume estimates """ volume_record = [] converged = False iter = 0 while not converged and iter < max_iters: iter += 1 mol_min_corner = scatter(pos, batch, dim=0, reduce='min') - vdw_radii_tensor.max() mol_max_corner = scatter(pos, batch, dim=0, reduce='max') + vdw_radii_tensor.max() bounding_volumes = torch.prod(mol_max_corner - mol_min_corner, dim=1) probe_batch = torch.arange(num_graphs, device=batch.device).repeat_interleave(probes_per_mol) random_probes = mol_min_corner[probe_batch] + ( mol_max_corner[probe_batch] - mol_min_corner[probe_batch]) * torch.rand( (probes_per_mol * num_graphs, 3), device=pos.device) # Compute squared distances between each probe and each atom edge_i, edge_j = radius(x=pos, y=random_probes, r=vdw_radii_tensor.max(), batch_x=batch, batch_y=probe_batch, max_num_neighbors=probes_per_mol) dists = torch.linalg.norm(random_probes[edge_i] - pos[edge_j], dim=1) # Check if each probe is inside any sphere inside_any_sphere = (dists < vdw_radii_tensor[atom_types[edge_j]]).float() # we want a maximum of one contact per probe probe_has_a_contact = scatter(inside_any_sphere, edge_i, reduce='max', dim_size=probes_per_mol * num_graphs, dim=0) inside_sphere_frac = scatter(probe_has_a_contact, probe_batch, reduce='sum', dim=0, dim_size=num_graphs) / probes_per_mol # Estimate molecular volume mol_volume = bounding_volumes * inside_sphere_frac volume_record.append(mol_volume) if iter > min_iters: volumes = torch.stack(volume_record) cum_vols = torch.cumsum(volumes, dim=0) cum_iters = torch.arange(1, len(volumes) + 1, device=volumes.device)[:, None] cum_means = cum_vols / cum_iters rel_diffs = torch.diff(cum_means / cum_means.mean(0)[None, :], dim=0) criteria = rel_diffs[-min_iters:].abs().mean(0) if torch.all(criteria < eps): # relative change in running average less than eps converged = True return cum_means[-1, :]
[docs] def grid_compute_molecule_volume(atom_types, pos, vdw_radii_tensor, eps): """ brute force grid approach to computing vdW volume for a single molecule Parameters ---------- atom_types pos vdw_radii_tensor Returns ------- """ convergence_history = [] dx = 0.1 converged = False ind = -1 max_iters = 10 while converged is False and ind < max_iters: dx *= 0.75 xmin, ymin, zmin = (pos.amin(0) - vdw_radii_tensor.amax()) xmax, ymax, zmax = (pos.amax(0) + vdw_radii_tensor.amax()) num_x = int((xmax - xmin) / dx) num_y = int((ymax - ymin) / dx) num_z = int((zmax - zmin) / dx) grid = torch.meshgrid(torch.linspace(xmin, xmax, num_x), torch.linspace(ymin, ymax, num_y), torch.linspace(zmin, zmax, num_z), indexing='xy' ) grid = torch.stack(grid) grid_flat = grid.reshape(3, num_x * num_y * num_z).T edges = radius(x=grid_flat, y=pos, r=vdw_radii_tensor.amax(), max_num_neighbors=int(1e12)) dists = torch.linalg.norm(pos[edges[0]] - grid_flat[edges[1]], dim=1) close_enough = dists <= vdw_radii_tensor[atom_types[edges[0]]] overlapped_points = len(edges[1, close_enough].unique()) box_volume = (xmax - xmin) * (ymax - ymin) * (zmax - zmin) overlapping_fraction = overlapped_points / len(grid_flat) occupied_volume = box_volume * overlapping_fraction convergence_history.append(float(occupied_volume)) ind += 1 if ind > 1: conv = abs(convergence_history[-2] - convergence_history[-1]) / convergence_history[1] if conv < eps: converged = True print(ind) print(conv) return occupied_volume
[docs] def grid_compute_molecule_volume_pointwise(atom_types, pos, vdw_radii_tensor, eps): """ brute force grid approach to computing vdW volume for a single molecule Parameters ---------- atom_types pos vdw_radii_tensor Returns ------- """ convergence_history = [] dx = 0.1 converged = False ind = -1 max_iters = 10 while converged is False and ind < max_iters: dx *= 0.75 xmin, ymin, zmin = (pos.amin(0) - vdw_radii_tensor.amax()) xmax, ymax, zmax = (pos.amax(0) + vdw_radii_tensor.amax()) num_x = int((xmax - xmin) / dx) num_y = int((ymax - ymin) / dx) num_z = int((zmax - zmin) / dx) grid = torch.meshgrid(torch.linspace(xmin, xmax, num_x), torch.linspace(ymin, ymax, num_y), torch.linspace(zmin, zmax, num_z), indexing='xy' ) grid = torch.stack(grid) grid_flat = grid.reshape(3, num_x * num_y * num_z).T edges = radius(x=grid_flat, y=pos, r=vdw_radii_tensor.amax(), max_num_neighbors=int(1e12)) dists = torch.linalg.norm(pos[edges[0]] - grid_flat[edges[1]], dim=1) close_enough = dists <= vdw_radii_tensor[atom_types[edges[0]]] overlapped_points = len(edges[1, close_enough].unique()) box_volume = (xmax - xmin) * (ymax - ymin) * (zmax - zmin) overlapping_fraction = overlapped_points / len(grid_flat) occupied_volume = box_volume * overlapping_fraction convergence_history.append(float(occupied_volume)) ind += 1 if ind > 1: conv = abs(convergence_history[-2] - convergence_history[-1]) / convergence_history[1] if conv < eps: converged = True print(ind) print(conv) return occupied_volume
[docs] def norm_circular_components(components: torch.tensor): """ Use Pythagoras to norm the sum of squares to the unit circle. Parameters ---------- components : torch.tensor(n, 2) Returns ------- normed_components : torch.tensor(n, 2) """ return components / torch.sqrt(torch.sum(components ** 2, dim=-1))[:, None]
[docs] def components2angle(components: torch.tensor, norm_components=True): """ Take two non-normalized components[n, 2] representing sin(angle) and cos(angle), compute the resulting angle, following https://ai.stackexchange.com/questions/38045/how-can-i-encode-angle-data-to-train-neural-networks Optionally norm the sum of squares - doesn't appear to do much though. Parameters ---------- components : torch.tensor(n, 2) norm_components : bool, optional Returns ------- angles : torch.tensor(n, 2) """ if norm_components: normed_components = norm_circular_components(components) angles = torch.atan2(normed_components[:, 0], normed_components[:, 1]) else: angles = torch.atan2(components[:, 0], components[:, 1]) return angles
[docs] def angle2components(angle: torch.tensor): """ Decompose an angle into sin(angle) and cos(angle) Parameters ---------- angle : torch.tensor(n) Returns ------- sin(angle), cos(angle) : torch.tensor, torch.tensor """ return torch.cat((torch.sin(angle)[:, None], torch.cos(angle)[:, None]), dim=1)
[docs] def enforce_crystal_system(lattice_lengths, lattice_angles, sg_inds, symmetries_dict: Optional[dict] = None ): """ enforce physical bounds on cell parameters https://en.wikipedia.org/wiki/Crystal_system """ # todo vectorize this function, and clean it up in general if symmetries_dict is None: symmetries_dict = init_sym_info() if sg_inds.ndim == 0: lattices = [symmetries_dict['lattice_type'][int(sg_inds)]] else: lattices = [symmetries_dict['lattice_type'][int(sg_inds[n])] for n in range(len(lattice_lengths))] pi_tensor = torch.ones_like(lattice_lengths[0, 0]) * torch.pi fixed_lengths = torch.zeros_like(lattice_lengths) fixed_angles = torch.zeros_like(lattice_angles) for i in range(len(lattice_lengths)): lengths = lattice_lengths[i] angles = lattice_angles[i] lattice = lattices[i] # enforce agreement with crystal system if lattice.lower() == 'triclinic': # anything goes fixed_lengths[i] = lengths * 1 fixed_angles[i] = angles * 1 elif lattice.lower() == 'monoclinic': # fix alpha and gamma to pi/2 fixed_lengths[i] = lengths * 1 fixed_angles[i] = torch.stack(( pi_tensor.clone() / 2, angles[1], pi_tensor.clone() / 2, ), dim=- 1) elif lattice.lower() == 'orthorhombic': # fix all angles at pi/2 fixed_lengths[i] = lengths * 1 fixed_angles[i] = torch.stack(( pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() / 2, ), dim=- 1) elif lattice.lower() == 'tetragonal': # fix all angles pi/2 and a=b mean_tensor = lengths[0] * 1 fixed_lengths[i] = torch.stack(( mean_tensor, mean_tensor, lengths[2] * 1, ), dim=- 1) fixed_angles[i] = torch.stack(( pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() / 2, ), dim=- 1) elif lattice.lower() == 'hexagonal': # a=b # alpha beta are pi/2, gamma is 2pi/3 mean_tensor = lengths[0] * 1 fixed_lengths[i] = torch.stack(( mean_tensor, mean_tensor, lengths[2] * 1, ), dim=- 1) fixed_angles[i] = torch.stack(( pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() * 2 / 3, ), dim=- 1) # elif lattice.lower() == 'trigonal': elif lattice.lower() == 'rhombohedral': # mean of abc vector lengths # mean of all angles mean_tensor = lengths[0] * 1 fixed_lengths[i] = torch.stack(( mean_tensor, mean_tensor, mean_tensor, ), dim=- 1) mean_angle = angles[0] fixed_angles[i] = torch.stack(( mean_angle, mean_angle, mean_angle, ), dim=- 1) elif lattice.lower() == 'cubic': # all angles 90 all lengths equal mean_tensor = lengths[0] * 1 fixed_lengths[i] = torch.stack(( mean_tensor, mean_tensor, mean_tensor, ), dim=- 1) fixed_angles[i] = torch.stack(( pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() / 2, ), dim=- 1) else: print(lattice + ' is not a valid crystal lattice!') sys.exit() return fixed_lengths, fixed_angles
[docs] def enforce_crystal_system2(lattice_lengths, lattice_angles, lattices): """ enforce physical bounds on cell parameters https://en.wikipedia.org/wiki/Crystal_system """ # todo double check these limits pi_tensor = torch.ones_like(lattice_lengths[0, 0]) * torch.pi fixed_lengths = torch.zeros_like(lattice_lengths) fixed_angles = torch.zeros_like(lattice_angles) for i in range(len(lattice_lengths)): lengths = lattice_lengths[i] angles = lattice_angles[i] lattice = lattices[i] # enforce agreement with crystal system if lattice.lower() == 'triclinic': # anything goes fixed_lengths[i] = lengths * 1 fixed_angles[i] = angles * 1 elif lattice.lower() == 'monoclinic': # fix alpha and gamma to pi/2 fixed_lengths[i] = lengths * 1 fixed_angles[i] = torch.stack(( pi_tensor.clone() / 2, angles[1], pi_tensor.clone() / 2, ), dim=- 1) elif lattice.lower() == 'orthorhombic': # fix all angles at pi/2 fixed_lengths[i] = lengths * 1 fixed_angles[i] = torch.stack(( pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() / 2, ), dim=- 1) elif lattice.lower() == 'tetragonal': # fix all angles pi/2 and take the mean of a & b vectors mean_tensor = torch.mean(lengths[0:2]) fixed_lengths[i] = torch.stack(( mean_tensor, mean_tensor, lengths[2] * 1, ), dim=- 1) fixed_angles[i] = torch.stack(( pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() / 2, ), dim=- 1) elif lattice.lower() == 'hexagonal': # mean of ab, c is free # alpha beta are pi/2, gamma is 2pi/3 mean_tensor = torch.mean(lengths[0:2]) fixed_lengths[i] = torch.stack(( mean_tensor, mean_tensor, lengths[2] * 1, ), dim=- 1) fixed_angles[i] = torch.stack(( pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() * 2 / 3, ), dim=- 1) # elif lattice.lower() == 'trigonal': elif lattice.lower() == 'rhombohedral': # mean of abc vector lengths # mean of all angles mean_tensor = torch.mean(lengths) fixed_lengths[i] = torch.stack(( mean_tensor, mean_tensor, mean_tensor, ), dim=- 1) mean_angle = torch.mean(angles) fixed_angles[i] = torch.stack(( mean_angle, mean_angle, mean_angle, ), dim=- 1) elif lattice.lower() == 'cubic': # all angles 90 all lengths equal mean_tensor = torch.mean(lengths) fixed_lengths[i] = torch.stack(( mean_tensor, mean_tensor, mean_tensor, ), dim=- 1) fixed_angles[i] = torch.stack(( pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() / 2, ), dim=- 1) else: print(lattice + ' is not a valid crystal lattice!') sys.exit() return fixed_lengths, fixed_angles
[docs] def cell_parameters_to_box_vectors(opt: str, cell_lengths: torch.tensor, cell_angles: torch.tensor, return_vol: bool = False): """ # TODO I believe this is a duplicate function Initially borrowed from Nikos Quickly convert from cell lengths and angles to fractional transform matrices fractional->cartesian or cartesian->fractional """ ''' Calculate cos and sin of cell angles ''' cos_a = torch.cos(cell_angles) sin_a = torch.sin(cell_angles) ''' Calculate volume of the unit cell ''' val = 1.0 - cos_a[0] ** 2 - cos_a[1] ** 2 - cos_a[2] ** 2 + 2.0 * cos_a[0] * cos_a[1] * cos_a[2] vol = torch.sign(val) * cell_lengths[0] * cell_lengths[1] * cell_lengths[2] * torch.sqrt( torch.abs(val)) # technically a signed quanitity ''' Setting the transformation matrix ''' m = torch.zeros((3, 3), dtype=torch.float32, device=cell_lengths.device) if opt == 'c_to_f': ''' Converting from cartesian to fractional ''' m[0, 0] = 1.0 / cell_lengths[0] m[0, 1] = -cos_a[2] / cell_lengths[0] / sin_a[2] m[0, 2] = cell_lengths[1] * cell_lengths[2] * (cos_a[0] * cos_a[2] - cos_a[1]) / vol / sin_a[2] m[1, 1] = 1.0 / cell_lengths[1] / sin_a[2] m[1, 2] = cell_lengths[0] * cell_lengths[2] * (cos_a[1] * cos_a[2] - cos_a[0]) / vol / sin_a[2] m[2, 2] = cell_lengths[0] * cell_lengths[1] * sin_a[2] / vol elif opt == 'f_to_c': ''' Converting from fractional to cartesian ''' m[0, 0] = cell_lengths[0] m[0, 1] = cell_lengths[1] * cos_a[2] m[0, 2] = cell_lengths[2] * cos_a[1] m[1, 1] = cell_lengths[1] * sin_a[2] m[1, 2] = cell_lengths[2] * (cos_a[0] - cos_a[1] * cos_a[2]) / sin_a[2] m[2, 2] = vol / cell_lengths[0] / cell_lengths[1] / sin_a[2] # todo create m in a single-step if return_vol: return m, torch.abs(vol) else: return m
[docs] def compute_mol_radius(coords: torch.FloatTensor) -> Tensor: """ Compute the radius of a single molecule as the maximum distance from the centroid to any atom. Parameters ---------- coords : torch.FloatTensor(n, 3) Returns ------- radius : torch.FloatTensor scalar """ centroid = coords.mean(0) return torch.amax(torch.linalg.norm(coords - centroid, dim=-1))
[docs] def batch_compute_mol_radius(coords: torch.FloatTensor, batch: torch.LongTensor, num_graphs: int, nodes_per_graph: torch.LongTensor, ) -> Tensor: """ Batched version of compute_mol_radius. Parameters ---------- coords : torch.FloatTensor(n, 3) batch : torch.LongTensor(n) num_graphs : int nodes_per_graph : torch.LongTensor(num_graphs) Returns ------- radii : torch.FloatTensor(num_graphs) """ centroids = get_batch_centroids(coords, batch, num_graphs) dists = torch.linalg.norm(coords - centroids.repeat_interleave(nodes_per_graph, 0), dim=-1) return scatter(dists, batch, dim=0, dim_size=num_graphs, reduce='max')
[docs] def get_batch_centroids(coords: torch.FloatTensor, batch: torch.LongTensor, num_graphs: int, ) -> Tensor: """ Compute the centroid (mean position) for each graph in a batch. Parameters ---------- coords : torch.FloatTensor(n, 3) batch : torch.LongTensor(n) num_graphs : int Returns ------- centroids : torch.FloatTensor(num_graphs, 3) """ return scatter(coords, batch, dim=0, dim_size=num_graphs, reduce='mean')
[docs] def center_batch(coords: torch.FloatTensor, batch: torch.LongTensor, num_graphs: int, nodes_per_graph: torch.LongTensor, center_on_heavy_atoms: bool = False, atom_types: Optional[torch.LongTensor] = None, ) -> Tensor: """ Subtract the centroid from each graph's coordinates, returning zero-centered positions. Parameters ---------- coords : torch.FloatTensor(n, 3) batch : torch.LongTensor(n) num_graphs : int nodes_per_graph : torch.LongTensor(num_graphs) center_on_heavy_atoms : bool If True, compute the centroid using only heavy atoms (atom_types > 1) but translate all atoms atom_types : torch.LongTensor(n), optional Required when center_on_heavy_atoms is True Returns ------- coords_centered : torch.FloatTensor(n, 3) """ if center_on_heavy_atoms: mask = atom_types > 1 centroids = get_batch_centroids(coords[mask], batch[mask], num_graphs) else: centroids = get_batch_centroids(coords, batch, num_graphs) coords_out = coords - centroids.repeat_interleave(nodes_per_graph, 0) return coords_out
[docs] def batch_compute_mol_mass(z: torch.LongTensor, batch: torch.LongTensor, masses_tensor: torch.FloatTensor, num_graphs: int) -> Tensor: """ Sum atomic masses for each graph in a batch. Parameters ---------- z : torch.LongTensor(n) Atomic numbers used to index masses_tensor batch : torch.LongTensor(n) masses_tensor : torch.FloatTensor Lookup table of atomic masses indexed by atomic number num_graphs : int Returns ------- masses : torch.FloatTensor(num_graphs) """ return scatter(masses_tensor[z], batch, dim=0, dim_size=num_graphs, reduce='sum')
[docs] def compute_mol_mass(z: torch.LongTensor, masses_tensor: torch.FloatTensor) -> Tensor: """ Sum atomic masses for a single molecule. Parameters ---------- z : torch.LongTensor(n) Atomic numbers masses_tensor : torch.FloatTensor Lookup table of atomic masses indexed by atomic number Returns ------- mass : torch.FloatTensor scalar """ return torch.sum(masses_tensor[z])
[docs] def rotvec2rotmat(mol_rotation: torch.tensor, basis='cartesian'): """ get applied rotation matrix mol_rotation here is a list of rotation vectors [n_samples, 3] rotvec -> rotation matrix directly (see https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula) rotvec -> quat -> rotation matrix (old way) """ if basis == 'cartesian': r = torch.linalg.norm(mol_rotation, dim=1) unit_vector = mol_rotation / r[:, None] elif basis == 'spherical': # psi, phi, (spherical unit vector) theta (rotation vector) r = mol_rotation[ :, -1] # third dimension in spherical basis is the norm #torch.linalg.norm(mol_rotation, dim=1) mol_rotation = sph2cart_rotvec(mol_rotation) unit_vector = mol_rotation / r[:, None] else: print(f'{basis} is not a valid orientation parameterization!') sys.exit() K = torch.stack(( # matrix representing rotation axis torch.stack((torch.zeros_like(unit_vector[:, 0]), -unit_vector[:, 2], unit_vector[:, 1]), dim=1), torch.stack((unit_vector[:, 2], torch.zeros_like(unit_vector[:, 0]), -unit_vector[:, 0]), dim=1), torch.stack((-unit_vector[:, 1], unit_vector[:, 0], torch.zeros_like(unit_vector[:, 0])), dim=1) ), dim=1) identity_batch = torch.eye(3, device=r.device, dtype=torch.float32)[None, :, :].tile(len(r), 1, 1) applied_rotation_list = identity_batch + torch.sin(r[:, None, None]) * K + (1 - torch.cos(r[:, None, None])) * ( K @ K) return applied_rotation_list
[docs] def extract_rotmat(target_position: torch.FloatTensor, original_position: torch.FloatTensor) -> Tensor: """ Compute the rotation matrix R such that R @ original_position ≈ target_position. Parameters ---------- target_position : torch.FloatTensor(n, 3, 3) or (3, 3) original_position : torch.FloatTensor(n, 3, 3) or (3, 3) Returns ------- rotmat : torch.FloatTensor(n, 3, 3) or (3, 3) """ if target_position.ndim == 3: return torch.einsum('nji, njk -> nik', target_position, torch.linalg.inv(original_position)) elif target_position.ndim == 2: return torch.einsum('ji, jk -> ik', target_position, torch.linalg.inv(original_position)) else: assert False, "Target position must have at least 2 dimensions"
[docs] def apply_rotation_to_batch(elems, rotations, batch): """ Apply per-graph rotation matrices to a batch of vectors. Parameters ---------- elems : torch.FloatTensor(n, 3) rotations : torch.FloatTensor(num_graphs, 3, 3) batch : torch.LongTensor(n) Returns ------- rotated : torch.FloatTensor(n, 3) """ return torch.einsum('nij, nj -> ni', rotations[batch], elems)
[docs] def rotmat2rotvec(rotation_matrix_list, warn_on_bad_determinant=True): """ Convert a batch of rotation matrices to rotation vectors (axis-angle). Uses the Rodrigues formula: axis from the skew-symmetric part, angle from the trace. Degenerate cases (near-identity or near-pi rotations) are handled by clamping to a fallback rotation of pi around [1,1,1]. Parameters ---------- rotation_matrix_list : torch.FloatTensor(n, 3, 3) warn_on_bad_determinant : bool Print a warning if any matrix has det < 0 (i.e. is a reflection, not a rotation) Returns ------- rotvecs : torch.FloatTensor(n, 3) """ # Fixed! if warn_on_bad_determinant: det = torch.linalg.det(rotation_matrix_list) bad_dets = det < 0.0 if bad_dets.any(): num_bad = bad_dets.sum().item() print(f"[rotmat2rotvec] WARNING: {num_bad} matrices have determinant < 0. " "These are reflections (not in SO(3)) and cannot be converted to rotation vectors.") direction_vector_list = torch.stack([ rotation_matrix_list[:, 2, 1] - rotation_matrix_list[:, 1, 2], rotation_matrix_list[:, 0, 2] - rotation_matrix_list[:, 2, 0], rotation_matrix_list[:, 1, 0] - rotation_matrix_list[:, 0, 1]], ).T trace = torch.einsum('nii->n', rotation_matrix_list) r_arg = (trace - 1) / 2 r = torch.arccos(r_arg) bad_inds = torch.any(torch.stack([r_arg.abs() >= 1, torch.isnan(r), direction_vector_list.sum(1) == 0, torch.isnan(direction_vector_list).sum(dim=1) > 0] ).T, dim=1) direction_vector_list[bad_inds, :] = torch.ones_like(direction_vector_list[bad_inds, :]) r[bad_inds] = torch.pi rotvecs = direction_vector_list / (direction_vector_list.norm(dim=1, keepdim=True).clamp(min=1e-8)) * r[:, None] return rotvecs
""" old version direction_vector = torch.tensor([ rotation_matrix[2, 1] - rotation_matrix[1, 2], rotation_matrix[0, 2] - rotation_matrix[2, 0], rotation_matrix[1, 0] - rotation_matrix[0, 1]], device=rotation_matrix.device, dtype=torch.float32) # 32 precision is limiting here in some cases rotvec_list.append(direction_vector / torch.linalg.norm(direction_vector) * r) """
[docs] def sample_random_valid_rotvecs(num_samples): """ Sample uniformly random rotation vectors with theta restricted to the upper half-sphere (z ≥ 0). Directions are drawn from a Gaussian (giving uniform spherical coverage) and norms are drawn uniformly from (0, 2π]. Parameters ---------- num_samples : int Returns ------- rotvecs : torch.FloatTensor(num_samples, 3) """ # random directions on the sphere, getting naturally the correct distribution of theta, phi random_vectors = torch.randn(size=(num_samples, 3)) # set norms uniformly between 0-2pi norms = random_vectors.norm(dim=1) applied_norms = (torch.rand(num_samples) * 2 * torch.pi).clip(min=0.05) # cannot be exactly zero random_vectors = random_vectors / norms[:, None] * applied_norms[:, None] # restrict theta to upper half-sphere (positive z) random_vectors[:, 2] = torch.abs(random_vectors[:, 2]) return random_vectors
[docs] def embed_vector_to_rank3(v): """embed an nxk vector as a symmetric 3-tensor""" delta = torch.eye(3) P = torch.einsum('ij,nk->nijk', delta, v) + \ torch.einsum('ik,nj->nijk', delta, v) + \ torch.einsum('jk,ni->nijk', delta, v) return P / 3 # Normalization factor
[docs] def fractional_transform(coords, transform_matrix): """ Transform between fractional/cartesian bases. Assumes the fractional->cartesian transform is the transpose of the box vectors Args: coords: transform_matrix: Returns: transformed_coords """ if isinstance(coords, np.ndarray): return fractional_transform_np(coords, transform_matrix) elif torch.is_tensor(coords): return fractional_transform_torch(coords, transform_matrix) else: assert False
[docs] def fractional_transform_np(coords, transform_matrix): """ Apply a fractional/cartesian transform to numpy coordinate arrays. Dispatches on the combination of coords and transform_matrix dimensionality: (n,3)+(3,3) → per-point transform; (n,m,3)+(3,3) → per-atom-in-molecule; (n,3)+(n,3,3) → per-graph batched transform. Parameters ---------- coords : np.ndarray transform_matrix : np.ndarray Returns ------- transformed : np.ndarray """ if coords.ndim == 2 and transform_matrix.ndim == 2: return np.einsum('nj,ij->ni', coords, transform_matrix) elif coords.ndim == 3 and transform_matrix.ndim == 2: return np.einsum('nmj,ij->nmi', coords, transform_matrix) elif coords.ndim == 2 and transform_matrix.ndim == 3: return np.einsum('nj,nij->ni', coords, transform_matrix)
[docs] def fractional_transform_torch(coords, transform_matrix): """ Apply a fractional/cartesian transform to torch coordinate tensors. Same dispatch logic as fractional_transform_np. Parameters ---------- coords : torch.FloatTensor transform_matrix : torch.FloatTensor Returns ------- transformed : torch.FloatTensor """ if coords.ndim == 2 and transform_matrix.ndim == 2: return torch.einsum('nj,ij->ni', (coords, transform_matrix)) elif coords.ndim == 3 and transform_matrix.ndim == 2: return torch.einsum('nmj,ij->nmi', (coords, transform_matrix)) elif coords.ndim == 2 and transform_matrix.ndim == 3: return torch.einsum('nj,nij->ni', (coords, transform_matrix))
[docs] def compute_ellipsoid_volume(e): """ Compute ellipsoid volumes from semi-axis vectors. Parameters ---------- e : torch.FloatTensor(..., 3, 3) Each row is a semi-axis vector; volume = (4/3)π * product of semi-axis lengths Returns ------- volumes : torch.FloatTensor(...) """ return 4 / 3 * torch.pi * e.norm(dim=-1).prod(dim=-1)
[docs] def compute_cosine_similarity_matrix(e1, e2): """ compute the row-to-row dot products for batches of ellipsoids [n, i, j] returns the [n, i, i] matrix of dot products between rows in e1 and e2 permuting the order of e1, e2, results in transposition of the overlap matrix :param e1: :param e2: :return: """ return torch.einsum('nij, nkj -> nik', e1, e2)
[docs] def safe_batched_eigh(covs, chunk=10000): """ Chunked symmetric eigendecomposition with a CPU fallback for CUSOLVER failures. Parameters ---------- covs : torch.FloatTensor(n, d, d) Batch of symmetric matrices chunk : int Number of matrices per kernel call Returns ------- eigenvalues : torch.FloatTensor(n, d) eigenvectors : torch.FloatTensor(n, d, d) """ out_vals, out_vecs = [], [] for i in range(0, covs.shape[0], chunk): cchunk = covs[i:i + chunk] try: ev, evec = torch.linalg.eigh(cchunk) except torch.cuda.OutOfMemoryError: raise except RuntimeError as e: if "CUSOLVER_STATUS_INVALID_VALUE" in str(e): print("Invalid matrix to eigh! Switching to CPU.") ev, evec = torch.linalg.eigh(cchunk.cpu()) ev, evec = ev.to(covs.device), evec.to(covs.device) else: raise e out_vals.append(ev) out_vecs.append(evec) return torch.cat(out_vals, dim=0).float(), torch.cat(out_vecs, dim=0).float()
[docs] def lat2sph_rotvec(lat_orientations, z_prime): """ Map latent orientation parameters (normalized to [-1, 1]) to spherical rotation vectors. Inverse of sph_rotvec2lat. The three latent dimensions map as: theta ∈ [-1,1] → [π/4, 3π/4] (upper half-sphere polar angle) phi ∈ [-1,1] → [-π, π] (azimuthal angle) r ∈ [-1,1] → [0, 2π] (rotation magnitude) Parameters ---------- lat_orientations : torch.FloatTensor(..., z_prime * 3) z_prime : int Number of asymmetric units Returns ------- sph_rotvec : torch.FloatTensor(..., z_prime * 3) Spherical rotation vectors [theta, phi, r] for each asymmetric unit """ lat = lat_orientations.view(*lat_orientations.shape[:-1], z_prime, 3) # allocate output sph = torch.empty_like(lat) halfpi = torch.pi / 2 # theta sph[..., 0] = lat[..., 0] * (torch.pi / 4) + (halfpi / 2) # phi sph[..., 1] = lat[..., 1] * torch.pi # r sph[..., 2] = lat[..., 2] * torch.pi + torch.pi return sph.view(*lat_orientations.shape)
[docs] def simple_latent_distance(l1: torch.Tensor, l2: torch.Tensor) -> torch.Tensor: """euclidean distances, but with wrapped angular dimensions""" max_z_prime = (l1.shape[-1] - 6) // 6 angs = [False] * 6 for zp in range(max_z_prime): angs.extend([False, False, False]) for zp in range(max_z_prime): angs.extend([False, True, True]) # phi and r dimensions arein rotational basis periodic_mask = torch.tensor(angs, device=l1.device) diff = l1 - l2 # wrap periodic dims diff[:, periodic_mask] = ((diff[:, periodic_mask] + 1) % 2) - 1 dist = diff.norm(dim=-1) return dist
[docs] def compute_latent_distance(latents1: torch.Tensor, latents2: torch.Tensor) -> torch.Tensor: """ Compute a distance metric between crystals in the latent parameterization. :param params: :return: """ n_params = latents1.shape[1] assert latents1.shape[-1] == latents2.shape[-1] z_prime = (n_params - 6) // 6 "Latent box parameters are the log lengths 0-2 and angles 3-5" box_params1 = latents1[..., :6] box_params2 = latents2[..., :6] box_dist = (box_params1 - box_params2).norm(dim=-1) "positions defined on [-1, 1]^3 * z_prime. Not periodic (different periodicity in every space group)" positions1 = latents1[..., 6:6 + 3 * z_prime] positions2 = latents2[..., 6:6 + 3 * z_prime] positions_dist = (positions1 - positions2).norm(dim=-1) "angles are defined in a spherical basis [0,pi/2], [-pi, pi], [0,2pi] and renormalized on [-1,1]" lat_orientations1 = latents1[..., 6 + 3 * z_prime:] lat_orientations2 = latents2[..., 6 + 3 * z_prime:] lat_sph_rotvec1 = lat2sph_rotvec(lat_orientations1, z_prime) # [polar, azimuthal, length] lat_sph_rotvec2 = lat2sph_rotvec(lat_orientations2, z_prime) rot_dists = [] for zp in range(z_prime): # this should be replaced with a proper vector distance rmat1 = rotvec2rotmat(lat_sph_rotvec1[..., 3 * zp:3 * zp + 3], 'spherical') rmat2 = rotvec2rotmat(lat_sph_rotvec2[..., 3 * zp:3 * zp + 3], 'spherical') R_delta = rmat1 @ rmat2.transpose(-1, -2) trace = R_delta.diagonal(dim1=-2, dim2=-1).sum(-1) cos_theta = ((trace - 1) / 2).clamp(-1 + 1e-7, 1 - 1e-7) dist = torch.acos(cos_theta) rot_dists.append(dist) rot_dists = torch.stack(rot_dists).sum(0) "overall distance metric" # dists = 0.5 * box_dist + 0.25 * (positions_dist / z_prime / 2.5) + 0.25 * (rot_dists / z_prime / 2) # scales = [2 * sqrt(6), 2*sqrt(3), torch.pi] # maximum variation per dist scales = [1, 0.836, 0.293] # [0.0127, 0.0152, 0.0433] # empirical std over CSD samples dists = scales[0] * 0.5 * box_dist + scales[1] * 0.25 * (positions_dist / z_prime) + scales[2] * 0.25 * ( rot_dists / z_prime) return dists
# def crystal_parameter_distmat(latents, max_batch_size = 2000): # """ # Compute distance to self for a set of crystal latent vectors # :param latents: # :return: # """ # N, K = latents.shape # if N < max_batch_size: # lat1 = latents[:, None, :].expand(N, N, K).reshape(N * N, K) # lat2 = latents[None, :, :].expand(N, N, K).reshape(N * N, K) # # d = crystal_parameter_distance(lat1, lat2) # (N*N,) # distmat = d.reshape(N, N) # else: # distmat = torch.zeros((N, N), device=latents.device) # # for i in range(N): # lat1 = latents[i:i + 1].expand(N, -1) # (N, K) # lat2 = latents # (N, K) # # distmat[i] = crystal_parameter_distance(lat1, lat2) # # # return distmat
[docs] def crystal_parameter_distmat( latents, target_entries=5_000_000, # ≈ distances per call min_block=1, max_block=2048, ): """ Blockwise distance matrix with adaptive block size. Tries to keep ~target_entries distances per kernel call. """ device = latents.device N, K = latents.shape # adaptive block size B = max(min_block, min(max_block, target_entries // max(N, 1))) distmat = torch.empty((N, N), device=device) for i in range(0, N, B): lat1 = latents[i:i + B] # (B, K) b = lat1.shape[0] lat1_exp = lat1[:, None, :].expand(b, N, K).reshape(-1, K) lat2_exp = latents[None, :, :].expand(b, N, K).reshape(-1, K) d = compute_latent_distance(lat1_exp, lat2_exp) distmat[i:i + b] = d.view(b, N) return distmat
[docs] def sph_rotvec2lat(sph_rotvec, z_prime): """ Map spherical rotation vectors to latent orientation parameters normalized to [-1, 1]. Inverse of lat2sph_rotvec. Parameters ---------- sph_rotvec : torch.FloatTensor(..., z_prime * 3) Spherical rotation vectors [theta, phi, r] for each asymmetric unit z_prime : int Number of asymmetric units Returns ------- lat_orientations : torch.FloatTensor(..., z_prime * 3) """ sph = sph_rotvec.view(*sph_rotvec.shape[:-1], z_prime, 3) # allocate output lat = torch.empty_like(sph) halfpi = torch.pi / 2 # theta: [0, π/2] → [-1, 1] lat[..., 0] = (sph[..., 0] - (halfpi / 2)) / (torch.pi / 4) # phi: [-π, π] → [-1, 1] lat[..., 1] = sph[..., 1] / torch.pi # r: [0, 2π] → [-1, 1] lat[..., 2] = (sph[..., 2] - torch.pi) / torch.pi return lat.view(*sph_rotvec.shape)