mxtaltools.common.geometry_utils
- mxtaltools.common.geometry_utils.angle2components(angle: tensor)[source]
Decompose an angle into sin(angle) and cos(angle)
- Parameters:
angle (torch.tensor(n))
- Returns:
sin(angle), cos(angle)
- Return type:
torch.tensor, torch.tensor
- mxtaltools.common.geometry_utils.apply_rotation_to_batch(elems, rotations, batch)[source]
Apply per-graph rotation matrices to a batch of vectors.
- Parameters:
elems (torch.FloatTensor(n, 3))
rotations (torch.FloatTensor(num_graphs, 3, 3))
batch (torch.LongTensor(n))
- Returns:
rotated
- Return type:
torch.FloatTensor(n, 3)
- mxtaltools.common.geometry_utils.batch_cell_vol_torch(v: tensor, a: tensor)[source]
Batched computation of unit cell volumes given basis vector lengths and internal angles.
- Parameters:
v (torch.tensor(n, 3)) – batch of [a, b, c] lengths
a (torch.tensor(n, 3)) – batch of [alpha, beta, gamma] angles in radians
- Returns:
cell_volumes
- Return type:
torch.tensor(n)
- mxtaltools.common.geometry_utils.batch_compute_fractional_transform(cell_lengths, cell_angles)[source]
compute f->c and c->f transforms as well as cell volume in a vectorized, differentiable way
- Parameters:
cell_lengths (torch.tensor(nx3)) – a, b, c
cell_angles (torch.tensor(nx3)) – alpha, beta, gamma
- Returns:
fc_transform (torch.tensor(n,3,3))
cf_transform (torch.tensor(n,3,3))
cell_volumes (torch.tensor(n))
- mxtaltools.common.geometry_utils.batch_compute_mol_mass(z: LongTensor, batch: LongTensor, masses_tensor: FloatTensor, num_graphs: int) Tensor[source]
Sum atomic masses for each graph in a batch.
- Parameters:
z (torch.LongTensor(n)) – Atomic numbers used to index masses_tensor
batch (torch.LongTensor(n))
masses_tensor (torch.FloatTensor) – Lookup table of atomic masses indexed by atomic number
num_graphs (int)
- Returns:
masses
- Return type:
torch.FloatTensor(num_graphs)
- mxtaltools.common.geometry_utils.batch_compute_mol_radius(coords: FloatTensor, batch: LongTensor, num_graphs: int, nodes_per_graph: LongTensor) Tensor[source]
Batched version of compute_mol_radius.
- Parameters:
coords (torch.FloatTensor(n, 3))
batch (torch.LongTensor(n))
num_graphs (int)
nodes_per_graph (torch.LongTensor(num_graphs))
- Returns:
radii
- Return type:
torch.FloatTensor(num_graphs)
- mxtaltools.common.geometry_utils.batch_compute_molecule_volume(atom_types: LongTensor, pos: FloatTensor, batch: LongTensor, num_graphs: int, vdw_radii_tensor: Tensor)[source]
Estimate molecular vdW volumes for a batch of molecules using a sphere-overlap correction.
Sums atomic sphere volumes then subtracts pairwise sphere-sphere intersection volumes scaled by an empirical correction factor (~0.73) to approximate triple overlaps.
- Parameters:
atom_types (torch.LongTensor(n)) – Atomic numbers used to index vdw_radii_tensor
pos (torch.FloatTensor(n, 3)) – Atomic coordinates
batch (torch.LongTensor(n)) – Graph index for each atom
num_graphs (int)
vdw_radii_tensor (torch.Tensor) – Lookup table of vdW radii indexed by atomic number
- Returns:
corrected_mol_volume
- Return type:
torch.FloatTensor(num_graphs)
- mxtaltools.common.geometry_utils.batch_get_furthest_node_vector(all_coords: FloatTensor, batch: LongTensor, num_graphs: int) Tensor[source]
Compute cardinal direction for a list of sets of coordinates, defined as the vector from the centroid to the furthest coordinate. Output is not unique for certain symmetric inputs.
- Parameters:
all_coords (torch.tensor(n,3))
batch (torch.tensor(n))
num_graphs (int)
- Returns:
direction
- Return type:
torch.tensor(num_graphs, 3)
- mxtaltools.common.geometry_utils.batch_molecule_principal_axes_torch(coords_i: FloatTensor, batch: LongTensor, num_graphs: int, nodes_per_graph: LongTensor, heavy_atoms_only: bool = True, atom_types: LongTensor | None = None)[source]
Parallel computation of principal inertial axes from a list of coordinate lists.
- Parameters:
skip_centring (bool) – Whether to skip centering each point cloud - e.g., if the input is already centered
- Returns:
Ip_fin (list(torch.tensor(3,3)))
Ipm_fin (list(torch.tensor(3)))
I (list(torch.tensor(3,3)))
- mxtaltools.common.geometry_utils.cart2sph_rotvec(rotvec)[source]
transform rotation vector with axis rotvec/norm(rotvec) and angle ||rotvec|| to spherical coordinates theta, phi and r=||rotvec||
- Parameters:
rotvec ((nx3)) – x, y, z
- Returns:
angles – theta, phi, r
- Return type:
(nx3)
- mxtaltools.common.geometry_utils.cell_parameters_to_box_vectors(opt: str, cell_lengths: tensor, cell_angles: tensor, return_vol: bool = False)[source]
# TODO I believe this is a duplicate function Initially borrowed from Nikos Quickly convert from cell lengths and angles to fractional transform matrices fractional->cartesian or cartesian->fractional
- mxtaltools.common.geometry_utils.cell_vol_angle_factor(cell_angles)[source]
Compute the angular factor sqrt(1 - cos²α - cos²β - cos²γ + 2cosα cosβ cosγ) used in cell volume calculations.
- Parameters:
cell_angles (torch.tensor(..., 3)) – [alpha, beta, gamma] in radians; leading batch dimensions are supported
- Returns:
factor
- Return type:
torch.tensor(…)
- mxtaltools.common.geometry_utils.cell_vol_np(v, a)[source]
compute the volume of a parallelpiped given basis vector lengths and internal angles :param v: [a b c] :type v: np.array(3) :param a: [alpha beta gamma] :type a: np.array(3)
- Returns:
cell_volume
- Return type:
- mxtaltools.common.geometry_utils.cell_vol_torch(v: tensor, a: tensor)[source]
compute the volume of a parallelpiped given basis vector lengths and internal angles :param v: [a b c] :type v: torch.tensor(3) :param a: [alpha beta gamma] :type a: torch.tensor(3)
- Returns:
cell_volume
- Return type:
- mxtaltools.common.geometry_utils.center_batch(coords: FloatTensor, batch: LongTensor, num_graphs: int, nodes_per_graph: LongTensor, center_on_heavy_atoms: bool = False, atom_types: LongTensor | None = None) Tensor[source]
Subtract the centroid from each graph’s coordinates, returning zero-centered positions.
- Parameters:
coords (torch.FloatTensor(n, 3))
batch (torch.LongTensor(n))
num_graphs (int)
nodes_per_graph (torch.LongTensor(num_graphs))
center_on_heavy_atoms (bool) – If True, compute the centroid using only heavy atoms (atom_types > 1) but translate all atoms
atom_types (torch.LongTensor(n), optional) – Required when center_on_heavy_atoms is True
- Returns:
coords_centered
- Return type:
torch.FloatTensor(n, 3)
- mxtaltools.common.geometry_utils.components2angle(components: tensor, norm_components=True)[source]
Take two non-normalized components[n, 2] representing sin(angle) and cos(angle), compute the resulting angle, following https://ai.stackexchange.com/questions/38045/how-can-i-encode-angle-data-to-train-neural-networks
Optionally norm the sum of squares - doesn’t appear to do much though.
- Parameters:
components (torch.tensor(n, 2))
norm_components (bool, optional)
- Returns:
angles
- Return type:
torch.tensor(n, 2)
- mxtaltools.common.geometry_utils.compute_Ip_handedness(Ip)[source]
determine the right or left handedness from the cross products of principal inertial axes np.array or torch.tensor input, single or multiple samples
- Parameters:
Ip ((opt n, 3, 3)) – principal inertial tensor
- Returns:
handedness – +/- 1, the handedness of the cross products of principal inertial axes
- Return type:
- mxtaltools.common.geometry_utils.compute_cosine_similarity_matrix(e1, e2)[source]
compute the row-to-row dot products for batches of ellipsoids [n, i, j] returns the [n, i, i] matrix of dot products between rows in e1 and e2
permuting the order of e1, e2, results in transposition of the overlap matrix :param e1: :param e2: :return:
- mxtaltools.common.geometry_utils.compute_ellipsoid_volume(e)[source]
Compute ellipsoid volumes from semi-axis vectors.
- Parameters:
e (torch.FloatTensor(..., 3, 3)) – Each row is a semi-axis vector; volume = (4/3)π * product of semi-axis lengths
- Returns:
volumes
- Return type:
torch.FloatTensor(…)
- mxtaltools.common.geometry_utils.compute_inertial_tensor_torch(x: tensor, y: tensor, z: tensor)[source]
Compute the principal inertial axes for a given set of particle coordinates, ignoring particle mass.
- Parameters:
x (torch.tensor)
y (torch.tensor)
z (torch.tensor)
- Returns:
Ip (np.array(3,3)) – Principal inertial axes.
Ipm (np.array(3)) – Principal inertial moments
I (np.array(3)) – Inertial tensor in original frame
- mxtaltools.common.geometry_utils.compute_latent_distance(latents1: Tensor, latents2: Tensor) Tensor[source]
Compute a distance metric between crystals in the latent parameterization. :param params: :return:
- mxtaltools.common.geometry_utils.compute_mol_mass(z: LongTensor, masses_tensor: FloatTensor) Tensor[source]
Sum atomic masses for a single molecule.
- Parameters:
z (torch.LongTensor(n)) – Atomic numbers
masses_tensor (torch.FloatTensor) – Lookup table of atomic masses indexed by atomic number
- Returns:
mass
- Return type:
torch.FloatTensor scalar
- mxtaltools.common.geometry_utils.compute_mol_radius(coords: FloatTensor) Tensor[source]
Compute the radius of a single molecule as the maximum distance from the centroid to any atom.
- Parameters:
coords (torch.FloatTensor(n, 3))
- Returns:
radius
- Return type:
torch.FloatTensor scalar
- mxtaltools.common.geometry_utils.compute_principal_axes_np(coords)[source]
Compute the principal inertial axes for a given set of particle coordinates, ignoring particle mass.
Use our overlap rules to ensure a fixed direction for all axes under almost all circumstances, excepting e.g., certain symmetric molecules.
- Parameters:
coords
- Returns:
Ip (np.array(3,3)) – Principal inertial axes.
Ipm (np.array(3)) – Principal inertial moments
I (np.array(3)) – Inertial tensor in original frame
- mxtaltools.common.geometry_utils.coor_trans_matrix_np(opt, v, a, return_vol=False)[source]
compute f->c and c->f transforms as well as cell volume in a vectorized, differentiable way
- Parameters:
opt (str 'c_to_f' or 'f_to_c') – which direction to transform between fractional and cartesian
v (np.array(3)) – a, b, c
a (np.array(3)) – alpha, beta, gamma
return_vol (bool, optional) – return the absolute value of the cell volume
- Returns:
transform (np.array(3,3))
cell_volumes (float)
- mxtaltools.common.geometry_utils.correct_Ip_directions(Ip, overlaps, signs, overlap_threshold: float = 1e-05)[source]
Enforce positive overlaps for given inertial principal axes with a given canonical direction, given their overlaps.
- Parameters:
Ip (torch.tensor(3,3))
overlaps (torch.tensor(3))
signs (torch.tensor(3))
overlap_threshold (float)
- Returns:
Ip_fin – Inertial principal axes with positive overlaps to the given canonical direction
- Return type:
torch.tensor(3,3)
- mxtaltools.common.geometry_utils.crystal_parameter_distmat(latents, target_entries=5000000, min_block=1, max_block=2048)[source]
Blockwise distance matrix with adaptive block size. Tries to keep ~target_entries distances per kernel call.
- mxtaltools.common.geometry_utils.embed_vector_to_rank3(v)[source]
embed an nxk vector as a symmetric 3-tensor
- mxtaltools.common.geometry_utils.enforce_crystal_system(lattice_lengths, lattice_angles, sg_inds, symmetries_dict: dict | None = None)[source]
enforce physical bounds on cell parameters https://en.wikipedia.org/wiki/Crystal_system
- mxtaltools.common.geometry_utils.enforce_crystal_system2(lattice_lengths, lattice_angles, lattices)[source]
enforce physical bounds on cell parameters https://en.wikipedia.org/wiki/Crystal_system
- mxtaltools.common.geometry_utils.extract_batching_info(nodes_list, device='cpu')[source]
Extract batch and ptr info from a list of sets of coordinates.
- mxtaltools.common.geometry_utils.extract_rotmat(target_position: FloatTensor, original_position: FloatTensor) Tensor[source]
Compute the rotation matrix R such that R @ original_position ≈ target_position.
- Parameters:
target_position (torch.FloatTensor(n, 3, 3) or (3, 3))
original_position (torch.FloatTensor(n, 3, 3) or (3, 3))
- Returns:
rotmat
- Return type:
torch.FloatTensor(n, 3, 3) or (3, 3)
- mxtaltools.common.geometry_utils.fractional_transform(coords, transform_matrix)[source]
Transform between fractional/cartesian bases. Assumes the fractional->cartesian transform is the transpose of the box vectors :param coords: :param transform_matrix:
Returns: transformed_coords
- mxtaltools.common.geometry_utils.fractional_transform_np(coords, transform_matrix)[source]
Apply a fractional/cartesian transform to numpy coordinate arrays.
Dispatches on the combination of coords and transform_matrix dimensionality: (n,3)+(3,3) → per-point transform; (n,m,3)+(3,3) → per-atom-in-molecule; (n,3)+(n,3,3) → per-graph batched transform.
- Parameters:
coords (np.ndarray)
transform_matrix (np.ndarray)
- Returns:
transformed
- Return type:
np.ndarray
- mxtaltools.common.geometry_utils.fractional_transform_torch(coords, transform_matrix)[source]
Apply a fractional/cartesian transform to torch coordinate tensors.
Same dispatch logic as fractional_transform_np.
- Parameters:
coords (torch.FloatTensor)
transform_matrix (torch.FloatTensor)
- Returns:
transformed
- Return type:
torch.FloatTensor
- mxtaltools.common.geometry_utils.get_batch_centroids(coords: FloatTensor, batch: LongTensor, num_graphs: int) Tensor[source]
Compute the centroid (mean position) for each graph in a batch.
- Parameters:
coords (torch.FloatTensor(n, 3))
batch (torch.LongTensor(n))
num_graphs (int)
- Returns:
centroids
- Return type:
torch.FloatTensor(num_graphs, 3)
- mxtaltools.common.geometry_utils.get_overlaps(Ip, direction)[source]
Compute overlaps and signs for given inertial principal axes with a given canonical direction
- Parameters:
Ip (torch.tensor(3,3))
direction (torch.tensor(3))
- Returns:
overlaps (torch.tensor(3))
signs (torch.tensor(3))
- mxtaltools.common.geometry_utils.grid_compute_molecule_volume(atom_types, pos, vdw_radii_tensor, eps)[source]
brute force grid approach to computing vdW volume for a single molecule :param atom_types: :param pos: :param vdw_radii_tensor:
- mxtaltools.common.geometry_utils.grid_compute_molecule_volume_pointwise(atom_types, pos, vdw_radii_tensor, eps)[source]
brute force grid approach to computing vdW volume for a single molecule :param atom_types: :param pos: :param vdw_radii_tensor:
- mxtaltools.common.geometry_utils.lat2sph_rotvec(lat_orientations, z_prime)[source]
Map latent orientation parameters (normalized to [-1, 1]) to spherical rotation vectors.
- Inverse of sph_rotvec2lat. The three latent dimensions map as:
theta ∈ [-1,1] → [π/4, 3π/4] (upper half-sphere polar angle) phi ∈ [-1,1] → [-π, π] (azimuthal angle) r ∈ [-1,1] → [0, 2π] (rotation magnitude)
- Parameters:
lat_orientations (torch.FloatTensor(..., z_prime * 3))
z_prime (int) – Number of asymmetric units
- Returns:
sph_rotvec – Spherical rotation vectors [theta, phi, r] for each asymmetric unit
- Return type:
torch.FloatTensor(…, z_prime * 3)
- mxtaltools.common.geometry_utils.list_molecule_principal_axes_torch(coords_list: list = None, skip_centring=False)[source]
Parallel computation of principal inertial axes from a list of coordinate lists.
- mxtaltools.common.geometry_utils.mol_batch_vdW_volume(mol_batch)[source]
wrapper for batch_compute_vdW_volume
- mxtaltools.common.geometry_utils.nan_hook(name, tensor_ref, batch)[source]
Return a backward hook that prints debug info and halts on NaN gradients.
- mxtaltools.common.geometry_utils.norm_circular_components(components: tensor)[source]
Use Pythagoras to norm the sum of squares to the unit circle. :param components: :type components: torch.tensor(n, 2)
- Returns:
normed_components
- Return type:
torch.tensor(n, 2)
- mxtaltools.common.geometry_utils.probe_compute_molecule_volume(atom_types: LongTensor, pos: FloatTensor, batch: LongTensor, num_graphs: int, vdw_radii_tensor: Tensor, probes_per_mol: int = 100, eps: float = 0.01, max_iters: int = 1000, min_iters: int = 5)[source]
Estimate molecular vdW volumes via Monte Carlo probe sampling, iterating until convergence.
Random probes are scattered within each molecule’s bounding box; the fraction inside any atomic vdW sphere gives the volume estimate. Runs until the relative change in the running mean drops below eps for min_iters consecutive iterations.
- Parameters:
atom_types (torch.LongTensor(n))
pos (torch.FloatTensor(n, 3))
batch (torch.LongTensor(n))
num_graphs (int)
vdw_radii_tensor (torch.Tensor) – Lookup table of vdW radii indexed by atomic number
probes_per_mol (int) – Number of random probes per molecule per iteration
eps (float) – Convergence threshold on relative change in running mean
max_iters (int) – Maximum number of sampling iterations
min_iters (int) – Minimum iterations before convergence is checked
- Returns:
volumes – Converged volume estimates
- Return type:
torch.FloatTensor(num_graphs)
- mxtaltools.common.geometry_utils.rotmat2rotvec(rotation_matrix_list, warn_on_bad_determinant=True)[source]
Convert a batch of rotation matrices to rotation vectors (axis-angle).
Uses the Rodrigues formula: axis from the skew-symmetric part, angle from the trace. Degenerate cases (near-identity or near-pi rotations) are handled by clamping to a fallback rotation of pi around [1,1,1].
- Parameters:
rotation_matrix_list (torch.FloatTensor(n, 3, 3))
warn_on_bad_determinant (bool) – Print a warning if any matrix has det < 0 (i.e. is a reflection, not a rotation)
- Returns:
rotvecs
- Return type:
torch.FloatTensor(n, 3)
- mxtaltools.common.geometry_utils.rotvec2rotmat(mol_rotation: tensor, basis='cartesian')[source]
get applied rotation matrix mol_rotation here is a list of rotation vectors [n_samples, 3] rotvec -> rotation matrix directly (see https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula) rotvec -> quat -> rotation matrix (old way)
- mxtaltools.common.geometry_utils.safe_batched_eigh(covs, chunk=10000)[source]
Chunked symmetric eigendecomposition with a CPU fallback for CUSOLVER failures.
- Parameters:
covs (torch.FloatTensor(n, d, d)) – Batch of symmetric matrices
chunk (int) – Number of matrices per kernel call
- Returns:
eigenvalues (torch.FloatTensor(n, d))
eigenvectors (torch.FloatTensor(n, d, d))
- mxtaltools.common.geometry_utils.sample_random_valid_rotvecs(num_samples)[source]
Sample uniformly random rotation vectors with theta restricted to the upper half-sphere (z ≥ 0).
Directions are drawn from a Gaussian (giving uniform spherical coverage) and norms are drawn uniformly from (0, 2π].
- Parameters:
num_samples (int)
- Returns:
rotvecs
- Return type:
torch.FloatTensor(num_samples, 3)
- mxtaltools.common.geometry_utils.scatter_compute_Ip(all_coords, batch, eps: float = 0.05, add_noise: bool = False)[source]
Parallel function to compute inertial for a list of unequal sized sets of coordinates.
- Parameters:
all_coords (torch.tensor(n,3))
batch (torch.tensor(n))
- Returns:
Ip (torch.tensor(num_graphs, 3, 3))
Ipm (torch.tensor(num_graphs, 3))
I (torch.tensor(num_graphs, 3, 3))
- mxtaltools.common.geometry_utils.simple_latent_distance(l1: Tensor, l2: Tensor) Tensor[source]
euclidean distances, but with wrapped angular dimensions
- mxtaltools.common.geometry_utils.single_molecule_principal_axes_torch(coords: tensor, masses=None, return_direction=False)[source]
Compute the principal inertial axes for a given set of particle coordinates, ignoring particle mass.
Use our overlap rules to ensure a fixed direction for all axes under almost all circumstances, excepting e.g., certain symmetric molecules.
- Parameters:
coords (torch.tensor(n,3))
masses (None) – not used
return_direction (bool) – whether to add the direction between centroid and most distant coordinate to the output
- Returns:
Ip (torch.tensor(3,3)) – Principal inertial axes.
Ipm (torch.tensor(3)) – Principal inertial moments
I (torch.tensor(3)) – Inertial tensor in original frame
- mxtaltools.common.geometry_utils.sph2cart_rotvec(angles)[source]
Transform from axis-angle in polar coordinates to rotation vector
- Parameters:
angles ((nx3)) – theta, phi, r
- Returns:
rotvec – x, y, z
- Return type:
(nx3)
- mxtaltools.common.geometry_utils.sph_rotvec2lat(sph_rotvec, z_prime)[source]
Map spherical rotation vectors to latent orientation parameters normalized to [-1, 1].
Inverse of lat2sph_rotvec.
- Parameters:
sph_rotvec (torch.FloatTensor(..., z_prime * 3)) – Spherical rotation vectors [theta, phi, r] for each asymmetric unit
z_prime (int) – Number of asymmetric units
- Returns:
lat_orientations
- Return type:
torch.FloatTensor(…, z_prime * 3)