mxtaltools.common.adaptive_batching
- mxtaltools.common.adaptive_batching.adaptive_batched_analysis(batch, analyses: list | str, state: dict, *, initial_batch_size: int = 1000, max_batch_size: int = 100000, grow_factor: float = 0.01, shrink_factor: float = 0.65, oom_sleep: float = 0.1, return_state: bool = False, device='cuda', show_tqdm: bool = False, **kwargs)[source]
Run batch.analyze(analysis_name, assign_outputs=True, **kwargs) over the full batch using adaptive mini-batches to handle GPU OOM gracefully.
- Parameters:
.analyze(). (batch Any batch object with .batch_to_list() and)
batch.analyze(). (analysis_name Passed as the first argument to)
batch_size (state Mutable dict owned by the caller; used to carry) – across retries within a single call. Pass a fresh {} each call if you don’t want persistence across calls.
predictor (**kwargs Forwarded to batch.analyze() (e.g.)
temperature).
- Return type:
Collated batch object with outputs assigned.