mxtaltools.common.training_utils

class mxtaltools.common.training_utils.OOMRetry(bs_ref, factor=0.75, min_bs=1, context='')[source]

Bases: object

Purely functional OOM handler. Expects a mutable reference to batch size, e.g. [batch_size]. On OOM: scales bs_ref[0] *= factor, cleans CUDA, retries.

mxtaltools.common.training_utils.check_convergence(test_record, history, convergence_eps, epoch, minimum_epochs, overfit_tolerance, train_record=None)[source]

check if we are converged condition: test loss has increased or levelled out over the last several epochs :return: convergence flag

mxtaltools.common.training_utils.enable_dropout(model)[source]

Enable dropout layers in evaluation mode.

mxtaltools.common.training_utils.flatten_wandb_params(config)[source]

Initialize “flat” config for wandb parameter logging

mxtaltools.common.training_utils.get_model_nans(model)[source]
mxtaltools.common.training_utils.get_n_config(model)[source]

count parameters for a pytorch model :param model: :return:

mxtaltools.common.training_utils.init_optimizer(model_name, optim_config, model, amsgrad=False, freeze_params=False)[source]

initialize optimizers @param optim_config: config for a given optimizer @param model: model with params to be optimized @param freeze_params: whether parameters without requires_grad should be frozen @return: optimizer

mxtaltools.common.training_utils.init_scheduler(optimizer, optimizer_config)[source]

initialize a series of LR schedulers

mxtaltools.common.training_utils.load_crystal_score_model(checkpoint_path, device)[source]

script to reload a regression model for molecule scalar properties

mxtaltools.common.training_utils.load_molecule_autoencoder(checkpoint_path, device)[source]
mxtaltools.common.training_utils.load_molecule_scalar_regressor(checkpoint_path, device)[source]

script to reload a regression model for molecule scalar properties

mxtaltools.common.training_utils.make_sequential_directory(yaml_path, workdir)[source]

make a new working directory labelled by the time & date hopefully does not overlap with any other workdirs :return:

mxtaltools.common.training_utils.reload_model(model, device, optimizer, path, reload_optimizer=False)[source]

load model and state dict from path includes fix for potential dataparallel issue

mxtaltools.common.training_utils.save_checkpoint(epoch: int, model: Module, optimizer, config: dict, save_path: str, dataDims: dict)[source]
Parameters:
  • epoch

  • model

  • optimizer

  • config

  • save_path

  • dataDims

mxtaltools.common.training_utils.set_lr(schedulers, optimizer, optimizer_config, err_tr, hit_max_lr, override_lr=None)[source]
mxtaltools.common.training_utils.slash_batch(train_loader, test_loader, slash_fraction)[source]
mxtaltools.common.training_utils.spoof_gpu_compute()[source]
mxtaltools.common.training_utils.spoof_gpu_memory()[source]

Dynamically allocate memory only when needed.

mxtaltools.common.training_utils.spoof_usage()[source]
mxtaltools.common.training_utils.update_stats_dict(dictionary: dict, keys, values, mode='append')[source]

Append/extend dict of key:list pairs or one at a time

Parameters:
  • dictionary

  • keys

  • values

  • mode ('append' or 'extend')

Return type:

updated_dictionary

mxtaltools.common.training_utils.weight_reset(m)[source]