Source code for mxtaltools.common.ellipsoid_ops

import torch
from scipy.spatial.transform import Rotation as R
from torch.nn import functional as F
from tqdm import tqdm

from mxtaltools.common.geometry_utils import compute_ellipsoid_volume, compute_cosine_similarity_matrix


[docs] def sample_triangular_right(n_samples, start, stop, device='cpu'): """ sample from the CDF of a uniform distribution the right-aligned triangular distribution """ U = torch.rand(n_samples, device=device) return start + (stop - start) * torch.sqrt(U)
[docs] def generate_random_eigenvalues(num_samples, device): """ Generate ellipsoid eigenvector lengths according to the rule 1 <= a <= b <= c :return: """ rands = torch.rand(num_samples, device=device).clip(min=0.5) b_rands = sample_triangular_right(num_samples, 0, 1.1, device).clip(min=0.1, max=1) c_rands = sample_triangular_right(num_samples, 0, 1.1, device).clip(min=0.1, max=1) a = rands b = a * b_rands c = b * c_rands return torch.stack([a, b, c]).T
[docs] def generate_random_rotations(num_samples, device): return torch.tensor(R.random(num_samples).as_matrix(), device=device, dtype=torch.float32)
[docs] def compute_cosine_similarity_matrix(e1, e2): """ compute the row-to-row dot products for batches of ellipsoids [n, i, j] returns the [n, i, i] matrix of dot products between rows in e1 and e2 permuting the order of e1, e2, results in transposition of the overlap matrix :param e1: :param e2: :return: """ return torch.einsum('nij, nkj -> nik', e1, e2)
[docs] def trailing_relative_diffs(record, eps: float = 1e-5): cum_vols = torch.cumsum(record, dim=0) cum_iters = torch.arange(1, len(record) + 1, device=record.device) cum_means = cum_vols / cum_iters rel_diffs = torch.diff(cum_means / (cum_means.mean() + eps)) return rel_diffs
[docs] def llipsoid_volume(e): return 4 / 3 * torch.pi * e.norm(dim=-1).prod(dim=-1)
[docs] def compute_ellipsoid_overlap(e1, e2, v1_true, v2_true, r, num_probes: int = 100000, eps: float = 1e-3, max_iters: int = 1000, min_iters: int = 10, show_tqdm: bool = False, ): """ Compute the volume of the overlapping region between ellipsoids defined by e1, e2 Done by random resampling until satisfactory convergence is achieved probe the region occupied by both ellipsoids, and check if points are within each :param num_probes: :param min_iters: :param eps: :param max_iters: :param e1: [i, j] of row-wise eigenvectors defining ellipsoid 1 :param e2: [i, j] of row-wise eigenvectors defining ellipsoid 1 :param r: separation vector between e1 centroids and e2 centroids :return: """ if (e1.norm(dim=1).max() + e2.norm(dim=1).max()) <= r.norm(): # impossible for them to intersect return 0, True assert len(e1) == len(e2) == len(r) device = e1.device # e1 stays on origin # move e2 to point r A1 = torch.linalg.inv(e1.T @ e1) # Metric tensor A2 = torch.linalg.inv(e2.T @ e2) # Metric tensor c1 = torch.zeros(3, device=device)[None, ...] c2 = r[None, ...] centers = torch.cat([c1, c2], dim=0) axes = torch.stack([e1, e2]) # shape: (2, 3, 3) # Compute corners: c ± each axis vector extents = axes.abs().sum(dim=1) # shape: (2, 3), max extent in each dim mins = centers - extents # shape: (2, 3) maxs = centers + extents # Bounding box that contains both ellipsoids bbox_min = mins.min(dim=0).values # shape: (3,) bbox_max = maxs.max(dim=0).values # shape: (3,) volume_record = [] converged = False iter = 0 with tqdm(total=max_iters, desc="Optimizing", unit="iter", disable=not show_tqdm) as pbar: while not converged and iter < max_iters: iter += 1 points = bbox_min + (bbox_max - bbox_min) * torch.rand((num_probes, 3), device=device) bounding_volume = torch.prod(bbox_max - bbox_min) # check which points are inside each ellipsoid dx1 = points - c1 in_e1 = (torch.einsum('ni,ij,nj->n', dx1, A1, dx1) <= 1).float() dx2 = points - c2 in_e2 = (torch.einsum('ni,ij,nj->n', dx2, A2, dx2) <= 1).float() # Estimate ellipsoid volume v1 = bounding_volume * in_e1.mean() v2 = bounding_volume * in_e2.mean() v_ov = bounding_volume * (in_e1 * in_e2).mean() volume_record.append([float(v1), float(v2), float(v_ov)]) if iter > min_iters: recs = torch.tensor(volume_record) v1_rec = recs[:, 0] v2_rec = recs[:, 1] ov_rec = recs[:, 2] # stop iterating when the mean estimates are stable # AND when the single volume estimates are accurate # stable overlap estimate v1_relative_diffs = trailing_relative_diffs(v1_rec) v2_relative_diffs = trailing_relative_diffs(v2_rec) ov_relative_diffs = trailing_relative_diffs(ov_rec) criteria1 = v1_relative_diffs[-min(10, min_iters):].abs().mean() criteria2 = v2_relative_diffs[-min(10, min_iters):].abs().mean() criteria3 = ov_relative_diffs[-min(10, min_iters):].abs().mean() criteria4 = ((v1_rec.mean() - v1_true).abs() / (v1_true + 1e-5)) criteria5 = ((v2_rec.mean() - v2_true).abs() / (v2_true + 1e-5)) criteria = torch.tensor([criteria1, criteria2, criteria3, criteria4, criteria5]) if torch.all(criteria < eps): # relative change in running average less than eps converged = True pbar.set_description(f"Iter {iter} | loss={criteria.mean():.4f}") pbar.update(1) if not converged: aa = 1 return ov_rec.mean(), converged """ # watch mean convergence record = ov_rec cum_vols = torch.cumsum(record, dim=0) cum_iters = torch.arange(1, len(record) + 1, device=record.device) cum_means = cum_vols / cum_iters rel_diffs = torch.diff(cum_means / (cum_means.mean() + eps)) import plotly.graph_objects as go go.Figure(go.Scatter(y=rel_diffs.abs().cpu().detach())).show() """