mxtaltools.common.adaptive_batching

mxtaltools.common.adaptive_batching.adaptive_batched_analysis(batch, analyses: list | str, state: dict, *, initial_batch_size: int = 1000, max_batch_size: int = 100000, grow_factor: float = 0.01, shrink_factor: float = 0.65, oom_sleep: float = 0.1, return_state: bool = False, device='cuda', show_tqdm: bool = False, **kwargs)[source]

Run batch.analyze(analysis_name, assign_outputs=True, **kwargs) over the full batch using adaptive mini-batches to handle GPU OOM gracefully.

Parameters:
  • .analyze(). (batch Any batch object with .batch_to_list() and)

  • batch.analyze(). (analysis_name Passed as the first argument to)

  • batch_size (state Mutable dict owned by the caller; used to carry) – across retries within a single call. Pass a fresh {} each call if you don’t want persistence across calls.

  • predictor (**kwargs Forwarded to batch.analyze() (e.g.)

  • temperature).

Return type:

Collated batch object with outputs assigned.