Examples
MXtalTools includes several analysis utilities for molecular crystal tasks.
Runnable example scripts are in the examples/ directory.
Crystal Density Prediction
Predict the crystal packing coefficient from molecular SMILES using a pre-trained regressor. The packing coefficient is defined as \(c_p = V_{mol} / V_{aunit}\), where \(V_{mol}\) is the molecule volume and \(V_{aunit} = V_{cell} / Z\) is the asymmetric unit volume.
The model is trained on CSD data and achieves good accuracy for CSD-like molecules. Uncertainty can be estimated via dropout-based MC sampling.
See examples/crystal_density_prediction.py:
"""configs"""
device = 'cpu'
checkpoint = Path(r"../checkpoints/cp_regressor.pt")
num_samples = 50
"""
First we load some molecules.
Here we are having RDKit build and minimize molecules generated from some SMILES codes.
This is automated in our :class:`MolData` class.
One can also directly input the atom types and coordinates.
"""
base_molData = MolData()
num_mols = len(test_smiles)
mols = [base_molData.from_smiles(test_smiles[ind],
compute_partial_charges=True,
minimize=True,
protonate=True,
) for ind in range(num_mols)]
mols = [mol for mol in mols if mol is not None] # sometimes the embedding fails
mol_batch = collate_data_list(mols).to(device)
"""load model"""
model = load_molecule_scalar_regressor(
checkpoint,
device
)
"""
Now, we make predictions using the model, which we
can easily convert from packing coefficient, to `V_{aunit}`, to the density.
"""
with torch.no_grad():
"""predict crystal packing coefficient - single-point"""
packing_coeff_pred = model(mol_batch).flatten() * model.target_std + model.target_mean
aunit_volume_pred = mol_batch.mol_volume / packing_coeff_pred # A^3
density_pred = mol_batch.mass / aunit_volume_pred * 1.6654 # g/cm^3
"""get prediction with uncertainty via resampling with dropout"""
predictions = []
model = enable_dropout(model)
for _ in range(num_samples):
predictions.append(model(mol_batch).flatten() * model.target_std + model.target_mean)
predictions = torch.stack(predictions)
packing_coeff_mean = predictions.mean(0)
packing_coeff_std = predictions.std(0)
Molecule Encoding
Encode molecules into equivariant vector and scalar representations using a pre-trained Mo3ENet autoencoder. The model has been trained on QM9-like molecules (up to 9 heavy atoms, containing H, C, N, O, F). Performance may degrade for fluorine-rich or highly symmetric molecules.
See examples/molecule_autoencoder.py:
"""configs"""
device = 'cpu'
checkpoint = Path(r"../checkpoints/autoencoder.pt")
"""
First we load some molecules, and ensure they are each centered on the origin.
Here we are having RDKit build and minimize molecules generated
from some SMILES codes. This is automated in our MolData class.
One can also directly input the atom types and coordinates.
"""
base_molData = MolData()
num_mols = len(test_smiles)
mols = [base_molData.from_smiles(test_smiles[ind],
compute_partial_charges=True,
minimize=True,
protonate=True,
) for ind in range(num_mols)]
mols = [mol for mol in mols if mol is not None] # sometimes the embedding fails
mol_batch = collate_data_list(mols).to(device)
mol_batch.recenter_molecules()
"""load pre-trained model"""
model = load_molecule_autoencoder(
checkpoint,
device
)
with torch.no_grad():
"""get vector and scalar embeddings"""
vector_encoding = model.encode(mol_batch.clone())
scalar_encoding = model.scalarizer(vector_encoding)
"""
Check the quality of the embedding for this batch of
molecules, and visualize the reconstruction
"""
reconstruction_loss, rmsd, matched_molecule = (
model.check_embedding_quality(mol_batch, visualize=True))
Crystal Analysis & Scoring
Build and analyze molecular crystals, including Lennard-Jones potential evaluation, radial distribution functions, and a CSD-trained crystal score model.
See examples/crystal_analysis.py:
device = 'cpu'
mini_dataset_path = '../mini_datasets/mini_CSD_dataset.pt'
checkpoint = r"../checkpoints/crystal_score.pt"
space_groups_to_sample = ["P1", "P-1", "P21/c", "C2/c", "P212121"]
sym_info = init_sym_info()
"load and batch example crystals"
example_crystals = torch.load(mini_dataset_path)
crystal_batch = collate_data_list(example_crystals[:10])
"""
A core function of our code is crystal parameterization and construction,
and so we show a simple example of building crystals, starting from the same
molecules as before, but with random space groups and lattice parameters.
"""
# initialize prior distribution
crystal_batch2 = crystal_batch.detach().clone()
# pick space groups to sample
sgs_to_build = np.random.choice(space_groups_to_sample,
replace=True,
size=crystal_batch.num_graphs)
sg_rand_inds = torch.tensor(
[list(sym_info['space_groups'].values()).index(SG) + 1 for SG in sgs_to_build],
dtype=torch.long,
device=device) # indexing from 0
# assign SG info to crystals - critical to do this before resampling cell parameters
crystal_batch2.reset_sg_info(sg_rand_inds)
# sample random cell parameters
crystal_batch2.sample_random_crystal_parameters()
"""load crystal score model"""
model = load_crystal_score_model(checkpoint, device)
"""
And proceed to analyzing both sets of crystals.
We present here a very basic analysis, computing a very basic
Lennard-Jones-type and short-range electrostatic potential.
We also show the outpudts of the crystal scoring model,
(1) it's classification confidence between "real" CSD samples
and "fake" samples, not from the CSD, and
(2) the predicted distance in RDF space from the given crystal
to the "correct" crystal for the given molecule.
"""
lj_pot, es_pot, scaled_lj_pot, cluster_batch = (
crystal_batch.build_and_analyze(return_cluster=True))
model_output = model(cluster_batch)
model_score = softmax_and_score(model_output[:, :2])
rdf_dist_pred = F.softplus(model_output[:, 2])
packing_coeff = crystal_batch.packing_coeff
cluster_batch.visualize([1, 2, 3, 4], mode='convolve with')
lj_pot2, es_pot2, scaled_lj_pot2, cluster_batch2 = (
crystal_batch2.build_and_analyze(return_cluster=True))
model_output2 = model(cluster_batch2)
model_score2 = softmax_and_score(model_output2[:, :2])
rdf_dist_pred2 = F.softplus(model_output2[:, 2])
packing_coeff2 = crystal_batch2.packing_coeff
cluster_batch2.visualize([1, 2, 3, 4], mode='convolve with')
print("Finished Crystal Analysis!")