Source code for mxtaltools.common.ase_interface

import numpy as np
import torch
from ase import Atoms
from ase.build import niggli_reduce
from ase.geometry import cell_to_cellpar
from ase.io import read
from ase.spacegroup import crystal as ase_crystal
from ase.visualize import view


[docs] def data_batch_to_ase_mols_list(crystaldata_batch, max_ind: int = np.inf, specific_inds=None, show_mols: bool = False, **kwargs): """ Helper function for converting Crystaldata batches into lists of ase objects. Parameters ---------- crystaldata_batch : batch of Crystaldata objects max_ind : int Maximum batch ind to include in the list show_mols : bool whether to visualize the list of mol objects Returns ------- list of ase mol objects """ if hasattr(crystaldata_batch, 'num_graphs'): num_graphs = crystaldata_batch.num_graphs else: num_graphs = 1 if specific_inds is None: mols = [ase_mol_from_crystaldata(crystaldata_batch, ii, **kwargs) for ii in range(min(max_ind, num_graphs))] else: mols = [ase_mol_from_crystaldata(crystaldata_batch, ii, **kwargs) for ii in specific_inds] if show_mols: view(mols) return mols
[docs] def ase_mol_from_crystaldata(crystal_batch, index: int = None, highlight_canonical_conformer: bool = False, mode=None, cutoff: float = 4, return_crystal: bool = False, highlight_mol_ind: bool=False): """ Extract an atomic structure from a Crystaldata object according to its batch index, and convert it into an ase mol object. Several options for visualization of crystals. Parameters ---------- crystal_batch : Crystaldata object index : int, optional Batch index of Crystaldata object to extract. Default is None. highlight_canonical_conformer : bool, optional Whether to give the canonical conformer a different atom type for color comparison. mode : 'conformer', 'unit cell', 'inside cell', 'convolve with', 'distance', or None (default None) Assuming the input Crystaldata is a molecule cluster larger than a single unit cell, we have several options to pare down to the desired visualization. 'conformer' : only the 'canonical conformer' 'unit cell' : all molecules with centroids inside the unit cell 'inside cell' : all atoms inside the unit cell 'convolve with' : all atoms within convolution range of the canonical conformer 'distance' : all atoms within a certain distance of the canonical conformer None : show all atoms in the crystal cutoff : float, optional Distance fed to the 'distance' exclusion level option return_crystal : bool, optional Whether to return an ase mol for the crystal. Does not always work properly. Ase does not understand our crystals correctly. Returns ------- ase mol object """ crystal_batch = crystal_batch.clone().cpu().detach() if crystal_batch.batch is not None: # more than one crystal in the datafile atom_inds = torch.where(crystal_batch.batch == index)[0] else: atom_inds = torch.arange(len(crystal_batch.z)) if mode == 'conformer': # only the canonical conformer itself inside_inds = torch.where(crystal_batch.aux_ind == 0)[0] new_atom_inds = torch.stack([ind for ind in atom_inds if ind in inside_inds]) atom_inds = new_atom_inds coords = crystal_batch.pos[atom_inds].cpu().detach().numpy() elif mode == 'unit cell': atom_inds = torch.argwhere(crystal_batch.unit_cell_batch == index).flatten() coords=crystal_batch.unit_cell_pos[atom_inds] elif mode == 'inside cell': fractional_coords = torch.inner(torch.linalg.inv(crystal_batch.T_fc[index]), crystal_batch.pos[crystal_batch.batch == index]).T inside_coords = torch.prod((fractional_coords < 1) * (fractional_coords > 0), dim=-1) inside_inds = torch.where(inside_coords)[0] inside_inds += crystal_batch.ptr[index] atom_inds = inside_inds coords = crystal_batch.pos[inside_inds].cpu().detach().numpy() elif mode == 'convolve with': # atoms potentially in the convolutional field inside_inds = torch.where(crystal_batch.aux_ind < 2)[0] new_atom_inds = torch.stack([ind for ind in atom_inds if ind in inside_inds]) atom_inds = new_atom_inds coords = crystal_batch.pos[atom_inds].cpu().detach().numpy() elif mode == 'distance': # atoms within a certain distance of the conformer radius crystal_coords = crystal_batch.pos[atom_inds] crystal_inds = crystal_batch.aux_ind[atom_inds] canonical_conformer_inds = torch.where(crystal_inds == 0)[0] mol_centroid = crystal_coords[canonical_conformer_inds].mean(0) mol_radius = torch.max(torch.cdist(mol_centroid[None], crystal_coords[canonical_conformer_inds], p=2)) in_range_inds = \ torch.where((torch.cdist(mol_centroid[None], crystal_coords, p=2) < (mol_radius + cutoff))[0])[0] atom_inds = atom_inds[in_range_inds] coords = crystal_coords[in_range_inds].cpu().detach().numpy() else: coords = crystal_batch.pos[atom_inds].cpu().detach().numpy() if highlight_canonical_conformer: # highlight the atom aux index numbers = crystal_batch.aux_ind[atom_inds].cpu().detach().numpy() + 6 elif highlight_mol_ind: numbers = crystal_batch.mol_ind[atom_inds].cpu().detach().numpy() + 6 elif mode == 'unit cell': start = crystal_batch.ptr[index] stop = start + crystal_batch.num_atoms[index] numbers = crystal_batch.z[start:stop].repeat(crystal_batch.sym_mult[index]).cpu().detach().numpy() else: numbers = crystal_batch.z[atom_inds].cpu().detach().numpy() if hasattr(crystal_batch, "T_fc"): if index is not None: try: cell = crystal_batch.T_fc[index].T.cpu().detach().numpy() except IndexError: cell = crystal_batch.T_fc[0].T.cpu().detach().numpy() else: cell = crystal_batch.T_fc[0].T.cpu().detach().numpy() mol = Atoms(symbols=numbers, positions=coords, cell=cell) else: mol = Atoms(symbols=numbers, positions=coords) cell = None if return_crystal: cry = ase_crystal(symbols=mol, cell=cell, setting=1, spacegroup=int(crystal_batch.sg_ind[index])) return mol, cry else: return mol
[docs] def get_niggli_cell(crystal_batch, index, radians: bool=False): mol = ase_mol_from_crystaldata(crystal_batch, index=index) #mol.info['spacegroup'] = Spacegroup(int(original_cluster_batch.sg_ind[ind]), setting=1) mol.write('temp.cif') atoms = read("temp.cif") niggli_reduce(atoms) return cell_to_cellpar(atoms.cell, radians=radians)
[docs] def ase_write_cif(batch, inds, path, mode): for ind in inds: mol = ase_mol_from_crystaldata(batch, ind, mode=mode) cif_path = f"{path}_{ind}.cif" mol.write(cif_path) with open(cif_path, 'r') as f: content = f.read() content = content.replace('data_image0', f'{path}_{ind}', 1) with open(cif_path, 'w') as f: f.write(content)