Source code for mxtaltools.common.adaptive_batching

import gc
import time
from typing import Union

import numpy as np
import torch
from tqdm import tqdm

from mxtaltools.common.utils import is_cuda_oom
from mxtaltools.dataset_utils.utils import collate_data_list


[docs] def adaptive_batched_analysis( batch, analyses: Union[list, str], state: dict, *, initial_batch_size: int = 1000, max_batch_size: int = 100_000, grow_factor: float = 0.01, shrink_factor: float = 0.65, oom_sleep: float = 0.1, return_state: bool = False, device = 'cuda', show_tqdm: bool = False, **kwargs, ): """ Run batch.analyze(analysis_name, assign_outputs=True, **kwargs) over the full batch using adaptive mini-batches to handle GPU OOM gracefully. Parameters ---------- batch Any batch object with .batch_to_list() and .analyze(). analysis_name Passed as the first argument to batch.analyze(). state Mutable dict owned by the caller; used to carry batch_size across retries within a single call. Pass a fresh {} each call if you don't want persistence across calls. **kwargs Forwarded to batch.analyze() (e.g. predictor, temperature). Returns ------- Collated batch object with outputs assigned. """ if not hasattr(state, 'batch_size'): state["batch_size"] = initial_batch_size if isinstance(analyses, str): analyses = [analyses] data_list = batch.batch_to_list() n_samples = len(data_list) outputs_list = [None] * n_samples cursor = 0 already_oomed = False pbar = tqdm(total=len(data_list), disable=not show_tqdm) while cursor < n_samples: inds = np.arange(cursor, min(n_samples, cursor + state["batch_size"])) sub_batch = collate_data_list([data_list[i] for i in inds]) sub_batch = sub_batch.to(device) try: sub_batch.analyze(analyses, assign_outputs=True, **kwargs) outputs_list[cursor: cursor + len(inds)] = sub_batch.cpu().batch_to_list() cursor += len(inds) pbar.update(len(inds)) if ( state["batch_size"] <= max_batch_size and state["batch_size"] < n_samples and not already_oomed ): state["batch_size"] += max(int(state["batch_size"] * grow_factor), 1) except (RuntimeError, ValueError) as e: if is_cuda_oom(e): if state["batch_size"] == 1: raise RuntimeError( "Cascading OOM failure: batch_size already 1" ) from e state["batch_size"] = max(int(state["batch_size"] * shrink_factor), 1) gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() already_oomed = True time.sleep(oom_sleep) # retry same cursor else: raise gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() pbar.close() if return_state: return collate_data_list(outputs_list), state else: return collate_data_list(outputs_list)