mxtaltools.models.utils
- mxtaltools.models.utils.clean_cell_params(samples, sg_inds, lattice_means, lattice_stds, symmetries_dict, asym_unit_dict, rescale_asymmetric_unit=True, destandardize=False, mode='soft', fractional_basis='asymmetric_unit', skip_angular_dof=False)[source]
# todo deprecate An important function for enforcing physical limits on cell parameterization with randomly generated samples of different sources.
- Parameters:
skip_angular_dof
samples (torch.Tensor)
sg_inds (torch.LongTensor)
lattice_means (torch.Tensor)
lattice_stds (torch.Tensor)
symmetries_dict (dict)
asym_unit_dict (dict)
rescale_asymmetric_unit (bool)
destandardize (bool)
mode (str, "hard" or "soft")
fractional_basis (bool)
- mxtaltools.models.utils.clean_generator_output(samples=None, lattice_lengths=None, lattice_angles=None, mol_positions=None, mol_orientations=None, lattice_means=None, lattice_stds=None, destandardize=True, mode='soft', skip_angular_dof=False)[source]
# TODO rewrite - this is a very important function but it’s currently a disaster convert from raw model output to the actual cell parameters with appropriate bounds considering raw outputs to be in the standardized basis, we destandardize, then enforce bounds
- mxtaltools.models.utils.compute_prior_loss(norm_factors: Tensor, sg_inds: LongTensor, generator_raw_samples: Tensor, prior: Tensor, variation_factor: Tensor) tuple[Tensor, Tensor][source]
Take the norm of the scaled distances between prior and generated samples, and apply a quadratic penalty when it is larger than variation_factor :param data: :param generator_raw_samples: :param prior: :param variation_factor:
- mxtaltools.models.utils.compute_reduced_volume_fraction(cell_lengths: tensor, cell_angles: tensor, atom_radii: tensor, batch: tensor, crystal_multiplicity: tensor)[source]
# TODO DEPRECATE IN FAVOUR OF PACKING COEFFICIENT
- Parameters:
cell_lengths
cell_angles
atom_radii
crystal_multiplicity
Returns: asymmetric unit volume / sum of vdw volumes - so-called ‘reduced volume fraction’
- mxtaltools.models.utils.decode_to_sph_rotvec(mol_orientations)[source]
each angle is predicted with 2 params we bound the encodings for theta on 0-1 to restrict the range of theta to [0,pi/2]
- mxtaltools.models.utils.decode_to_sph_rotvec2(mol_orientation_components)[source]
# todo decide whether to use/keep or deprecate this each angle is predicted with 2 params we bound the encodings for theta on 0-1 to restrict the range of theta to [0,pi/2]
identical to the above, but considering theta as a simple scalar [n, 5] input to [n, 3] output
- mxtaltools.models.utils.denormalize_generated_cell_params(normed_cell_samples: FloatTensor, mol_data, asym_unit_dict: dict)[source]
- mxtaltools.models.utils.embed_crystal_list(batch_size: int, crystal_list: list, embedding_type: str, encoder_checkpoint_path: Optional = None, device: str | None = 'cpu', redo_crystal_analysis: bool | None = False) list[source]
- mxtaltools.models.utils.enforce_1d_bound(x: tensor, x_span, x_center, mode='soft')[source]
constrains function to range x_center plus/minus x_span :param x: :param x_span: :param x_center: :param mode:
- mxtaltools.models.utils.get_mol_embedding_for_proxy(crystal_batch, embedding_type, encoder: Optional = None)[source]
- mxtaltools.models.utils.norm_scores(score, tracking_features, dataDims)[source]
norm the incoming score according to some feature of the molecule (generally size)
- mxtaltools.models.utils.renormalize_generated_cell_params(generator_raw_samples, mol_data, asym_unit_dict)[source]
- mxtaltools.models.utils.softmax_and_score(raw_classwise_output, temperature=1, old_method=False, correct_discontinuity=True) Tensor | ndarray[source]
- Parameters:
raw_classwise_output (numpy array or torch tensor with dimension [n,2], representing the non-normalized [false,true] probabilities)
temperature (softmax temperature)
old_method (use more complicated method from first paper)
correct_discontinuity (correct discontinuity at 0 only in the old method)
- Returns:
score
- Return type:
linearizes the input probabilities from (0,1) to [-inf, inf] for easier visualization