Source code for mxtaltools.models.task_models.mol_classifier

from typing import Tuple, Optional

import torch

from mxtaltools.models.graph_models.base_graph_model import BaseGraphModel
from mxtaltools.models.graph_models.molecule_graph_model import MoleculeClusterModel


[docs] class MoleculeClusterClassifier(BaseGraphModel): def __init__(self, seed, config, output_dim, atom_features: list, molecule_features: list, node_standardization_tensor: torch.Tensor, graph_standardization_tensor: torch.Tensor, ): super(MoleculeClusterClassifier, self).__init__() torch.manual_seed(seed) self.get_data_stats(atom_features, molecule_features, node_standardization_tensor, graph_standardization_tensor) self.model = MoleculeClusterModel( input_node_dim=self.n_atom_feats, num_mol_feats=0, output_dim=output_dim, seed=seed, activation=config.activation, fc_config=config.fc, graph_config=config.graph, )
[docs] def forward(self, data_batch, return_dists: bool = False, return_latent: bool = False, return_embedding: bool = False, ) -> Tuple[torch.Tensor, Optional[dict]]: # featurize atom properties on the fly data_batch = self.featurize_input_graph(data_batch) # standardize on the fly from model-attached statistics data_batch = self.standardize(data_batch) return self.model(data_batch.x, data_batch.pos, data_batch.ptr, data_batch.mol_x, data_batch.num_graphs, data_batch.mol_ind, data_batch.T_fc, data_batch.edge_index, data_batch.edge_attr, return_dists=return_dists, return_latent=return_latent, return_embedding=return_embedding)
# deprecated # class PolymorphClassifier(BaseGraphModel): # def __init__(self, seed, config, # dataDims: dict, # atom_features: list, # molecule_features: list, # node_standardization_tensor: torch.tensor, # graph_standardization_tensor: torch.tensor, # ): # super(PolymorphClassifier, self).__init__() # # torch.manual_seed(seed) # self.get_data_stats(atom_features, # molecule_features, # node_standardization_tensor, # graph_standardization_tensor) # # self.model = MoleculeGraphModel( # input_node_dim=self.n_atom_feats, # num_mol_feats=0, # output_dim=dataDims['num_polymorphs'] + dataDims['num_topologies'], # seed=seed, # graph_aggregator='molwise', # activation=config.activation, # fc_config=config.fc, # graph_config=config.graph, # outside_convolution_type='none' # ) # # def forward(self, data, return_dists=False, return_latent=False, return_embedding=False, # skip_standardization=False): # if not skip_standardization: # data = self.standardize(data) # # return self.model(data, # return_dists=return_dists, # return_latent=return_latent, # return_embedding=return_embedding)