mxtaltools.common.training_utils
- class mxtaltools.common.training_utils.OOMRetry(bs_ref, factor=0.75, min_bs=1, context='')[source]
Bases:
objectPurely functional OOM handler. Expects a mutable reference to batch size, e.g. [batch_size]. On OOM: scales bs_ref[0] *= factor, cleans CUDA, retries.
- mxtaltools.common.training_utils.check_convergence(test_record, history, convergence_eps, epoch, minimum_epochs, overfit_tolerance, train_record=None)[source]
check if we are converged condition: test loss has increased or levelled out over the last several epochs :return: convergence flag
- mxtaltools.common.training_utils.enable_dropout(model)[source]
Enable dropout layers in evaluation mode.
- mxtaltools.common.training_utils.flatten_wandb_params(config)[source]
Initialize “flat” config for wandb parameter logging
- mxtaltools.common.training_utils.get_n_config(model)[source]
count parameters for a pytorch model :param model: :return:
- mxtaltools.common.training_utils.init_optimizer(model_name, optim_config, model, amsgrad=False, freeze_params=False)[source]
initialize optimizers @param optim_config: config for a given optimizer @param model: model with params to be optimized @param freeze_params: whether parameters without requires_grad should be frozen @return: optimizer
- mxtaltools.common.training_utils.init_scheduler(optimizer, optimizer_config)[source]
initialize a series of LR schedulers
- mxtaltools.common.training_utils.load_crystal_score_model(checkpoint_path, device)[source]
script to reload a regression model for molecule scalar properties
- mxtaltools.common.training_utils.load_molecule_scalar_regressor(checkpoint_path, device)[source]
script to reload a regression model for molecule scalar properties
- mxtaltools.common.training_utils.make_sequential_directory(yaml_path, workdir)[source]
make a new working directory labelled by the time & date hopefully does not overlap with any other workdirs :return:
- mxtaltools.common.training_utils.reload_model(model, device, optimizer, path, reload_optimizer=False)[source]
load model and state dict from path includes fix for potential dataparallel issue
- mxtaltools.common.training_utils.save_checkpoint(epoch: int, model: Module, optimizer, config: dict, save_path: str, dataDims: dict)[source]
- Parameters:
epoch
model
optimizer
config
save_path
dataDims
- mxtaltools.common.training_utils.set_lr(schedulers, optimizer, optimizer_config, err_tr, hit_max_lr, override_lr=None)[source]
- mxtaltools.common.training_utils.spoof_gpu_memory()[source]
Dynamically allocate memory only when needed.