Source code for mxtaltools.models.task_models.crystal_models

import torch
from torch_geometric.typing import OptTensor

from mxtaltools.models.graph_models.base_graph_model import BaseGraphModel
from mxtaltools.models.graph_models.molecule_graph_model import MolecularCrystalGraphModel


[docs] class MolecularCrystalModel(BaseGraphModel):
[docs] def __init__(self, seed, config, atom_features: list, molecule_features: list, output_dim: int, node_standardization_tensor: OptTensor = None, graph_standardization_tensor: OptTensor = None): """ wrapper for molecule model, with appropriate I/O """ super(MolecularCrystalModel, self).__init__() torch.manual_seed(seed) self.get_data_stats(atom_features, molecule_features, node_standardization_tensor, graph_standardization_tensor) self.model = MolecularCrystalGraphModel( input_node_dim=self.n_atom_feats, num_mol_feats=self.n_mol_feats, output_dim=output_dim, seed=seed, concat_mol_to_node_dim=True, activation=config.activation, fc_config=config.fc, graph_config=config.graph, )
[docs] def forward(self, crystal_batch, return_dists=False, return_latent=False, force_edges_rebuild=False): """overwrites base method""" # on the fly atom property embeddings crystal_batch = self.featurize_input_graph(crystal_batch) # on the fly input standardization crystal_batch = self.standardize(crystal_batch) return self.model(crystal_batch.x, crystal_batch.pos, crystal_batch.batch, crystal_batch.ptr, crystal_batch.mol_x, crystal_batch.num_graphs, crystal_batch.aux_ind, crystal_batch.mol_ind, edges_dict=crystal_batch.edges_dict if hasattr(crystal_batch, 'edges_dict') else None, return_dists=return_dists, return_latent=return_latent, force_edges_rebuild=force_edges_rebuild,)
# class HierarchicalCrystalDiscriminator(nn.Module): # def __init__(self, seed, config, n_atom_types, input_depth): # ''' # wrapper for molecule model, with appropriate I/O # ''' # torch.manual_seed(seed) # super(HierarchicalCrystalDiscriminator, self).__init__() # self.conditioner = molecule_encoder # self.gnn = point_graph # self.crystal_graph=PeriodicCrystalGraph() # def symmetrize(self): # def forward(self, data,): # # if crystalline, encode the data object, then generate pattern it according to the crystal symmetry # # then evaluate material graph # # if it's some bulk, thing, just encode all the separate molecules and evaluate the resulting material graph # return self.model(data, return_dists=return_dists, return_latent=return_latent) # # deprecated # # class MolCrystal(BaseGraphModel): # def __init__(self, seed, config, # atom_features: list, # molecule_features: list, # node_standardization_tensor: OptTensor = None, # graph_standardization_tensor: OptTensor = None): # """ # wrapper for molecule model, with appropriate I/O # """ # super(MolCrystal, self).__init__() # # torch.manual_seed(seed) # self.get_data_stats(atom_features, # molecule_features, # node_standardization_tensor, # graph_standardization_tensor) # # self.model = MoleculeGraphModel( # input_node_dim=self.n_atom_feats, # num_mol_feats=self.n_mol_feats, # output_dim=2 + 1, # 2 for classification and 1 for distance regression # seed=seed, # graph_aggregator=config.graph_aggregator, # concat_pos_to_node_dim=False, # concat_mol_to_node_dim=True, # concat_crystal_to_node_dim=False, # activation=config.activation, # fc_config=config.fc, # graph_config=config.graph, # periodize_inside_nodes=True, # outside_convolution_type=config.periodic_convolution_type # ) # # def forward(self, data, return_dists=False, return_latent=False): # """overwrites base method""" # # on the fly atom property embeddings # data = self.featurize_input_graph(data) # # on the fly input standardization # data = self.standardize(data) # # outputs = self.model(data, return_dists=return_dists, return_latent=return_latent) # # if isinstance(outputs, tuple): # if we have extra outputs, pick out the actual model output and adjust # model_outputs, extra_outputs = outputs # return (torch.cat([model_outputs[:, :2], F.softplus(model_outputs[:, -1, None])], dim=1), extra_outputs) # # else: # return torch.cat([outputs[:, 0], F.softplus(outputs[:, 1])], dim=1)