Examples

MXtalTools includes several analysis utilities for molecular crystal tasks. Runnable example scripts are in the examples/ directory.

Crystal Density Prediction

Predict the crystal packing coefficient from molecular SMILES using a pre-trained regressor. The packing coefficient is defined as \(c_p = V_{mol} / V_{aunit}\), where \(V_{mol}\) is the molecule volume and \(V_{aunit} = V_{cell} / Z\) is the asymmetric unit volume.

The model is trained on CSD data and achieves good accuracy for CSD-like molecules. Uncertainty can be estimated via dropout-based MC sampling.

See examples/crystal_density_prediction.py:

    """configs"""
    device = 'cpu'
    checkpoint = Path(r"../checkpoints/cp_regressor.pt")
    num_samples = 50

    """
    First we load some molecules.
    Here we are having RDKit build and minimize molecules generated from some SMILES codes.
    This is automated in our :class:`MolData` class.
    One can also directly input the atom types and coordinates.
    """
    base_molData = MolData()
    num_mols = len(test_smiles)
    mols = [base_molData.from_smiles(test_smiles[ind],
                                     compute_partial_charges=True,
                                     minimize=True,
                                     protonate=True,
                                     ) for ind in range(num_mols)]
    mols = [mol for mol in mols if mol is not None]  # sometimes the embedding fails
    mol_batch = collate_data_list(mols).to(device)

    """load model"""
    model = load_molecule_scalar_regressor(
        checkpoint,
        device
    )
    """
    Now, we make predictions using the model, which we 
    can easily convert from packing coefficient, to `V_{aunit}`, to the density.
    """
    with torch.no_grad():
        """predict crystal packing coefficient - single-point"""
        packing_coeff_pred = model(mol_batch).flatten() * model.target_std + model.target_mean
        aunit_volume_pred = mol_batch.mol_volume / packing_coeff_pred  # A^3
        density_pred = mol_batch.mass / aunit_volume_pred * 1.6654  # g/cm^3

        """get prediction with uncertainty via resampling with dropout"""
        predictions = []
        model = enable_dropout(model)
        for _ in range(num_samples):
            predictions.append(model(mol_batch).flatten() * model.target_std + model.target_mean)

        predictions = torch.stack(predictions)
        packing_coeff_mean = predictions.mean(0)
        packing_coeff_std = predictions.std(0)

Molecule Encoding

Encode molecules into equivariant vector and scalar representations using a pre-trained Mo3ENet autoencoder. The model has been trained on QM9-like molecules (up to 9 heavy atoms, containing H, C, N, O, F). Performance may degrade for fluorine-rich or highly symmetric molecules.

See examples/molecule_autoencoder.py:

    """configs"""
    device = 'cpu'
    checkpoint = Path(r"../checkpoints/autoencoder.pt")

    """
    First we load some molecules, and ensure they are each centered on the origin.
    Here we are having RDKit build and minimize molecules generated 
    from some SMILES codes. This is automated in our MolData class.
    One can also directly input the atom types and coordinates.
    """
    base_molData = MolData()
    num_mols = len(test_smiles)
    mols = [base_molData.from_smiles(test_smiles[ind],
                                     compute_partial_charges=True,
                                     minimize=True,
                                     protonate=True,
                                     ) for ind in range(num_mols)]
    mols = [mol for mol in mols if mol is not None]  # sometimes the embedding fails
    mol_batch = collate_data_list(mols).to(device)
    mol_batch.recenter_molecules()

    """load pre-trained model"""
    model = load_molecule_autoencoder(
        checkpoint,
        device
    )

    with torch.no_grad():
        """get vector and scalar embeddings"""
        vector_encoding = model.encode(mol_batch.clone())
        scalar_encoding = model.scalarizer(vector_encoding)

        """
        Check the quality of the embedding for this batch of
        molecules, and visualize the reconstruction
        """
        reconstruction_loss, rmsd, matched_molecule = (
            model.check_embedding_quality(mol_batch, visualize=True))

Crystal Analysis & Scoring

Build and analyze molecular crystals, including Lennard-Jones potential evaluation, radial distribution functions, and a CSD-trained crystal score model.

See examples/crystal_analysis.py:

    device = 'cpu'
    mini_dataset_path = '../mini_datasets/mini_CSD_dataset.pt'
    checkpoint = r"../checkpoints/crystal_score.pt"
    space_groups_to_sample = ["P1", "P-1", "P21/c", "C2/c", "P212121"]
    sym_info = init_sym_info()

    "load and batch example crystals"
    example_crystals = torch.load(mini_dataset_path)
    crystal_batch = collate_data_list(example_crystals[:10])

    """
    A core function of our code is crystal parameterization and construction, 
    and so we show a simple example of building crystals, starting from the same 
    molecules as before, but with random space groups and lattice parameters.
    """
    # initialize prior distribution
    crystal_batch2 = crystal_batch.detach().clone()
    # pick space groups to sample
    sgs_to_build = np.random.choice(space_groups_to_sample,
                                    replace=True,
                                    size=crystal_batch.num_graphs)
    sg_rand_inds = torch.tensor(
        [list(sym_info['space_groups'].values()).index(SG) + 1 for SG in sgs_to_build],
        dtype=torch.long,
        device=device)  # indexing from 0
    # assign SG info to crystals - critical to do this before resampling cell parameters
    crystal_batch2.reset_sg_info(sg_rand_inds)
    # sample random cell parameters
    crystal_batch2.sample_random_crystal_parameters()
    """load crystal score model"""
    model = load_crystal_score_model(checkpoint, device)

    """
    And proceed to analyzing both sets of crystals.
    We present here a very basic analysis, computing a very basic 
    Lennard-Jones-type and short-range electrostatic potential. 
    We also show the outpudts of the crystal scoring model, 
    (1) it's classification confidence between "real" CSD samples 
    and "fake" samples, not from the CSD, and 
    (2) the predicted distance in RDF space from the given crystal 
    to the "correct" crystal for the given molecule.
    """
    lj_pot, es_pot, scaled_lj_pot, cluster_batch = (
        crystal_batch.build_and_analyze(return_cluster=True))
    model_output = model(cluster_batch)
    model_score = softmax_and_score(model_output[:, :2])
    rdf_dist_pred = F.softplus(model_output[:, 2])
    packing_coeff = crystal_batch.packing_coeff
    cluster_batch.visualize([1, 2, 3, 4], mode='convolve with')

    lj_pot2, es_pot2, scaled_lj_pot2, cluster_batch2 = (
        crystal_batch2.build_and_analyze(return_cluster=True))
    model_output2 = model(cluster_batch2)
    model_score2 = softmax_and_score(model_output2[:, :2])
    rdf_dist_pred2 = F.softplus(model_output2[:, 2])
    packing_coeff2 = crystal_batch2.packing_coeff
    cluster_batch2.visualize([1, 2, 3, 4], mode='convolve with')

    print("Finished Crystal Analysis!")