Source code for mxtaltools.common.clustering

from mxtaltools.common.geometry_utils import simple_latent_distance
import torch
from tqdm import tqdm


[docs] def greedy_bottom_up_anchors(params, cps, ens, d_cut, e_cut): valid_density = (cps > 0.55) & (cps < 0.95) valid_energy = (ens <= e_cut) valid = valid_density & valid_energy en_sort_inds = torch.argsort(ens, descending=False).flatten() anchor_tensor = params[en_sort_inds[0:1]] anchors = [en_sort_inds[0].item()] for ind in tqdm(en_sort_inds): if not valid[ind]: continue sample = params[ind:ind + 1] diff = simple_latent_distance(anchor_tensor, sample) keep = (diff > d_cut).all() if keep: anchors.append(ind.item()) anchor_tensor = torch.cat([anchor_tensor, sample], dim=0) anchors = torch.tensor(anchors, dtype=torch.long, device=params.device) return anchors