Source code for mxtaltools.models.task_models.embedding_regression_models

from typing import Optional

import torch

from mxtaltools.models.graph_models.base_graph_model import BaseGraphModel
from mxtaltools.models.modules.components import vectorMLP, scalarMLP


[docs] class EquivariantEmbeddingRegressor(BaseGraphModel): """single property prediction head for pretrained embeddings""" def __init__(self, seed, config, num_targets: int = 1, conditions_dim: int = 0, prediction_type: str = 'scalar'): super(EquivariantEmbeddingRegressor, self).__init__() torch.manual_seed(seed) self.prediction_type = prediction_type self.output_dim = int(1 * num_targets) # regression model self.model = vectorMLP(layers=config.num_layers, filters=config.hidden_dim, norm=config.norm, dropout=config.dropout, input_dim=config.bottleneck_dim + conditions_dim, output_dim=self.output_dim, vector_input_dim=config.bottleneck_dim, vector_output_dim=self.output_dim, seed=seed, vector_norm=config.vector_norm )
[docs] def forward(self, x: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """no need to do standardization, inputs are raw outputs from autoencoder model """ x, v = self.model(x=x, v=v) return x, v
[docs] class InvariantEmbeddingRegressor(BaseGraphModel): """single property prediction head for pretrained embeddings""" def __init__(self, seed, config, num_targets: int = 1, conditions_dim: int = 0, target_standardization_tensor: Optional[torch.Tensor] = None, ): super(InvariantEmbeddingRegressor, self).__init__() torch.manual_seed(seed) self.output_dim = int(1 * num_targets) if target_standardization_tensor is not None: self.register_buffer('target_mean', target_standardization_tensor[0]) self.register_buffer('target_std', target_standardization_tensor[1]) else: self.register_buffer('target_mean', torch.ones(1)[0]) self.register_buffer('target_std', torch.ones(1)[0]) # regression model self.model = scalarMLP(layers=config.num_layers, filters=config.hidden_dim, norm=config.norm, dropout=config.dropout, input_dim=config.bottleneck_dim + conditions_dim, output_dim=self.output_dim, seed=seed, )
[docs] def forward(self, x: torch.Tensor, ) -> torch.Tensor: """no need to do standardization, inputs are raw outputs from autoencoder model """ x = self.model(x=x) return x