mxtaltools.common.geometry_utils

mxtaltools.common.geometry_utils.angle2components(angle: tensor)[source]

Decompose an angle into sin(angle) and cos(angle)

Parameters:

angle (torch.tensor(n))

Returns:

sin(angle), cos(angle)

Return type:

torch.tensor, torch.tensor

mxtaltools.common.geometry_utils.apply_rotation_to_batch(elems, rotations, batch)[source]

Apply per-graph rotation matrices to a batch of vectors.

Parameters:
  • elems (torch.FloatTensor(n, 3))

  • rotations (torch.FloatTensor(num_graphs, 3, 3))

  • batch (torch.LongTensor(n))

Returns:

rotated

Return type:

torch.FloatTensor(n, 3)

mxtaltools.common.geometry_utils.batch_cell_vol_torch(v: tensor, a: tensor)[source]

Batched computation of unit cell volumes given basis vector lengths and internal angles.

Parameters:
  • v (torch.tensor(n, 3)) – batch of [a, b, c] lengths

  • a (torch.tensor(n, 3)) – batch of [alpha, beta, gamma] angles in radians

Returns:

cell_volumes

Return type:

torch.tensor(n)

mxtaltools.common.geometry_utils.batch_compute_fractional_transform(cell_lengths, cell_angles)[source]

compute f->c and c->f transforms as well as cell volume in a vectorized, differentiable way

Parameters:
  • cell_lengths (torch.tensor(nx3)) – a, b, c

  • cell_angles (torch.tensor(nx3)) – alpha, beta, gamma

Returns:

  • fc_transform (torch.tensor(n,3,3))

  • cf_transform (torch.tensor(n,3,3))

  • cell_volumes (torch.tensor(n))

mxtaltools.common.geometry_utils.batch_compute_mol_mass(z: LongTensor, batch: LongTensor, masses_tensor: FloatTensor, num_graphs: int) Tensor[source]

Sum atomic masses for each graph in a batch.

Parameters:
  • z (torch.LongTensor(n)) – Atomic numbers used to index masses_tensor

  • batch (torch.LongTensor(n))

  • masses_tensor (torch.FloatTensor) – Lookup table of atomic masses indexed by atomic number

  • num_graphs (int)

Returns:

masses

Return type:

torch.FloatTensor(num_graphs)

mxtaltools.common.geometry_utils.batch_compute_mol_radius(coords: FloatTensor, batch: LongTensor, num_graphs: int, nodes_per_graph: LongTensor) Tensor[source]

Batched version of compute_mol_radius.

Parameters:
  • coords (torch.FloatTensor(n, 3))

  • batch (torch.LongTensor(n))

  • num_graphs (int)

  • nodes_per_graph (torch.LongTensor(num_graphs))

Returns:

radii

Return type:

torch.FloatTensor(num_graphs)

mxtaltools.common.geometry_utils.batch_compute_molecule_volume(atom_types: LongTensor, pos: FloatTensor, batch: LongTensor, num_graphs: int, vdw_radii_tensor: Tensor)[source]

Estimate molecular vdW volumes for a batch of molecules using a sphere-overlap correction.

Sums atomic sphere volumes then subtracts pairwise sphere-sphere intersection volumes scaled by an empirical correction factor (~0.73) to approximate triple overlaps.

Parameters:
  • atom_types (torch.LongTensor(n)) – Atomic numbers used to index vdw_radii_tensor

  • pos (torch.FloatTensor(n, 3)) – Atomic coordinates

  • batch (torch.LongTensor(n)) – Graph index for each atom

  • num_graphs (int)

  • vdw_radii_tensor (torch.Tensor) – Lookup table of vdW radii indexed by atomic number

Returns:

corrected_mol_volume

Return type:

torch.FloatTensor(num_graphs)

mxtaltools.common.geometry_utils.batch_get_furthest_node_vector(all_coords: FloatTensor, batch: LongTensor, num_graphs: int) Tensor[source]

Compute cardinal direction for a list of sets of coordinates, defined as the vector from the centroid to the furthest coordinate. Output is not unique for certain symmetric inputs.

Parameters:
  • all_coords (torch.tensor(n,3))

  • batch (torch.tensor(n))

  • num_graphs (int)

Returns:

direction

Return type:

torch.tensor(num_graphs, 3)

mxtaltools.common.geometry_utils.batch_molecule_principal_axes_torch(coords_i: FloatTensor, batch: LongTensor, num_graphs: int, nodes_per_graph: LongTensor, heavy_atoms_only: bool = True, atom_types: LongTensor | None = None)[source]

Parallel computation of principal inertial axes from a list of coordinate lists.

Parameters:

skip_centring (bool) – Whether to skip centering each point cloud - e.g., if the input is already centered

Returns:

  • Ip_fin (list(torch.tensor(3,3)))

  • Ipm_fin (list(torch.tensor(3)))

  • I (list(torch.tensor(3,3)))

mxtaltools.common.geometry_utils.cart2sph_rotvec(rotvec)[source]

transform rotation vector with axis rotvec/norm(rotvec) and angle ||rotvec|| to spherical coordinates theta, phi and r=||rotvec||

Parameters:

rotvec ((nx3)) – x, y, z

Returns:

angles – theta, phi, r

Return type:

(nx3)

mxtaltools.common.geometry_utils.cell_parameters_to_box_vectors(opt: str, cell_lengths: tensor, cell_angles: tensor, return_vol: bool = False)[source]

# TODO I believe this is a duplicate function Initially borrowed from Nikos Quickly convert from cell lengths and angles to fractional transform matrices fractional->cartesian or cartesian->fractional

mxtaltools.common.geometry_utils.cell_vol_angle_factor(cell_angles)[source]

Compute the angular factor sqrt(1 - cos²α - cos²β - cos²γ + 2cosα cosβ cosγ) used in cell volume calculations.

Parameters:

cell_angles (torch.tensor(..., 3)) – [alpha, beta, gamma] in radians; leading batch dimensions are supported

Returns:

factor

Return type:

torch.tensor(…)

mxtaltools.common.geometry_utils.cell_vol_np(v, a)[source]

compute the volume of a parallelpiped given basis vector lengths and internal angles :param v: [a b c] :type v: np.array(3) :param a: [alpha beta gamma] :type a: np.array(3)

Returns:

cell_volume

Return type:

float

mxtaltools.common.geometry_utils.cell_vol_torch(v: tensor, a: tensor)[source]

compute the volume of a parallelpiped given basis vector lengths and internal angles :param v: [a b c] :type v: torch.tensor(3) :param a: [alpha beta gamma] :type a: torch.tensor(3)

Returns:

cell_volume

Return type:

float

mxtaltools.common.geometry_utils.center_batch(coords: FloatTensor, batch: LongTensor, num_graphs: int, nodes_per_graph: LongTensor, center_on_heavy_atoms: bool = False, atom_types: LongTensor | None = None) Tensor[source]

Subtract the centroid from each graph’s coordinates, returning zero-centered positions.

Parameters:
  • coords (torch.FloatTensor(n, 3))

  • batch (torch.LongTensor(n))

  • num_graphs (int)

  • nodes_per_graph (torch.LongTensor(num_graphs))

  • center_on_heavy_atoms (bool) – If True, compute the centroid using only heavy atoms (atom_types > 1) but translate all atoms

  • atom_types (torch.LongTensor(n), optional) – Required when center_on_heavy_atoms is True

Returns:

coords_centered

Return type:

torch.FloatTensor(n, 3)

mxtaltools.common.geometry_utils.components2angle(components: tensor, norm_components=True)[source]

Take two non-normalized components[n, 2] representing sin(angle) and cos(angle), compute the resulting angle, following https://ai.stackexchange.com/questions/38045/how-can-i-encode-angle-data-to-train-neural-networks

Optionally norm the sum of squares - doesn’t appear to do much though.

Parameters:
  • components (torch.tensor(n, 2))

  • norm_components (bool, optional)

Returns:

angles

Return type:

torch.tensor(n, 2)

mxtaltools.common.geometry_utils.compute_Ip_handedness(Ip)[source]

determine the right or left handedness from the cross products of principal inertial axes np.array or torch.tensor input, single or multiple samples

Parameters:

Ip ((opt n, 3, 3)) – principal inertial tensor

Returns:

handedness – +/- 1, the handedness of the cross products of principal inertial axes

Return type:

mxtaltools.common.geometry_utils.compute_cosine_similarity_matrix(e1, e2)[source]

compute the row-to-row dot products for batches of ellipsoids [n, i, j] returns the [n, i, i] matrix of dot products between rows in e1 and e2

permuting the order of e1, e2, results in transposition of the overlap matrix :param e1: :param e2: :return:

mxtaltools.common.geometry_utils.compute_ellipsoid_volume(e)[source]

Compute ellipsoid volumes from semi-axis vectors.

Parameters:

e (torch.FloatTensor(..., 3, 3)) – Each row is a semi-axis vector; volume = (4/3)π * product of semi-axis lengths

Returns:

volumes

Return type:

torch.FloatTensor(…)

mxtaltools.common.geometry_utils.compute_inertial_tensor_torch(x: tensor, y: tensor, z: tensor)[source]

Compute the principal inertial axes for a given set of particle coordinates, ignoring particle mass.

Parameters:
  • x (torch.tensor)

  • y (torch.tensor)

  • z (torch.tensor)

Returns:

  • Ip (np.array(3,3)) – Principal inertial axes.

  • Ipm (np.array(3)) – Principal inertial moments

  • I (np.array(3)) – Inertial tensor in original frame

mxtaltools.common.geometry_utils.compute_latent_distance(latents1: Tensor, latents2: Tensor) Tensor[source]

Compute a distance metric between crystals in the latent parameterization. :param params: :return:

mxtaltools.common.geometry_utils.compute_mol_mass(z: LongTensor, masses_tensor: FloatTensor) Tensor[source]

Sum atomic masses for a single molecule.

Parameters:
  • z (torch.LongTensor(n)) – Atomic numbers

  • masses_tensor (torch.FloatTensor) – Lookup table of atomic masses indexed by atomic number

Returns:

mass

Return type:

torch.FloatTensor scalar

mxtaltools.common.geometry_utils.compute_mol_radius(coords: FloatTensor) Tensor[source]

Compute the radius of a single molecule as the maximum distance from the centroid to any atom.

Parameters:

coords (torch.FloatTensor(n, 3))

Returns:

radius

Return type:

torch.FloatTensor scalar

mxtaltools.common.geometry_utils.compute_principal_axes_np(coords)[source]

Compute the principal inertial axes for a given set of particle coordinates, ignoring particle mass.

Use our overlap rules to ensure a fixed direction for all axes under almost all circumstances, excepting e.g., certain symmetric molecules.

Parameters:

coords

Returns:

  • Ip (np.array(3,3)) – Principal inertial axes.

  • Ipm (np.array(3)) – Principal inertial moments

  • I (np.array(3)) – Inertial tensor in original frame

mxtaltools.common.geometry_utils.coor_trans_matrix_np(opt, v, a, return_vol=False)[source]

compute f->c and c->f transforms as well as cell volume in a vectorized, differentiable way

Parameters:
  • opt (str 'c_to_f' or 'f_to_c') – which direction to transform between fractional and cartesian

  • v (np.array(3)) – a, b, c

  • a (np.array(3)) – alpha, beta, gamma

  • return_vol (bool, optional) – return the absolute value of the cell volume

Returns:

  • transform (np.array(3,3))

  • cell_volumes (float)

mxtaltools.common.geometry_utils.correct_Ip_directions(Ip, overlaps, signs, overlap_threshold: float = 1e-05)[source]

Enforce positive overlaps for given inertial principal axes with a given canonical direction, given their overlaps.

Parameters:
  • Ip (torch.tensor(3,3))

  • overlaps (torch.tensor(3))

  • signs (torch.tensor(3))

  • overlap_threshold (float)

Returns:

Ip_fin – Inertial principal axes with positive overlaps to the given canonical direction

Return type:

torch.tensor(3,3)

mxtaltools.common.geometry_utils.crystal_parameter_distmat(latents, target_entries=5000000, min_block=1, max_block=2048)[source]

Blockwise distance matrix with adaptive block size. Tries to keep ~target_entries distances per kernel call.

mxtaltools.common.geometry_utils.embed_vector_to_rank3(v)[source]

embed an nxk vector as a symmetric 3-tensor

mxtaltools.common.geometry_utils.enforce_crystal_system(lattice_lengths, lattice_angles, sg_inds, symmetries_dict: dict | None = None)[source]

enforce physical bounds on cell parameters https://en.wikipedia.org/wiki/Crystal_system

mxtaltools.common.geometry_utils.enforce_crystal_system2(lattice_lengths, lattice_angles, lattices)[source]

enforce physical bounds on cell parameters https://en.wikipedia.org/wiki/Crystal_system

mxtaltools.common.geometry_utils.extract_batching_info(nodes_list, device='cpu')[source]

Extract batch and ptr info from a list of sets of coordinates.

Parameters:
  • nodes_list (list(torch.tensor(n,3)) with different n throughout)

  • device (str)

Returns:

  • batch (torch.tensor(num_nodes))

  • ptr (torch.tensor(num_graphs + 1))

mxtaltools.common.geometry_utils.extract_rotmat(target_position: FloatTensor, original_position: FloatTensor) Tensor[source]

Compute the rotation matrix R such that R @ original_position ≈ target_position.

Parameters:
  • target_position (torch.FloatTensor(n, 3, 3) or (3, 3))

  • original_position (torch.FloatTensor(n, 3, 3) or (3, 3))

Returns:

rotmat

Return type:

torch.FloatTensor(n, 3, 3) or (3, 3)

mxtaltools.common.geometry_utils.fractional_transform(coords, transform_matrix)[source]

Transform between fractional/cartesian bases. Assumes the fractional->cartesian transform is the transpose of the box vectors :param coords: :param transform_matrix:

Returns: transformed_coords

mxtaltools.common.geometry_utils.fractional_transform_np(coords, transform_matrix)[source]

Apply a fractional/cartesian transform to numpy coordinate arrays.

Dispatches on the combination of coords and transform_matrix dimensionality: (n,3)+(3,3) → per-point transform; (n,m,3)+(3,3) → per-atom-in-molecule; (n,3)+(n,3,3) → per-graph batched transform.

Parameters:
  • coords (np.ndarray)

  • transform_matrix (np.ndarray)

Returns:

transformed

Return type:

np.ndarray

mxtaltools.common.geometry_utils.fractional_transform_torch(coords, transform_matrix)[source]

Apply a fractional/cartesian transform to torch coordinate tensors.

Same dispatch logic as fractional_transform_np.

Parameters:
  • coords (torch.FloatTensor)

  • transform_matrix (torch.FloatTensor)

Returns:

transformed

Return type:

torch.FloatTensor

mxtaltools.common.geometry_utils.get_batch_centroids(coords: FloatTensor, batch: LongTensor, num_graphs: int) Tensor[source]

Compute the centroid (mean position) for each graph in a batch.

Parameters:
  • coords (torch.FloatTensor(n, 3))

  • batch (torch.LongTensor(n))

  • num_graphs (int)

Returns:

centroids

Return type:

torch.FloatTensor(num_graphs, 3)

mxtaltools.common.geometry_utils.get_overlaps(Ip, direction)[source]

Compute overlaps and signs for given inertial principal axes with a given canonical direction

Parameters:
  • Ip (torch.tensor(3,3))

  • direction (torch.tensor(3))

Returns:

  • overlaps (torch.tensor(3))

  • signs (torch.tensor(3))

mxtaltools.common.geometry_utils.grid_compute_molecule_volume(atom_types, pos, vdw_radii_tensor, eps)[source]

brute force grid approach to computing vdW volume for a single molecule :param atom_types: :param pos: :param vdw_radii_tensor:

mxtaltools.common.geometry_utils.grid_compute_molecule_volume_pointwise(atom_types, pos, vdw_radii_tensor, eps)[source]

brute force grid approach to computing vdW volume for a single molecule :param atom_types: :param pos: :param vdw_radii_tensor:

mxtaltools.common.geometry_utils.lat2sph_rotvec(lat_orientations, z_prime)[source]

Map latent orientation parameters (normalized to [-1, 1]) to spherical rotation vectors.

Inverse of sph_rotvec2lat. The three latent dimensions map as:

theta ∈ [-1,1] → [π/4, 3π/4] (upper half-sphere polar angle) phi ∈ [-1,1] → [-π, π] (azimuthal angle) r ∈ [-1,1] → [0, 2π] (rotation magnitude)

Parameters:
  • lat_orientations (torch.FloatTensor(..., z_prime * 3))

  • z_prime (int) – Number of asymmetric units

Returns:

sph_rotvec – Spherical rotation vectors [theta, phi, r] for each asymmetric unit

Return type:

torch.FloatTensor(…, z_prime * 3)

mxtaltools.common.geometry_utils.list_molecule_principal_axes_torch(coords_list: list = None, skip_centring=False)[source]

Parallel computation of principal inertial axes from a list of coordinate lists.

Parameters:
  • coords_list (list(torch.tensor(n,3)))

  • skip_centring (bool) – Whether to skip centering each point cloud - e.g., if the input is already centered

Returns:

  • Ip_fin (list(torch.tensor(3,3)))

  • Ipm_fin (list(torch.tensor(3)))

  • I (list(torch.tensor(3,3)))

mxtaltools.common.geometry_utils.mol_batch_vdW_volume(mol_batch)[source]

wrapper for batch_compute_vdW_volume

mxtaltools.common.geometry_utils.nan_hook(name, tensor_ref, batch)[source]

Return a backward hook that prints debug info and halts on NaN gradients.

mxtaltools.common.geometry_utils.norm_circular_components(components: tensor)[source]

Use Pythagoras to norm the sum of squares to the unit circle. :param components: :type components: torch.tensor(n, 2)

Returns:

normed_components

Return type:

torch.tensor(n, 2)

mxtaltools.common.geometry_utils.probe_compute_molecule_volume(atom_types: LongTensor, pos: FloatTensor, batch: LongTensor, num_graphs: int, vdw_radii_tensor: Tensor, probes_per_mol: int = 100, eps: float = 0.01, max_iters: int = 1000, min_iters: int = 5)[source]

Estimate molecular vdW volumes via Monte Carlo probe sampling, iterating until convergence.

Random probes are scattered within each molecule’s bounding box; the fraction inside any atomic vdW sphere gives the volume estimate. Runs until the relative change in the running mean drops below eps for min_iters consecutive iterations.

Parameters:
  • atom_types (torch.LongTensor(n))

  • pos (torch.FloatTensor(n, 3))

  • batch (torch.LongTensor(n))

  • num_graphs (int)

  • vdw_radii_tensor (torch.Tensor) – Lookup table of vdW radii indexed by atomic number

  • probes_per_mol (int) – Number of random probes per molecule per iteration

  • eps (float) – Convergence threshold on relative change in running mean

  • max_iters (int) – Maximum number of sampling iterations

  • min_iters (int) – Minimum iterations before convergence is checked

Returns:

volumes – Converged volume estimates

Return type:

torch.FloatTensor(num_graphs)

mxtaltools.common.geometry_utils.rotmat2rotvec(rotation_matrix_list, warn_on_bad_determinant=True)[source]

Convert a batch of rotation matrices to rotation vectors (axis-angle).

Uses the Rodrigues formula: axis from the skew-symmetric part, angle from the trace. Degenerate cases (near-identity or near-pi rotations) are handled by clamping to a fallback rotation of pi around [1,1,1].

Parameters:
  • rotation_matrix_list (torch.FloatTensor(n, 3, 3))

  • warn_on_bad_determinant (bool) – Print a warning if any matrix has det < 0 (i.e. is a reflection, not a rotation)

Returns:

rotvecs

Return type:

torch.FloatTensor(n, 3)

mxtaltools.common.geometry_utils.rotvec2rotmat(mol_rotation: tensor, basis='cartesian')[source]

get applied rotation matrix mol_rotation here is a list of rotation vectors [n_samples, 3] rotvec -> rotation matrix directly (see https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula) rotvec -> quat -> rotation matrix (old way)

mxtaltools.common.geometry_utils.safe_batched_eigh(covs, chunk=10000)[source]

Chunked symmetric eigendecomposition with a CPU fallback for CUSOLVER failures.

Parameters:
  • covs (torch.FloatTensor(n, d, d)) – Batch of symmetric matrices

  • chunk (int) – Number of matrices per kernel call

Returns:

  • eigenvalues (torch.FloatTensor(n, d))

  • eigenvectors (torch.FloatTensor(n, d, d))

mxtaltools.common.geometry_utils.sample_random_valid_rotvecs(num_samples)[source]

Sample uniformly random rotation vectors with theta restricted to the upper half-sphere (z ≥ 0).

Directions are drawn from a Gaussian (giving uniform spherical coverage) and norms are drawn uniformly from (0, 2π].

Parameters:

num_samples (int)

Returns:

rotvecs

Return type:

torch.FloatTensor(num_samples, 3)

mxtaltools.common.geometry_utils.scatter_compute_Ip(all_coords, batch, eps: float = 0.05, add_noise: bool = False)[source]

Parallel function to compute inertial for a list of unequal sized sets of coordinates.

Parameters:
  • all_coords (torch.tensor(n,3))

  • batch (torch.tensor(n))

Returns:

  • Ip (torch.tensor(num_graphs, 3, 3))

  • Ipm (torch.tensor(num_graphs, 3))

  • I (torch.tensor(num_graphs, 3, 3))

mxtaltools.common.geometry_utils.simple_latent_distance(l1: Tensor, l2: Tensor) Tensor[source]

euclidean distances, but with wrapped angular dimensions

mxtaltools.common.geometry_utils.single_molecule_principal_axes_torch(coords: tensor, masses=None, return_direction=False)[source]

Compute the principal inertial axes for a given set of particle coordinates, ignoring particle mass.

Use our overlap rules to ensure a fixed direction for all axes under almost all circumstances, excepting e.g., certain symmetric molecules.

Parameters:
  • coords (torch.tensor(n,3))

  • masses (None) – not used

  • return_direction (bool) – whether to add the direction between centroid and most distant coordinate to the output

Returns:

  • Ip (torch.tensor(3,3)) – Principal inertial axes.

  • Ipm (torch.tensor(3)) – Principal inertial moments

  • I (torch.tensor(3)) – Inertial tensor in original frame

mxtaltools.common.geometry_utils.sph2cart_rotvec(angles)[source]

Transform from axis-angle in polar coordinates to rotation vector

Parameters:

angles ((nx3)) – theta, phi, r

Returns:

rotvec – x, y, z

Return type:

(nx3)

mxtaltools.common.geometry_utils.sph_rotvec2lat(sph_rotvec, z_prime)[source]

Map spherical rotation vectors to latent orientation parameters normalized to [-1, 1].

Inverse of lat2sph_rotvec.

Parameters:
  • sph_rotvec (torch.FloatTensor(..., z_prime * 3)) – Spherical rotation vectors [theta, phi, r] for each asymmetric unit

  • z_prime (int) – Number of asymmetric units

Returns:

lat_orientations

Return type:

torch.FloatTensor(…, z_prime * 3)