import sys
from math import sqrt
from typing import Optional
import numpy as np
import torch
from torch import Tensor
from torch_scatter import scatter, scatter_max
from mxtaltools.common.sym_utils import init_sym_info
from mxtaltools.constants.atom_properties import VDW_RADII
from mxtaltools.models.functions.radial_graph import radius
[docs]
def compute_principal_axes_np(coords):
"""
Compute the principal inertial axes for a given set of particle coordinates, ignoring particle mass.
Use our overlap rules to ensure a fixed direction for all axes under almost all circumstances, excepting e.g., certain symmetric molecules.
Parameters
----------
coords
Returns
-------
Ip : np.array(3,3)
Principal inertial axes.
Ipm : np.array(3)
Principal inertial moments
I : np.array(3)
Inertial tensor in original frame
""" # todo harmonize with torch version - currently disagrees ~0.5% of the time
points = coords - coords.mean(0)
x, y, z = points.T
Ixx = np.sum((y ** 2 + z ** 2))
Iyy = np.sum((x ** 2 + z ** 2))
Izz = np.sum((x ** 2 + y ** 2))
Ixy = -np.sum(x * y)
Iyz = -np.sum(y * z)
Ixz = -np.sum(x * z)
I = np.array([[Ixx, Ixy, Ixz], [Ixy, Iyy, Iyz], [Ixz, Iyz, Izz]]) # inertial tensor
Ipm, Ip = np.linalg.eig(I) # principal inertial tensor
Ipm, Ip = np.real(Ipm), np.real(Ip)
sort_inds = np.argsort(Ipm)
Ipm = Ipm[sort_inds]
Ip = Ip.T[sort_inds] # want eigenvectors to be sorted row-wise (rather than column-wise)
# cardinal direction is vector from CoM to the farthest atom
dists = np.linalg.norm(points, axis=1)
max_ind = np.argmax(dists)
max_equivs = np.argwhere(np.round(dists, 8) == np.round(dists[max_ind], 8))[:,
0] # if there are multiple equidistant atoms - pick the one with the lowest index
max_ind = int(np.amin(max_equivs))
direction = points[max_ind]
direction = np.divide(direction, np.linalg.norm(direction))
overlaps = Ip.dot(direction) # check if the principal components point towards or away from the CoG
signs = np.sign(overlaps) # returns zero for zero overlap, but we want it to default to +1 in this case
signs[signs == 0] = 1
Ip = (Ip.T * signs).T # if the vectors have negative overlap, flip the direction
if np.any(
np.abs(
overlaps) < 1e-3): # if any overlaps are vanishing, determine the direction via the RHR (if two overlaps are vanishing, this will not work)
# align the 'good' vectors
fix_ind = np.argmin(np.abs(overlaps)) # vector with vanishing overlap
if compute_Ip_handedness(Ip) < 0: # make sure result is right handed
Ip[fix_ind] = -Ip[fix_ind]
return Ip, Ipm, I
[docs]
def compute_inertial_tensor_torch(x: torch.tensor, y: torch.tensor, z: torch.tensor):
"""
Compute the principal inertial axes for a given set of particle coordinates, ignoring particle mass.
Parameters
----------
x : torch.tensor
y : torch.tensor
z : torch.tensor
Returns
-------
Ip : np.array(3,3)
Principal inertial axes.
Ipm : np.array(3)
Principal inertial moments
I : np.array(3)
Inertial tensor in original frame
""" # todo harmonize with numpy version - currently disagrees ~0.5% of the time
Ixy = -torch.sum(x * y)
Iyz = -torch.sum(y * z)
Ixz = -torch.sum(x * z)
# I = torch.tensor([[Ixx, Ixy, Ixz], [Ixy, Iyy, Iyz], [Ixz, Iyz, Izz]],device=points.device) # inertial tensor
I = torch.tensor(
[[torch.sum((y ** 2 + z ** 2)), Ixy, Ixz],
[Ixy, torch.sum((x ** 2 + z ** 2)), Iyz],
[Ixz, Iyz, torch.sum((x ** 2 + y ** 2))]], device=x.device) # inertial tensor
Ipm, Ip = torch.linalg.eig(I) # principal inertial tensor
return I, Ip, Ipm
[docs]
def single_molecule_principal_axes_torch(coords: torch.tensor, masses=None, return_direction=False):
"""
Compute the principal inertial axes for a given set of particle coordinates, ignoring particle mass.
Use our overlap rules to ensure a fixed direction for all axes under almost all circumstances, excepting e.g., certain symmetric molecules.
Parameters
----------
coords : torch.tensor(n,3)
masses : None
not used
return_direction : bool
whether to add the direction between centroid and most distant coordinate to the output
Returns
-------
Ip : torch.tensor(3,3)
Principal inertial axes.
Ipm : torch.tensor(3)
Principal inertial moments
I : torch.tensor(3)
Inertial tensor in original frame
"""
if masses is not None:
print('Inertial tensor is purely geometric! Calculation will not account for varying masses')
x, y, z = coords.T
I, Ip, Ipm = compute_inertial_tensor_torch(x, y, z)
Ipm, Ip = torch.real(Ipm), torch.real(Ip)
sort_inds = torch.argsort(Ipm)
Ipm = Ipm[sort_inds]
Ip = Ip.T[sort_inds] # want eigenvectors to be sorted row-wise (rather than column-wise)
# cardinal direction is vector from CoM to farthest atom
dists = torch.linalg.norm(coords, axis=1) # CoM is at 0,0,0
max_ind = torch.argmax(dists)
max_equivs = torch.where(dists == dists[max_ind])[
0] # torch.where(torch.round(dists, decimals=8) == torch.round(dists[max_ind], decimals=8))[0] # if there are multiple equidistant atoms - pick the one with the lowest index
max_ind = int(torch.amin(max_equivs))
direction = coords[max_ind]
# direction = direction / torch.linalg.norm(direction) # magnitude doesn't matter, only the sign
overlaps = torch.inner(Ip,
direction) # Ip.dot(direction) # check if the principal components point towards or away from the CoG
if any(overlaps == 0): # exactly zero is invalid #
overlaps[overlaps == 0] = 1e-9
if any(torch.abs(
overlaps) < 1e-8): # if any overlaps are vanishing, determine the direction via the RHR (if two overlaps are vanishing, this will not work)
# align the 'good' vectors
Ip = (Ip.T * torch.sign(overlaps)).T # if the vectors have negative overlap, flip the direction
fix_ind = torch.argmin(torch.abs(overlaps))
other_vectors = np.delete(np.arange(3), fix_ind)
check_direction = torch.cross(Ip[other_vectors[0]], Ip[other_vectors[1]])
# align the 'bad' vector
Ip[fix_ind] = check_direction # Ip[fix_ind] * torch.sign(torch.dot(check_direction, Ip[fix_ind]))
else:
Ip = (Ip.T * torch.sign(overlaps)).T # if the vectors have negative overlap, flip the direction
if return_direction:
return Ip, Ipm, I, direction
else:
return Ip, Ipm, I
[docs]
def list_molecule_principal_axes_torch(coords_list: list = None,
skip_centring=False):
"""
Parallel computation of principal inertial axes from a list of coordinate lists.
Parameters
----------
coords_list : list(torch.tensor(n,3))
skip_centring : bool
Whether to skip centering each point cloud - e.g., if the input is already centered
Returns
-------
Ip_fin : list(torch.tensor(3,3))
Ipm_fin : list(torch.tensor(3))
I : list(torch.tensor(3,3))
"""
if not skip_centring: # todo accelerate with scatter
coords_list_centred = [coord - coord.mean(0) for coord in coords_list]
all_coords = torch.cat(coords_list_centred)
else:
all_coords = torch.cat(coords_list)
batch, ptrs = extract_batching_info(coords_list,
all_coords.device) # todo pass batch info as an argument instead calculating here
Ip, Ipm_fin, I = scatter_compute_Ip(all_coords, batch)
# cardinal direction is vector from CoM to the farthest atom
direction = batch_get_furthest_node_vector(all_coords, batch, num_graphs=len(coords_list))
normed_direction = direction / torch.linalg.norm(direction, dim=1)[:, None]
overlaps, signs = get_overlaps(Ip, normed_direction)
Ip_fin = correct_Ip_directions(Ip, overlaps,
signs) # somehow, fails for mirror planes, on top of symmetric and spherical tops
return Ip_fin, Ipm_fin, I
[docs]
def batch_molecule_principal_axes_torch(coords_i: torch.FloatTensor,
batch: torch.LongTensor,
num_graphs: int,
nodes_per_graph: torch.LongTensor,
heavy_atoms_only: bool = True,
atom_types: Optional[torch.LongTensor] = None, ):
"""
Parallel computation of principal inertial axes from a list of coordinate lists.
Parameters
----------
skip_centring : bool
Whether to skip centering each point cloud - e.g., if the input is already centered
Returns
-------
Ip_fin : list(torch.tensor(3,3))
Ipm_fin : list(torch.tensor(3))
I : list(torch.tensor(3,3))
"""
bc = torch.bincount(batch, minlength=num_graphs)
assert torch.equal(bc, nodes_per_graph), \
f"bincount(batch) vs num_atoms mismatch:\n bincount={bc.tolist()}\n num_atoms={nodes_per_graph.tolist()}"
coords = center_batch(coords_i, batch, num_graphs, nodes_per_graph,
center_on_heavy_atoms=heavy_atoms_only, atom_types=atom_types)
if heavy_atoms_only:
mask = atom_types > 1
Ip, Ipm_fin, I = scatter_compute_Ip(coords[mask], batch[mask])
direction = batch_get_furthest_node_vector(coords[mask], batch[mask],
num_graphs) # todo ON NEXT REPARAMETERIZATION SET THIS TO A MORE STABLE ANCHOR like median atom or some inner product
else:
Ip, Ipm_fin, I = scatter_compute_Ip(coords, batch)
direction = batch_get_furthest_node_vector(coords, batch, num_graphs)
# cardinal direction is vector from CoM to the farthest atom
normed_direction = direction / torch.linalg.norm(direction, dim=1)[:, None]
overlaps, signs = get_overlaps(Ip, normed_direction)
Ip_fin = correct_Ip_directions(Ip, overlaps,
signs) # fails for mirror planes, on top of symmetric and spherical tops
return Ip_fin, Ipm_fin, I
''' visualize clouds and axes for testing
ind = 16
x, y, z = coords[batch == ind].T.cpu().detach().numpy()
types = atom_types[batch==ind].cpu().detach().numpy()
fig = go.Figure(go.Scatter3d(x=x, y=y, z=z, mode='markers'))
a, b, c = torch.stack([torch.zeros_like(direction[ind]), direction[ind]]).T.cpu().detach().numpy()
fig.add_trace(go.Scatter3d(x=a, y=b, z=c, marker_color=types))
colors = ['red', 'green', 'blue']
for i in range(3):
a, b, c = torch.stack([torch.zeros_like(direction[ind]), Ip[ind, i]]).T.cpu().detach().numpy()
fig.add_trace(go.Scatter3d(x=a, y=b, z=c, marker_color=colors[i], name='Initial principal axes',
legendgroup='Initial principal axes', showlegend=i == 0))
for i in range(3):
a, b, c = torch.stack([torch.zeros_like(direction[ind]), Ip_fin[ind, i]]).T.cpu().detach().numpy()
fig.add_trace(go.Scatter3d(x=a, y=b, z=c, marker_color=colors[i], name='Fixed principal axes',
legendgroup='Fixed principal axes', showlegend=i == 0))
fig.show(renderer='browser')
'''
def correct_Ip_directions(Ip, overlaps, signs, overlap_threshold: float = 1e-5):
""" # TODO speed this up for large batches
Enforce positive overlaps for given inertial principal axes with a given canonical direction, given their overlaps.
Parameters
----------
Ip : torch.tensor(3,3)
overlaps : torch.tensor(3)
signs : torch.tensor(3)
overlap_threshold : float
Returns
-------
Ip_fin: torch.tensor(3,3)
Inertial principal axes with positive overlaps to the given canonical direction
"""
handedness = compute_Ip_handedness(Ip)
# if the vectors have negative overlap, flip the direction,
Ip_fixed = (Ip.permute(0, 2, 1) * signs[..., None]).permute(0, 2, 1)
# if any overlaps are vanishing (up to 32 bit precision),
# happens if the cardinal direction is too close to an existing principal axis
# determine the direction via the RHR (if two overlaps are vanishing, this will not work)
small_overlaps_bool = (overlaps.abs() < overlap_threshold).any(dim=1)
# identify the smallest overlap if we need to flip a direction
fix_ind = torch.argmin(overlaps.abs(), dim=-1) # [B]
left_handed = handedness < 0
to_flip_bools = left_handed * small_overlaps_bool
Ip_fixed[to_flip_bools, fix_ind[to_flip_bools]] *= -1
Ip_fixed_list = []
for ii, Ip_i in enumerate(Ip):
# if the vectors have negative overlap, flip the direction,
Ip_fixed = (Ip_i.T * signs[ii]).T
# if any overlaps are vanishing (up to 32 bit precision),
# happens if the cardinal direction is too close to an existing principal axis
# determine the direction via the RHR (if two overlaps are vanishing, this will not work)
if any(torch.abs(overlaps[ii]) < overlap_threshold):
# enforce right-handedness in the free vector (vanishing overlap)
fix_ind = torch.argmin(torch.abs(overlaps[ii])) # vector with vanishing overlap
if compute_Ip_handedness(Ip_i) < 0: # if result is not right-handed, swap it
Ip_fixed[fix_ind] = -Ip_fixed[fix_ind]
Ip_fixed_list.append(Ip_fixed)
Ip_fin = torch.stack(Ip_fixed_list)
return Ip_fin
[docs]
def correct_Ip_directions(Ip, overlaps, signs, overlap_threshold: float = 1e-5):
"""
Enforce positive overlaps for given inertial principal axes with a given canonical direction, given their overlaps.
Parameters
----------
Ip : torch.tensor(3,3)
overlaps : torch.tensor(3)
signs : torch.tensor(3)
overlap_threshold : float
Returns
-------
Ip_fin: torch.tensor(3,3)
Inertial principal axes with positive overlaps to the given canonical direction
"""
handedness = compute_Ip_handedness(Ip)
# if the vectors have negative overlap, flip the direction,
Ip_fin = (Ip.permute(0, 2, 1) * signs[:, None, :]).permute(0, 2, 1)
# if any overlaps are vanishing (up to 32 bit precision),
# happens if the cardinal direction is too close to an existing principal axis
# determine the direction via the RHR (if two overlaps are vanishing, this will not work)
small_overlaps_bool = (overlaps.abs() < overlap_threshold).any(dim=1)
# identify the smallest overlap if we need to flip a direction
fix_ind = torch.argmin(overlaps.abs(), dim=-1) # [B]
left_handed = handedness < 0
to_flip_bools = left_handed * small_overlaps_bool
Ip_fin[to_flip_bools, fix_ind[to_flip_bools]] *= -1
#
# Ip_fixed_list = []
# for ii, Ip_i in enumerate(Ip):
# # if the vectors have negative overlap, flip the direction,
# Ip_fixed = (Ip_i.T * signs[ii]).T
# # if any overlaps are vanishing (up to 32 bit precision),
# # happens if the cardinal direction is too close to an existing principal axis
# # determine the direction via the RHR (if two overlaps are vanishing, this will not work)
# if any(torch.abs(overlaps[ii]) < overlap_threshold):
# # enforce right-handedness in the free vector (vanishing overlap)
# fix_ind = torch.argmin(torch.abs(overlaps[ii])) # vector with vanishing overlap
# if compute_Ip_handedness(Ip_i) < 0: # if result is not right-handed, swap it
# Ip_fixed[fix_ind] = -Ip_fixed[fix_ind]
#
# Ip_fixed_list.append(Ip_fixed)
# Ip_fin = torch.stack(Ip_fixed_list)
return Ip_fin
[docs]
def get_overlaps(Ip, direction):
"""
Compute overlaps and signs for given inertial principal axes with a given canonical direction
Parameters
----------
Ip : torch.tensor(3,3)
direction : torch.tensor(3)
Returns
-------
overlaps : torch.tensor(3)
signs : torch.tensor(3)
"""
overlaps = torch.einsum('nij,nj->ni', (Ip,
direction)) # Ip.dot(direction) # check if the principal components point towards or away from the CoG
signs = torch.sign(overlaps)
signs[signs == 0] = 1 # we want any exactly zero overlaps to come with positive signs
return overlaps, signs
[docs]
def batch_get_furthest_node_vector(all_coords: torch.FloatTensor, batch: torch.LongTensor, num_graphs: int) -> Tensor:
"""
Compute cardinal direction for a list of sets of coordinates, defined as the vector from the centroid to the furthest coordinate.
Output is not unique for certain symmetric inputs.
Parameters
----------
all_coords : torch.tensor(n,3)
batch : torch.tensor(n)
num_graphs : int
Returns
-------
direction : torch.tensor(num_graphs, 3)
"""
dists = torch.linalg.norm(all_coords, axis=1) # CoM is at 0,0,0 by construction
max_dist, max_ind = scatter_max(dists, batch, dim_size=num_graphs)
return all_coords[max_ind]
[docs]
def nan_hook(name, tensor_ref, batch):
"""Return a backward hook that prints debug info and halts on NaN gradients."""
def _hook(grad):
if torch.isnan(grad).any():
print(f"NaNs in grad of {name}")
print(f"Original tensor: {tensor_ref}")
print(f"Batch: {batch}")
assert False, "Stop the code!!"
return torch.nan_to_num(grad)
return _hook
[docs]
def scatter_compute_Ip(all_coords, batch,
eps: float = 0.05,
add_noise: bool = False):
"""
Parallel function to compute inertial for a list of unequal sized sets of coordinates.
Parameters
----------
all_coords : torch.tensor(n,3)
batch : torch.tensor(n)
Returns
-------
Ip : torch.tensor(num_graphs, 3, 3)
Ipm : torch.tensor(num_graphs, 3)
I : torch.tensor(num_graphs, 3, 3)
"""
"""
input is symmetric by construction
when backpropagating, we need to ensure no degenerate eigenvalues
empirically has to be bigger already than ~0.0127 to lift degeneracies
taking default at 0.05 for safety
"""
if all_coords.requires_grad or add_noise:
coords_to_compute = all_coords + torch.randn_like(all_coords) * eps
else:
coords_to_compute = all_coords
Ixy = -scatter(coords_to_compute[:, 0] * coords_to_compute[:, 1],
batch, reduce='sum')
Iyz = -scatter(coords_to_compute[:, 1] * coords_to_compute[:, 2],
batch, reduce='sum')
Ixz = -scatter(coords_to_compute[:, 0] * coords_to_compute[:, 2],
batch, reduce='sum')
Ixx = scatter(coords_to_compute[:, 1] ** 2 + coords_to_compute[:, 2] ** 2,
batch, reduce='sum')
Iyy = scatter(coords_to_compute[:, 0] ** 2 + coords_to_compute[:, 2] ** 2,
batch, reduce='sum')
Izz = scatter(coords_to_compute[:, 0] ** 2 + coords_to_compute[:, 1] ** 2,
batch, reduce='sum')
inertial_tensor = torch.cat(
(torch.vstack((Ixx, Ixy, Ixz))[:, None, :].permute(2, 1, 0),
torch.vstack((Ixy, Iyy, Iyz))[:, None, :].permute(2, 1, 0),
torch.vstack((Ixz, Iyz, Izz))[:, None, :].permute(2, 1, 0)
), dim=-2) # inertial tensor
try:
Ipm_c, Ip_c = torch.linalg.eigh(inertial_tensor)
except RuntimeError as e:
if 'cuda' in str(inertial_tensor.device):
Ipm_c, Ip_c = torch.linalg.eigh(inertial_tensor.cpu())
Ipm_c = Ipm_c.to('cuda')
Ip_c = Ip_c.to('cuda')
else:
assert False, "Ipm error"
Ipms, Ip_o = torch.real(Ipm_c), torch.real(Ip_c)
# Ip_o, Ipms, _ = torch.linalg.svd(inertial_tensor) # superior numerical stability, yet equivalent for symmetric positive semi-definite
Ips = Ip_o.permute(0, 2, 1) # switch to row-wise eigenvectors
sort_inds = torch.argsort(Ipms, dim=1)
# too slow
# Ipm = torch.stack([Ipms[i, sort_inds[i]] for i in range(len(sort_inds))])
# Ip = torch.stack([Ips[i][sort_inds[i]] for i in range(len(sort_inds))]) # sort also the eigenvectors
# much faster
Ipm = torch.gather(Ipms, dim=1, index=sort_inds)
Ip = torch.gather(Ips, dim=1, index=sort_inds.unsqueeze(2).expand(-1, -1, Ips.shape[2]))
# if Ipms.requires_grad:
# coords_to_compute.retain_grad()
# coords_to_compute.register_hook(nan_hook("coords", coords_to_compute, batch))
return Ip, Ipm, inertial_tensor
[docs]
def sph2cart_rotvec(angles):
"""
Transform from axis-angle in polar coordinates to rotation vector
Parameters
----------
angles : (nx3)
theta, phi, r
Returns
-------
rotvec : (nx3)
x, y, z
"""
if isinstance(angles, np.ndarray):
if angles.ndim > 1:
theta, phi, r = angles.T
rotvec = r[:, None] * np.stack((np.sin(theta) * np.cos(phi), np.sin(theta) * np.sin(phi), np.cos(theta))).T
else:
theta, phi, r = angles
rotvec = r * np.asarray((np.sin(theta) * np.cos(phi), np.sin(theta) * np.sin(phi), np.cos(theta)))
return rotvec
elif torch.is_tensor(angles):
if angles.ndim > 1:
theta, phi, r = angles.T
rotvec = r[:, None] * torch.stack((theta.sin() * phi.cos(), theta.sin() * phi.sin(), theta.cos())).T
else:
theta, phi, r = angles
rotvec = r * torch.Tensor(theta.sin() * phi.cos(), theta.sin() * phi.sin(), theta.cos())
return rotvec
else:
print("Array type not supported! Must be np.ndarray or torch.tensor")
return None
[docs]
def cart2sph_rotvec(rotvec):
"""
transform rotation vector with axis rotvec/norm(rotvec) and angle ||rotvec||
to spherical coordinates theta, phi and r=||rotvec||
Parameters
----------
rotvec : (nx3)
x, y, z
Returns
-------
angles : (nx3)
theta, phi, r
"""
if isinstance(rotvec, np.ndarray):
r = np.linalg.norm(rotvec, axis=-1)
if rotvec.ndim == 1:
rotvec = rotvec[None, :]
r = np.asarray(r)[None]
unit_vector = rotvec / r[:, None]
# convert unit vector to angles
theta = np.arctan2(np.sqrt(unit_vector[:, 0] ** 2 + unit_vector[:, 1] ** 2), unit_vector[:, 2])
phi = np.arctan2(unit_vector[:, 1], unit_vector[:, 0])
if rotvec.ndim == 1:
return np.concatenate((theta, phi, r), axis=-1) # polar, azimuthal, applied rotation
else:
return np.concatenate((theta[:, None], phi[:, None], r[:, None]),
axis=-1) # polar, azimuthal, applied rotation
elif torch.is_tensor(rotvec):
r = torch.linalg.norm(rotvec, axis=-1)
if rotvec.ndim == 1:
rotvec = rotvec[None, :]
r = torch.Tensor(r)[None]
unit_vector = rotvec / r[:, None]
# convert unit vector to angles
theta = torch.arctan2(torch.sqrt(unit_vector[:, 0] ** 2 + unit_vector[:, 1] ** 2), unit_vector[:, 2])
phi = torch.arctan2(unit_vector[:, 1], unit_vector[:, 0])
if rotvec.ndim == 1:
return torch.cat((theta, phi, r), dim=-1) # polar, azimuthal, applied rotation
else:
return torch.cat((theta[:, None], phi[:, None], r[:, None]), dim=-1) # polar, azimuthal, applied rotation
else:
print("Array type not supported! Must be np.ndarray or torch.tensor")
return None
[docs]
def cell_vol_torch(v: torch.tensor, a: torch.tensor):
"""
compute the volume of a parallelpiped given basis vector lengths and internal angles
Parameters
----------
v : torch.tensor(3)
[a b c]
a : torch.tensor(3)
[alpha beta gamma]
Returns
-------
cell_volume : float
"""
''' Calculate cos and sin of cell angles '''
cos_a = torch.cos(a) # in natural units
''' Calculate volume of the unit cell '''
vol = v[0] * v[1] * v[2] * torch.sqrt(
torch.abs(1.0 - cos_a[0] ** 2 - cos_a[1] ** 2 - cos_a[2] ** 2 + 2.0 * cos_a[0] * cos_a[1] * cos_a[2]))
return vol
[docs]
def batch_cell_vol_torch(v: torch.tensor, a: torch.tensor):
"""
Batched computation of unit cell volumes given basis vector lengths and internal angles.
Parameters
----------
v : torch.tensor(n, 3)
batch of [a, b, c] lengths
a : torch.tensor(n, 3)
batch of [alpha, beta, gamma] angles in radians
Returns
-------
cell_volumes : torch.tensor(n)
"""
''' Calculate cos and sin of cell angles '''
cos_a = torch.cos(a) # in natural units
''' Calculate volume of the unit cell '''
vol = v[:, 0] * v[:, 1] * v[:, 2] * torch.sqrt(
torch.abs(
1.0 - cos_a[:, 0] ** 2 - cos_a[:, 1] ** 2 - cos_a[:, 2] ** 2 + 2.0 * cos_a[:, 0] * cos_a[:, 1] * cos_a[:,
2]))
return vol
[docs]
def cell_vol_angle_factor(cell_angles):
"""
Compute the angular factor sqrt(1 - cos²α - cos²β - cos²γ + 2cosα cosβ cosγ) used in cell volume calculations.
Parameters
----------
cell_angles : torch.tensor(..., 3)
[alpha, beta, gamma] in radians; leading batch dimensions are supported
Returns
-------
factor : torch.tensor(...)
"""
cos_a = torch.cos(cell_angles) # in natural units
return torch.sqrt(
torch.abs(1.0 - cos_a[..., 0] ** 2 - cos_a[..., 1] ** 2 - cos_a[..., 2] ** 2
+ 2.0 * cos_a[..., 0] * cos_a[..., 1] * cos_a[..., 2]))
[docs]
def compute_Ip_handedness(Ip):
"""
determine the right or left handedness from the cross products of principal inertial axes
np.array or torch.tensor input, single or multiple samples
Parameters
----------
Ip : (opt n, 3, 3)
principal inertial tensor
Returns
-------
handedness : (n)
+/- 1, the handedness of the cross products of principal inertial axes
"""
if isinstance(Ip, np.ndarray):
if Ip.ndim == 2:
return np.sign(np.dot(Ip[0], np.cross(Ip[1], Ip[2])).sum())
elif Ip.ndim == 3:
return np.sign(np.dot(Ip[:, 0], np.cross(Ip[:, 1], Ip[:, 2], axis=1).T).sum(1))
elif torch.is_tensor(Ip):
if Ip.ndim == 2:
return torch.sign(torch.mul(Ip[0], torch.cross(Ip[1], Ip[2], dim=0)).sum()).float()
elif Ip.ndim == 3:
return torch.sign(torch.mul(Ip[:, 0], torch.cross(Ip[:, 1], Ip[:, 2], dim=1)).sum(1))
else:
print("Ip handedness calculation failed! Inputs were neither torch.tensor or numpy.array")
sys.exit()
[docs]
def cell_vol_np(v, a):
"""
compute the volume of a parallelpiped given basis vector lengths and internal angles
Parameters
----------
v : np.array(3)
[a b c]
a : np.array(3)
[alpha beta gamma]
Returns
-------
cell_volume : float
"""
""" Calculate cos and sin of cell angles """
cos_a = np.cos(a) # in natural units
''' Calculate volume of the unit cell '''
val = 1.0 - cos_a[0] ** 2 - cos_a[1] ** 2 - cos_a[2] ** 2 + 2.0 * cos_a[0] * cos_a[1] * cos_a[2]
vol = v[0] * v[1] * v[2] * np.sqrt(np.abs(val)) # technically a signed quanitity
return vol
[docs]
def coor_trans_matrix_np(opt, v, a, return_vol=False):
"""
compute f->c and c->f transforms as well as cell volume in a vectorized, differentiable way
Parameters
----------
opt : str 'c_to_f' or 'f_to_c'
which direction to transform between fractional and cartesian
v : np.array(3)
a, b, c
a : np.array(3)
alpha, beta, gamma
return_vol : bool, optional
return the absolute value of the cell volume
Returns
-------
transform : np.array(3,3)
cell_volumes : float
"""
""" Calculate cos and sin of cell angles """ # todo test - enforce this agrees with the torch version
if np.amax(a) > np.pi:
print('Warning - large angles! Remember to convert to natural units!')
cos_a = np.cos(a)
sin_a = np.sin(a)
''' Calculate volume of the unit cell '''
val = 1.0 - cos_a[0] ** 2 - cos_a[1] ** 2 - cos_a[2] ** 2 + 2.0 * cos_a[0] * cos_a[1] * cos_a[2]
vol = np.sign(val) * v[0] * v[1] * v[2] * np.sqrt(np.abs(val)) # technically a signed quanitity
''' Setting the transformation matrix '''
m = np.zeros((3, 3), dtype=np.float64)
if opt == 'c_to_f':
''' Converting from cartesian to fractional '''
m[0, 0] = 1.0 / v[0]
m[0, 1] = -cos_a[2] / v[0] / sin_a[2]
m[0, 2] = v[1] * v[2] * (cos_a[0] * cos_a[2] - cos_a[1]) / vol / sin_a[2]
m[1, 1] = 1.0 / v[1] / sin_a[2]
m[1, 2] = v[0] * v[2] * (cos_a[1] * cos_a[2] - cos_a[0]) / vol / sin_a[2]
m[2, 2] = v[0] * v[1] * sin_a[2] / vol
elif opt == 'f_to_c':
''' Converting from fractional to cartesian '''
m[0, 0] = v[0]
m[0, 1] = v[1] * cos_a[2]
m[0, 2] = v[2] * cos_a[1]
m[1, 1] = v[1] * sin_a[2]
m[1, 2] = v[2] * (cos_a[0] - cos_a[1] * cos_a[2]) / sin_a[2]
m[2, 2] = vol / v[0] / v[1] / sin_a[2]
if return_vol:
return m, np.abs(vol)
else:
return m
[docs]
def mol_batch_vdW_volume(mol_batch):
"""
wrapper for batch_compute_vdW_volume
"""
return batch_compute_molecule_volume(
mol_batch.z,
mol_batch.pos,
mol_batch.batch,
mol_batch.num_graphs,
torch.tensor(list(VDW_RADII.values()),
device=mol_batch.z.device,
dtype=torch.float32))
[docs]
def batch_compute_molecule_volume(
atom_types: torch.LongTensor,
pos: torch.FloatTensor,
batch: torch.LongTensor,
num_graphs: int,
vdw_radii_tensor: torch.Tensor
):
"""
Estimate molecular vdW volumes for a batch of molecules using a sphere-overlap correction.
Sums atomic sphere volumes then subtracts pairwise sphere-sphere intersection volumes
scaled by an empirical correction factor (~0.73) to approximate triple overlaps.
Parameters
----------
atom_types : torch.LongTensor(n)
Atomic numbers used to index vdw_radii_tensor
pos : torch.FloatTensor(n, 3)
Atomic coordinates
batch : torch.LongTensor(n)
Graph index for each atom
num_graphs : int
vdw_radii_tensor : torch.Tensor
Lookup table of vdW radii indexed by atomic number
Returns
-------
corrected_mol_volume : torch.FloatTensor(num_graphs)
"""
atom_volumes = 4 / 3 * torch.pi * vdw_radii_tensor[atom_types] ** 3
raw_vdw_volumes = scatter(atom_volumes, batch, dim=0, dim_size=num_graphs, reduce='sum')
bonds_i, bonds_j = radius(pos, pos,
r=2 * vdw_radii_tensor.max(),
batch_x=batch,
batch_y=batch,
max_num_neighbors=100)
mask = ~(bonds_i >= bonds_j) # eliminate duplicates
bonds_i, bonds_j = bonds_i[mask], bonds_j[mask]
bond_lengths = torch.linalg.norm(pos[bonds_i] - pos[bonds_j], dim=1)
radii_i, radii_j = vdw_radii_tensor[atom_types[bonds_i]], vdw_radii_tensor[atom_types[bonds_j]]
# https://mathworld.wolfram.com/Sphere-SphereIntersection.html
# c1 = pi*(r1+r2-d)^2
# c2 = d^2 + 2dr1 + 2dr2 + 6r1r2 - 3r1^2 - 3r2^2
# c3 = 12d
c1 = torch.pi * (radii_i + radii_j - bond_lengths) ** 2
c2 = bond_lengths ** 2 + 2 * bond_lengths * (
radii_i + radii_j) - 3 * radii_i ** 2 - 3 * radii_j ** 2 + 6 * radii_i * radii_j
c3 = 12 * bond_lengths
# sphere_overlaps = (torch.pi * (radii_i + radii_j - bond_lengths) ** 2 *
# (bond_lengths ** 2 + 2 * bond_lengths * radii_j - 3 * radii_j ** 2
# + 2 * bond_lengths * radii_i + 6 * radii_j * radii_i - 3 * radii_i ** 2) / (12 * bond_lengths))
sphere_overlaps = c1 * c2 / c3
sphere_overlaps[bond_lengths > (radii_i + radii_j)] = 0
molwise_sphere_overlaps = scatter(sphere_overlaps, batch[bonds_i], dim=0, dim_size=num_graphs,
reduce='sum')
# estimate correction factor (omits triple overlaps) by comparison of representative molecule batch with well-converged probe method
corrected_mol_volume = raw_vdw_volumes - molwise_sphere_overlaps * 0.7272 # a very coarse correction factor
return corrected_mol_volume
[docs]
def probe_compute_molecule_volume(
atom_types: torch.LongTensor,
pos: torch.FloatTensor,
batch: torch.LongTensor,
num_graphs: int,
vdw_radii_tensor: torch.Tensor,
probes_per_mol: int = 100,
eps: float = 1e-2,
max_iters: int = 1000,
min_iters: int = 5
):
"""
Estimate molecular vdW volumes via Monte Carlo probe sampling, iterating until convergence.
Random probes are scattered within each molecule's bounding box; the fraction inside
any atomic vdW sphere gives the volume estimate. Runs until the relative change in
the running mean drops below eps for min_iters consecutive iterations.
Parameters
----------
atom_types : torch.LongTensor(n)
pos : torch.FloatTensor(n, 3)
batch : torch.LongTensor(n)
num_graphs : int
vdw_radii_tensor : torch.Tensor
Lookup table of vdW radii indexed by atomic number
probes_per_mol : int
Number of random probes per molecule per iteration
eps : float
Convergence threshold on relative change in running mean
max_iters : int
Maximum number of sampling iterations
min_iters : int
Minimum iterations before convergence is checked
Returns
-------
volumes : torch.FloatTensor(num_graphs)
Converged volume estimates
"""
volume_record = []
converged = False
iter = 0
while not converged and iter < max_iters:
iter += 1
mol_min_corner = scatter(pos, batch, dim=0, reduce='min') - vdw_radii_tensor.max()
mol_max_corner = scatter(pos, batch, dim=0, reduce='max') + vdw_radii_tensor.max()
bounding_volumes = torch.prod(mol_max_corner - mol_min_corner, dim=1)
probe_batch = torch.arange(num_graphs, device=batch.device).repeat_interleave(probes_per_mol)
random_probes = mol_min_corner[probe_batch] + (
mol_max_corner[probe_batch] - mol_min_corner[probe_batch]) * torch.rand(
(probes_per_mol * num_graphs, 3), device=pos.device)
# Compute squared distances between each probe and each atom
edge_i, edge_j = radius(x=pos,
y=random_probes,
r=vdw_radii_tensor.max(),
batch_x=batch,
batch_y=probe_batch,
max_num_neighbors=probes_per_mol)
dists = torch.linalg.norm(random_probes[edge_i] - pos[edge_j], dim=1)
# Check if each probe is inside any sphere
inside_any_sphere = (dists < vdw_radii_tensor[atom_types[edge_j]]).float()
# we want a maximum of one contact per probe
probe_has_a_contact = scatter(inside_any_sphere, edge_i, reduce='max', dim_size=probes_per_mol * num_graphs,
dim=0)
inside_sphere_frac = scatter(probe_has_a_contact, probe_batch, reduce='sum', dim=0,
dim_size=num_graphs) / probes_per_mol
# Estimate molecular volume
mol_volume = bounding_volumes * inside_sphere_frac
volume_record.append(mol_volume)
if iter > min_iters:
volumes = torch.stack(volume_record)
cum_vols = torch.cumsum(volumes, dim=0)
cum_iters = torch.arange(1, len(volumes) + 1, device=volumes.device)[:, None]
cum_means = cum_vols / cum_iters
rel_diffs = torch.diff(cum_means / cum_means.mean(0)[None, :], dim=0)
criteria = rel_diffs[-min_iters:].abs().mean(0)
if torch.all(criteria < eps): # relative change in running average less than eps
converged = True
return cum_means[-1, :]
[docs]
def grid_compute_molecule_volume(atom_types, pos, vdw_radii_tensor, eps):
"""
brute force grid approach to computing vdW volume for a single molecule
Parameters
----------
atom_types
pos
vdw_radii_tensor
Returns
-------
"""
convergence_history = []
dx = 0.1
converged = False
ind = -1
max_iters = 10
while converged is False and ind < max_iters:
dx *= 0.75
xmin, ymin, zmin = (pos.amin(0) - vdw_radii_tensor.amax())
xmax, ymax, zmax = (pos.amax(0) + vdw_radii_tensor.amax())
num_x = int((xmax - xmin) / dx)
num_y = int((ymax - ymin) / dx)
num_z = int((zmax - zmin) / dx)
grid = torch.meshgrid(torch.linspace(xmin, xmax, num_x),
torch.linspace(ymin, ymax, num_y),
torch.linspace(zmin, zmax, num_z),
indexing='xy'
)
grid = torch.stack(grid)
grid_flat = grid.reshape(3, num_x * num_y * num_z).T
edges = radius(x=grid_flat, y=pos, r=vdw_radii_tensor.amax(),
max_num_neighbors=int(1e12))
dists = torch.linalg.norm(pos[edges[0]] - grid_flat[edges[1]], dim=1)
close_enough = dists <= vdw_radii_tensor[atom_types[edges[0]]]
overlapped_points = len(edges[1, close_enough].unique())
box_volume = (xmax - xmin) * (ymax - ymin) * (zmax - zmin)
overlapping_fraction = overlapped_points / len(grid_flat)
occupied_volume = box_volume * overlapping_fraction
convergence_history.append(float(occupied_volume))
ind += 1
if ind > 1:
conv = abs(convergence_history[-2] - convergence_history[-1]) / convergence_history[1]
if conv < eps:
converged = True
print(ind)
print(conv)
return occupied_volume
[docs]
def grid_compute_molecule_volume_pointwise(atom_types, pos, vdw_radii_tensor, eps):
"""
brute force grid approach to computing vdW volume for a single molecule
Parameters
----------
atom_types
pos
vdw_radii_tensor
Returns
-------
"""
convergence_history = []
dx = 0.1
converged = False
ind = -1
max_iters = 10
while converged is False and ind < max_iters:
dx *= 0.75
xmin, ymin, zmin = (pos.amin(0) - vdw_radii_tensor.amax())
xmax, ymax, zmax = (pos.amax(0) + vdw_radii_tensor.amax())
num_x = int((xmax - xmin) / dx)
num_y = int((ymax - ymin) / dx)
num_z = int((zmax - zmin) / dx)
grid = torch.meshgrid(torch.linspace(xmin, xmax, num_x),
torch.linspace(ymin, ymax, num_y),
torch.linspace(zmin, zmax, num_z),
indexing='xy'
)
grid = torch.stack(grid)
grid_flat = grid.reshape(3, num_x * num_y * num_z).T
edges = radius(x=grid_flat, y=pos, r=vdw_radii_tensor.amax(),
max_num_neighbors=int(1e12))
dists = torch.linalg.norm(pos[edges[0]] - grid_flat[edges[1]], dim=1)
close_enough = dists <= vdw_radii_tensor[atom_types[edges[0]]]
overlapped_points = len(edges[1, close_enough].unique())
box_volume = (xmax - xmin) * (ymax - ymin) * (zmax - zmin)
overlapping_fraction = overlapped_points / len(grid_flat)
occupied_volume = box_volume * overlapping_fraction
convergence_history.append(float(occupied_volume))
ind += 1
if ind > 1:
conv = abs(convergence_history[-2] - convergence_history[-1]) / convergence_history[1]
if conv < eps:
converged = True
print(ind)
print(conv)
return occupied_volume
[docs]
def norm_circular_components(components: torch.tensor):
"""
Use Pythagoras to norm the sum of squares to the unit circle.
Parameters
----------
components : torch.tensor(n, 2)
Returns
-------
normed_components : torch.tensor(n, 2)
"""
return components / torch.sqrt(torch.sum(components ** 2, dim=-1))[:, None]
[docs]
def components2angle(components: torch.tensor, norm_components=True):
"""
Take two non-normalized components[n, 2] representing sin(angle) and cos(angle), compute the resulting angle,
following https://ai.stackexchange.com/questions/38045/how-can-i-encode-angle-data-to-train-neural-networks
Optionally norm the sum of squares - doesn't appear to do much though.
Parameters
----------
components : torch.tensor(n, 2)
norm_components : bool, optional
Returns
-------
angles : torch.tensor(n, 2)
"""
if norm_components:
normed_components = norm_circular_components(components)
angles = torch.atan2(normed_components[:, 0], normed_components[:, 1])
else:
angles = torch.atan2(components[:, 0], components[:, 1])
return angles
[docs]
def angle2components(angle: torch.tensor):
"""
Decompose an angle into sin(angle) and cos(angle)
Parameters
----------
angle : torch.tensor(n)
Returns
-------
sin(angle), cos(angle) : torch.tensor, torch.tensor
"""
return torch.cat((torch.sin(angle)[:, None], torch.cos(angle)[:, None]), dim=1)
[docs]
def enforce_crystal_system(lattice_lengths,
lattice_angles,
sg_inds,
symmetries_dict: Optional[dict] = None
):
"""
enforce physical bounds on cell parameters
https://en.wikipedia.org/wiki/Crystal_system
""" # todo vectorize this function, and clean it up in general
if symmetries_dict is None:
symmetries_dict = init_sym_info()
if sg_inds.ndim == 0:
lattices = [symmetries_dict['lattice_type'][int(sg_inds)]]
else:
lattices = [symmetries_dict['lattice_type'][int(sg_inds[n])] for n in range(len(lattice_lengths))]
pi_tensor = torch.ones_like(lattice_lengths[0, 0]) * torch.pi
fixed_lengths = torch.zeros_like(lattice_lengths)
fixed_angles = torch.zeros_like(lattice_angles)
for i in range(len(lattice_lengths)):
lengths = lattice_lengths[i]
angles = lattice_angles[i]
lattice = lattices[i]
# enforce agreement with crystal system
if lattice.lower() == 'triclinic': # anything goes
fixed_lengths[i] = lengths * 1
fixed_angles[i] = angles * 1
elif lattice.lower() == 'monoclinic': # fix alpha and gamma to pi/2
fixed_lengths[i] = lengths * 1
fixed_angles[i] = torch.stack((
pi_tensor.clone() / 2, angles[1], pi_tensor.clone() / 2,
), dim=- 1)
elif lattice.lower() == 'orthorhombic': # fix all angles at pi/2
fixed_lengths[i] = lengths * 1
fixed_angles[i] = torch.stack((
pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() / 2,
), dim=- 1)
elif lattice.lower() == 'tetragonal': # fix all angles pi/2 and a=b
mean_tensor = lengths[0] * 1
fixed_lengths[i] = torch.stack((
mean_tensor, mean_tensor, lengths[2] * 1,
), dim=- 1)
fixed_angles[i] = torch.stack((
pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() / 2,
), dim=- 1)
elif lattice.lower() == 'hexagonal':
# a=b
# alpha beta are pi/2, gamma is 2pi/3
mean_tensor = lengths[0] * 1
fixed_lengths[i] = torch.stack((
mean_tensor, mean_tensor, lengths[2] * 1,
), dim=- 1)
fixed_angles[i] = torch.stack((
pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() * 2 / 3,
), dim=- 1)
# elif lattice.lower() == 'trigonal':
elif lattice.lower() == 'rhombohedral':
# mean of abc vector lengths
# mean of all angles
mean_tensor = lengths[0] * 1
fixed_lengths[i] = torch.stack((
mean_tensor, mean_tensor, mean_tensor,
), dim=- 1)
mean_angle = angles[0]
fixed_angles[i] = torch.stack((
mean_angle, mean_angle, mean_angle,
), dim=- 1)
elif lattice.lower() == 'cubic': # all angles 90 all lengths equal
mean_tensor = lengths[0] * 1
fixed_lengths[i] = torch.stack((
mean_tensor, mean_tensor, mean_tensor,
), dim=- 1)
fixed_angles[i] = torch.stack((
pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() / 2,
), dim=- 1)
else:
print(lattice + ' is not a valid crystal lattice!')
sys.exit()
return fixed_lengths, fixed_angles
[docs]
def enforce_crystal_system2(lattice_lengths, lattice_angles, lattices):
"""
enforce physical bounds on cell parameters
https://en.wikipedia.org/wiki/Crystal_system
""" # todo double check these limits
pi_tensor = torch.ones_like(lattice_lengths[0, 0]) * torch.pi
fixed_lengths = torch.zeros_like(lattice_lengths)
fixed_angles = torch.zeros_like(lattice_angles)
for i in range(len(lattice_lengths)):
lengths = lattice_lengths[i]
angles = lattice_angles[i]
lattice = lattices[i]
# enforce agreement with crystal system
if lattice.lower() == 'triclinic': # anything goes
fixed_lengths[i] = lengths * 1
fixed_angles[i] = angles * 1
elif lattice.lower() == 'monoclinic': # fix alpha and gamma to pi/2
fixed_lengths[i] = lengths * 1
fixed_angles[i] = torch.stack((
pi_tensor.clone() / 2, angles[1], pi_tensor.clone() / 2,
), dim=- 1)
elif lattice.lower() == 'orthorhombic': # fix all angles at pi/2
fixed_lengths[i] = lengths * 1
fixed_angles[i] = torch.stack((
pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() / 2,
), dim=- 1)
elif lattice.lower() == 'tetragonal': # fix all angles pi/2 and take the mean of a & b vectors
mean_tensor = torch.mean(lengths[0:2])
fixed_lengths[i] = torch.stack((
mean_tensor, mean_tensor, lengths[2] * 1,
), dim=- 1)
fixed_angles[i] = torch.stack((
pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() / 2,
), dim=- 1)
elif lattice.lower() == 'hexagonal':
# mean of ab, c is free
# alpha beta are pi/2, gamma is 2pi/3
mean_tensor = torch.mean(lengths[0:2])
fixed_lengths[i] = torch.stack((
mean_tensor, mean_tensor, lengths[2] * 1,
), dim=- 1)
fixed_angles[i] = torch.stack((
pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() * 2 / 3,
), dim=- 1)
# elif lattice.lower() == 'trigonal':
elif lattice.lower() == 'rhombohedral':
# mean of abc vector lengths
# mean of all angles
mean_tensor = torch.mean(lengths)
fixed_lengths[i] = torch.stack((
mean_tensor, mean_tensor, mean_tensor,
), dim=- 1)
mean_angle = torch.mean(angles)
fixed_angles[i] = torch.stack((
mean_angle, mean_angle, mean_angle,
), dim=- 1)
elif lattice.lower() == 'cubic': # all angles 90 all lengths equal
mean_tensor = torch.mean(lengths)
fixed_lengths[i] = torch.stack((
mean_tensor, mean_tensor, mean_tensor,
), dim=- 1)
fixed_angles[i] = torch.stack((
pi_tensor.clone() / 2, pi_tensor.clone() / 2, pi_tensor.clone() / 2,
), dim=- 1)
else:
print(lattice + ' is not a valid crystal lattice!')
sys.exit()
return fixed_lengths, fixed_angles
[docs]
def cell_parameters_to_box_vectors(opt: str,
cell_lengths: torch.tensor,
cell_angles: torch.tensor,
return_vol: bool = False):
""" # TODO I believe this is a duplicate function
Initially borrowed from Nikos
Quickly convert from cell lengths and angles to fractional transform matrices fractional->cartesian or cartesian->fractional
"""
''' Calculate cos and sin of cell angles '''
cos_a = torch.cos(cell_angles)
sin_a = torch.sin(cell_angles)
''' Calculate volume of the unit cell '''
val = 1.0 - cos_a[0] ** 2 - cos_a[1] ** 2 - cos_a[2] ** 2 + 2.0 * cos_a[0] * cos_a[1] * cos_a[2]
vol = torch.sign(val) * cell_lengths[0] * cell_lengths[1] * cell_lengths[2] * torch.sqrt(
torch.abs(val)) # technically a signed quanitity
''' Setting the transformation matrix '''
m = torch.zeros((3, 3), dtype=torch.float32, device=cell_lengths.device)
if opt == 'c_to_f':
''' Converting from cartesian to fractional '''
m[0, 0] = 1.0 / cell_lengths[0]
m[0, 1] = -cos_a[2] / cell_lengths[0] / sin_a[2]
m[0, 2] = cell_lengths[1] * cell_lengths[2] * (cos_a[0] * cos_a[2] - cos_a[1]) / vol / sin_a[2]
m[1, 1] = 1.0 / cell_lengths[1] / sin_a[2]
m[1, 2] = cell_lengths[0] * cell_lengths[2] * (cos_a[1] * cos_a[2] - cos_a[0]) / vol / sin_a[2]
m[2, 2] = cell_lengths[0] * cell_lengths[1] * sin_a[2] / vol
elif opt == 'f_to_c':
''' Converting from fractional to cartesian '''
m[0, 0] = cell_lengths[0]
m[0, 1] = cell_lengths[1] * cos_a[2]
m[0, 2] = cell_lengths[2] * cos_a[1]
m[1, 1] = cell_lengths[1] * sin_a[2]
m[1, 2] = cell_lengths[2] * (cos_a[0] - cos_a[1] * cos_a[2]) / sin_a[2]
m[2, 2] = vol / cell_lengths[0] / cell_lengths[1] / sin_a[2]
# todo create m in a single-step
if return_vol:
return m, torch.abs(vol)
else:
return m
[docs]
def compute_mol_radius(coords: torch.FloatTensor) -> Tensor:
"""
Compute the radius of a single molecule as the maximum distance from the centroid to any atom.
Parameters
----------
coords : torch.FloatTensor(n, 3)
Returns
-------
radius : torch.FloatTensor scalar
"""
centroid = coords.mean(0)
return torch.amax(torch.linalg.norm(coords - centroid, dim=-1))
[docs]
def batch_compute_mol_radius(coords: torch.FloatTensor,
batch: torch.LongTensor,
num_graphs: int,
nodes_per_graph: torch.LongTensor,
) -> Tensor:
"""
Batched version of compute_mol_radius.
Parameters
----------
coords : torch.FloatTensor(n, 3)
batch : torch.LongTensor(n)
num_graphs : int
nodes_per_graph : torch.LongTensor(num_graphs)
Returns
-------
radii : torch.FloatTensor(num_graphs)
"""
centroids = get_batch_centroids(coords, batch, num_graphs)
dists = torch.linalg.norm(coords - centroids.repeat_interleave(nodes_per_graph, 0), dim=-1)
return scatter(dists, batch, dim=0, dim_size=num_graphs, reduce='max')
[docs]
def get_batch_centroids(coords: torch.FloatTensor,
batch: torch.LongTensor,
num_graphs: int,
) -> Tensor:
"""
Compute the centroid (mean position) for each graph in a batch.
Parameters
----------
coords : torch.FloatTensor(n, 3)
batch : torch.LongTensor(n)
num_graphs : int
Returns
-------
centroids : torch.FloatTensor(num_graphs, 3)
"""
return scatter(coords, batch, dim=0, dim_size=num_graphs, reduce='mean')
[docs]
def center_batch(coords: torch.FloatTensor,
batch: torch.LongTensor,
num_graphs: int,
nodes_per_graph: torch.LongTensor,
center_on_heavy_atoms: bool = False,
atom_types: Optional[torch.LongTensor] = None,
) -> Tensor:
"""
Subtract the centroid from each graph's coordinates, returning zero-centered positions.
Parameters
----------
coords : torch.FloatTensor(n, 3)
batch : torch.LongTensor(n)
num_graphs : int
nodes_per_graph : torch.LongTensor(num_graphs)
center_on_heavy_atoms : bool
If True, compute the centroid using only heavy atoms (atom_types > 1)
but translate all atoms
atom_types : torch.LongTensor(n), optional
Required when center_on_heavy_atoms is True
Returns
-------
coords_centered : torch.FloatTensor(n, 3)
"""
if center_on_heavy_atoms:
mask = atom_types > 1
centroids = get_batch_centroids(coords[mask], batch[mask], num_graphs)
else:
centroids = get_batch_centroids(coords, batch, num_graphs)
coords_out = coords - centroids.repeat_interleave(nodes_per_graph, 0)
return coords_out
[docs]
def batch_compute_mol_mass(z: torch.LongTensor,
batch: torch.LongTensor,
masses_tensor: torch.FloatTensor,
num_graphs: int) -> Tensor:
"""
Sum atomic masses for each graph in a batch.
Parameters
----------
z : torch.LongTensor(n)
Atomic numbers used to index masses_tensor
batch : torch.LongTensor(n)
masses_tensor : torch.FloatTensor
Lookup table of atomic masses indexed by atomic number
num_graphs : int
Returns
-------
masses : torch.FloatTensor(num_graphs)
"""
return scatter(masses_tensor[z], batch, dim=0, dim_size=num_graphs, reduce='sum')
[docs]
def compute_mol_mass(z: torch.LongTensor,
masses_tensor: torch.FloatTensor) -> Tensor:
"""
Sum atomic masses for a single molecule.
Parameters
----------
z : torch.LongTensor(n)
Atomic numbers
masses_tensor : torch.FloatTensor
Lookup table of atomic masses indexed by atomic number
Returns
-------
mass : torch.FloatTensor scalar
"""
return torch.sum(masses_tensor[z])
[docs]
def rotvec2rotmat(mol_rotation: torch.tensor, basis='cartesian'):
"""
get applied rotation matrix
mol_rotation here is a list of rotation vectors [n_samples, 3]
rotvec -> rotation matrix directly (see https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula)
rotvec -> quat -> rotation matrix (old way)
"""
if basis == 'cartesian':
r = torch.linalg.norm(mol_rotation, dim=1)
unit_vector = mol_rotation / r[:, None]
elif basis == 'spherical': # psi, phi, (spherical unit vector) theta (rotation vector)
r = mol_rotation[
:, -1] # third dimension in spherical basis is the norm #torch.linalg.norm(mol_rotation, dim=1)
mol_rotation = sph2cart_rotvec(mol_rotation)
unit_vector = mol_rotation / r[:, None]
else:
print(f'{basis} is not a valid orientation parameterization!')
sys.exit()
K = torch.stack(( # matrix representing rotation axis
torch.stack((torch.zeros_like(unit_vector[:, 0]), -unit_vector[:, 2], unit_vector[:, 1]), dim=1),
torch.stack((unit_vector[:, 2], torch.zeros_like(unit_vector[:, 0]), -unit_vector[:, 0]), dim=1),
torch.stack((-unit_vector[:, 1], unit_vector[:, 0], torch.zeros_like(unit_vector[:, 0])), dim=1)
), dim=1)
identity_batch = torch.eye(3, device=r.device, dtype=torch.float32)[None, :, :].tile(len(r), 1, 1)
applied_rotation_list = identity_batch + torch.sin(r[:, None, None]) * K + (1 - torch.cos(r[:, None, None])) * (
K @ K)
return applied_rotation_list
[docs]
def apply_rotation_to_batch(elems, rotations, batch):
"""
Apply per-graph rotation matrices to a batch of vectors.
Parameters
----------
elems : torch.FloatTensor(n, 3)
rotations : torch.FloatTensor(num_graphs, 3, 3)
batch : torch.LongTensor(n)
Returns
-------
rotated : torch.FloatTensor(n, 3)
"""
return torch.einsum('nij, nj -> ni', rotations[batch], elems)
[docs]
def rotmat2rotvec(rotation_matrix_list, warn_on_bad_determinant=True):
"""
Convert a batch of rotation matrices to rotation vectors (axis-angle).
Uses the Rodrigues formula: axis from the skew-symmetric part, angle from the trace.
Degenerate cases (near-identity or near-pi rotations) are handled by clamping to a
fallback rotation of pi around [1,1,1].
Parameters
----------
rotation_matrix_list : torch.FloatTensor(n, 3, 3)
warn_on_bad_determinant : bool
Print a warning if any matrix has det < 0 (i.e. is a reflection, not a rotation)
Returns
-------
rotvecs : torch.FloatTensor(n, 3)
"""
# Fixed!
if warn_on_bad_determinant:
det = torch.linalg.det(rotation_matrix_list)
bad_dets = det < 0.0
if bad_dets.any():
num_bad = bad_dets.sum().item()
print(f"[rotmat2rotvec] WARNING: {num_bad} matrices have determinant < 0. "
"These are reflections (not in SO(3)) and cannot be converted to rotation vectors.")
direction_vector_list = torch.stack([
rotation_matrix_list[:, 2, 1] - rotation_matrix_list[:, 1, 2],
rotation_matrix_list[:, 0, 2] - rotation_matrix_list[:, 2, 0],
rotation_matrix_list[:, 1, 0] - rotation_matrix_list[:, 0, 1]],
).T
trace = torch.einsum('nii->n', rotation_matrix_list)
r_arg = (trace - 1) / 2
r = torch.arccos(r_arg)
bad_inds = torch.any(torch.stack([r_arg.abs() >= 1,
torch.isnan(r),
direction_vector_list.sum(1) == 0,
torch.isnan(direction_vector_list).sum(dim=1) > 0]
).T,
dim=1)
direction_vector_list[bad_inds, :] = torch.ones_like(direction_vector_list[bad_inds, :])
r[bad_inds] = torch.pi
rotvecs = direction_vector_list / (direction_vector_list.norm(dim=1, keepdim=True).clamp(min=1e-8)) * r[:, None]
return rotvecs
"""
old version
direction_vector = torch.tensor([
rotation_matrix[2, 1] - rotation_matrix[1, 2],
rotation_matrix[0, 2] - rotation_matrix[2, 0],
rotation_matrix[1, 0] - rotation_matrix[0, 1]],
device=rotation_matrix.device, dtype=torch.float32) # 32 precision is limiting here in some cases
rotvec_list.append(direction_vector / torch.linalg.norm(direction_vector) * r)
"""
[docs]
def sample_random_valid_rotvecs(num_samples):
"""
Sample uniformly random rotation vectors with theta restricted to the upper half-sphere (z ≥ 0).
Directions are drawn from a Gaussian (giving uniform spherical coverage) and norms
are drawn uniformly from (0, 2π].
Parameters
----------
num_samples : int
Returns
-------
rotvecs : torch.FloatTensor(num_samples, 3)
"""
# random directions on the sphere, getting naturally the correct distribution of theta, phi
random_vectors = torch.randn(size=(num_samples, 3))
# set norms uniformly between 0-2pi
norms = random_vectors.norm(dim=1)
applied_norms = (torch.rand(num_samples) * 2 * torch.pi).clip(min=0.05) # cannot be exactly zero
random_vectors = random_vectors / norms[:, None] * applied_norms[:, None]
# restrict theta to upper half-sphere (positive z)
random_vectors[:, 2] = torch.abs(random_vectors[:, 2])
return random_vectors
[docs]
def embed_vector_to_rank3(v):
"""embed an nxk vector as a symmetric 3-tensor"""
delta = torch.eye(3)
P = torch.einsum('ij,nk->nijk', delta, v) + \
torch.einsum('ik,nj->nijk', delta, v) + \
torch.einsum('jk,ni->nijk', delta, v)
return P / 3 # Normalization factor
[docs]
def compute_ellipsoid_volume(e):
"""
Compute ellipsoid volumes from semi-axis vectors.
Parameters
----------
e : torch.FloatTensor(..., 3, 3)
Each row is a semi-axis vector; volume = (4/3)π * product of semi-axis lengths
Returns
-------
volumes : torch.FloatTensor(...)
"""
return 4 / 3 * torch.pi * e.norm(dim=-1).prod(dim=-1)
[docs]
def compute_cosine_similarity_matrix(e1, e2):
"""
compute the row-to-row dot products for batches of ellipsoids [n, i, j]
returns the [n, i, i] matrix of dot products between rows in e1 and e2
permuting the order of e1, e2, results in transposition of the overlap matrix
:param e1:
:param e2:
:return:
"""
return torch.einsum('nij, nkj -> nik', e1, e2)
[docs]
def safe_batched_eigh(covs, chunk=10000):
"""
Chunked symmetric eigendecomposition with a CPU fallback for CUSOLVER failures.
Parameters
----------
covs : torch.FloatTensor(n, d, d)
Batch of symmetric matrices
chunk : int
Number of matrices per kernel call
Returns
-------
eigenvalues : torch.FloatTensor(n, d)
eigenvectors : torch.FloatTensor(n, d, d)
"""
out_vals, out_vecs = [], []
for i in range(0, covs.shape[0], chunk):
cchunk = covs[i:i + chunk]
try:
ev, evec = torch.linalg.eigh(cchunk)
except torch.cuda.OutOfMemoryError:
raise
except RuntimeError as e:
if "CUSOLVER_STATUS_INVALID_VALUE" in str(e):
print("Invalid matrix to eigh! Switching to CPU.")
ev, evec = torch.linalg.eigh(cchunk.cpu())
ev, evec = ev.to(covs.device), evec.to(covs.device)
else:
raise e
out_vals.append(ev)
out_vecs.append(evec)
return torch.cat(out_vals, dim=0).float(), torch.cat(out_vecs, dim=0).float()
[docs]
def lat2sph_rotvec(lat_orientations, z_prime):
"""
Map latent orientation parameters (normalized to [-1, 1]) to spherical rotation vectors.
Inverse of sph_rotvec2lat. The three latent dimensions map as:
theta ∈ [-1,1] → [π/4, 3π/4] (upper half-sphere polar angle)
phi ∈ [-1,1] → [-π, π] (azimuthal angle)
r ∈ [-1,1] → [0, 2π] (rotation magnitude)
Parameters
----------
lat_orientations : torch.FloatTensor(..., z_prime * 3)
z_prime : int
Number of asymmetric units
Returns
-------
sph_rotvec : torch.FloatTensor(..., z_prime * 3)
Spherical rotation vectors [theta, phi, r] for each asymmetric unit
"""
lat = lat_orientations.view(*lat_orientations.shape[:-1], z_prime, 3)
# allocate output
sph = torch.empty_like(lat)
halfpi = torch.pi / 2
# theta
sph[..., 0] = lat[..., 0] * (torch.pi / 4) + (halfpi / 2)
# phi
sph[..., 1] = lat[..., 1] * torch.pi
# r
sph[..., 2] = lat[..., 2] * torch.pi + torch.pi
return sph.view(*lat_orientations.shape)
[docs]
def simple_latent_distance(l1: torch.Tensor,
l2: torch.Tensor) -> torch.Tensor:
"""euclidean distances, but with wrapped angular dimensions"""
max_z_prime = (l1.shape[-1] - 6) // 6
angs = [False] * 6
for zp in range(max_z_prime):
angs.extend([False, False, False])
for zp in range(max_z_prime):
angs.extend([False, True, True])
# phi and r dimensions arein rotational basis
periodic_mask = torch.tensor(angs, device=l1.device)
diff = l1 - l2
# wrap periodic dims
diff[:, periodic_mask] = ((diff[:, periodic_mask] + 1) % 2) - 1
dist = diff.norm(dim=-1)
return dist
[docs]
def compute_latent_distance(latents1: torch.Tensor,
latents2: torch.Tensor) -> torch.Tensor:
"""
Compute a distance metric between crystals in the latent parameterization.
:param params:
:return:
"""
n_params = latents1.shape[1]
assert latents1.shape[-1] == latents2.shape[-1]
z_prime = (n_params - 6) // 6
"Latent box parameters are the log lengths 0-2 and angles 3-5"
box_params1 = latents1[..., :6]
box_params2 = latents2[..., :6]
box_dist = (box_params1 - box_params2).norm(dim=-1)
"positions defined on [-1, 1]^3 * z_prime. Not periodic (different periodicity in every space group)"
positions1 = latents1[..., 6:6 + 3 * z_prime]
positions2 = latents2[..., 6:6 + 3 * z_prime]
positions_dist = (positions1 - positions2).norm(dim=-1)
"angles are defined in a spherical basis [0,pi/2], [-pi, pi], [0,2pi] and renormalized on [-1,1]"
lat_orientations1 = latents1[..., 6 + 3 * z_prime:]
lat_orientations2 = latents2[..., 6 + 3 * z_prime:]
lat_sph_rotvec1 = lat2sph_rotvec(lat_orientations1, z_prime) # [polar, azimuthal, length]
lat_sph_rotvec2 = lat2sph_rotvec(lat_orientations2, z_prime)
rot_dists = []
for zp in range(z_prime): # this should be replaced with a proper vector distance
rmat1 = rotvec2rotmat(lat_sph_rotvec1[..., 3 * zp:3 * zp + 3], 'spherical')
rmat2 = rotvec2rotmat(lat_sph_rotvec2[..., 3 * zp:3 * zp + 3], 'spherical')
R_delta = rmat1 @ rmat2.transpose(-1, -2)
trace = R_delta.diagonal(dim1=-2, dim2=-1).sum(-1)
cos_theta = ((trace - 1) / 2).clamp(-1 + 1e-7, 1 - 1e-7)
dist = torch.acos(cos_theta)
rot_dists.append(dist)
rot_dists = torch.stack(rot_dists).sum(0)
"overall distance metric"
# dists = 0.5 * box_dist + 0.25 * (positions_dist / z_prime / 2.5) + 0.25 * (rot_dists / z_prime / 2)
# scales = [2 * sqrt(6), 2*sqrt(3), torch.pi] # maximum variation per dist
scales = [1, 0.836, 0.293] # [0.0127, 0.0152, 0.0433] # empirical std over CSD samples
dists = scales[0] * 0.5 * box_dist + scales[1] * 0.25 * (positions_dist / z_prime) + scales[2] * 0.25 * (
rot_dists / z_prime)
return dists
# def crystal_parameter_distmat(latents, max_batch_size = 2000):
# """
# Compute distance to self for a set of crystal latent vectors
# :param latents:
# :return:
# """
# N, K = latents.shape
# if N < max_batch_size:
# lat1 = latents[:, None, :].expand(N, N, K).reshape(N * N, K)
# lat2 = latents[None, :, :].expand(N, N, K).reshape(N * N, K)
#
# d = crystal_parameter_distance(lat1, lat2) # (N*N,)
# distmat = d.reshape(N, N)
# else:
# distmat = torch.zeros((N, N), device=latents.device)
#
# for i in range(N):
# lat1 = latents[i:i + 1].expand(N, -1) # (N, K)
# lat2 = latents # (N, K)
#
# distmat[i] = crystal_parameter_distance(lat1, lat2)
#
#
# return distmat
[docs]
def crystal_parameter_distmat(
latents,
target_entries=5_000_000, # ≈ distances per call
min_block=1,
max_block=2048,
):
"""
Blockwise distance matrix with adaptive block size.
Tries to keep ~target_entries distances per kernel call.
"""
device = latents.device
N, K = latents.shape
# adaptive block size
B = max(min_block, min(max_block, target_entries // max(N, 1)))
distmat = torch.empty((N, N), device=device)
for i in range(0, N, B):
lat1 = latents[i:i + B] # (B, K)
b = lat1.shape[0]
lat1_exp = lat1[:, None, :].expand(b, N, K).reshape(-1, K)
lat2_exp = latents[None, :, :].expand(b, N, K).reshape(-1, K)
d = compute_latent_distance(lat1_exp, lat2_exp)
distmat[i:i + b] = d.view(b, N)
return distmat
[docs]
def sph_rotvec2lat(sph_rotvec, z_prime):
"""
Map spherical rotation vectors to latent orientation parameters normalized to [-1, 1].
Inverse of lat2sph_rotvec.
Parameters
----------
sph_rotvec : torch.FloatTensor(..., z_prime * 3)
Spherical rotation vectors [theta, phi, r] for each asymmetric unit
z_prime : int
Number of asymmetric units
Returns
-------
lat_orientations : torch.FloatTensor(..., z_prime * 3)
"""
sph = sph_rotvec.view(*sph_rotvec.shape[:-1], z_prime, 3)
# allocate output
lat = torch.empty_like(sph)
halfpi = torch.pi / 2
# theta: [0, π/2] → [-1, 1]
lat[..., 0] = (sph[..., 0] - (halfpi / 2)) / (torch.pi / 4)
# phi: [-π, π] → [-1, 1]
lat[..., 1] = sph[..., 1] / torch.pi
# r: [0, 2π] → [-1, 1]
lat[..., 2] = (sph[..., 2] - torch.pi) / torch.pi
return lat.view(*sph_rotvec.shape)