Source code for mxtaltools.models.modules.graph_convolution

import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import (
    Adj,
    OptTensor,
)

from mxtaltools.models.modules.augmented_softmax_aggregator import AugSoftmaxAggregation, VectorAugSoftmaxAggregation
from mxtaltools.models.modules.components import Normalization, Activation


[docs] class MConv(MessagePassing): """ Message passing layer with optional vector channel. Aggregation done via softmax operator. Message embedding via linear operator. """ def __init__( self, message_dim, node_dim, edge_embedding_dim, norm=None, activation_fn='gelu', ): super().__init__(aggr=AugSoftmaxAggregation(temperature=1, learn=True, bias=0.1, channels=message_dim)) self.in_channels = node_dim self.out_channels = node_dim self.edge_dim = edge_embedding_dim self.message_dim = message_dim '''initialize scalar transforms''' self.edge2message = nn.Linear(edge_embedding_dim, message_dim, bias=False) self.source_node2message = nn.Linear(node_dim, message_dim, bias=False) self.tgt_node2message = nn.Linear(node_dim, message_dim, bias=False) self.generate_message = nn.Linear(int(3 * message_dim), message_dim, bias=False) self.norm = Normalization(norm, message_dim) self.activation = Activation(activation_fn, message_dim) self.message2node = nn.Linear(message_dim, node_dim, bias=False) self.reset_parameters()
[docs] def forward( self, x: Tensor, edge_index: Adj, edge_attr: Tensor, ) -> Tensor: r""" Runs the forward pass of the module. """ out = self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr, num_nodes=x.size(0)) return x + self.message2node(out)
[docs] def message(self, x_i: Tensor, x_j: Tensor, edge_attr: Tensor) -> Tensor: edge_attr = self.edge2message(edge_attr) msg_i = self.source_node2message(x_i) msg_j = self.tgt_node2message(x_j) return self.activation( self.norm( self.generate_message( torch.cat([msg_i, msg_j, edge_attr], dim=-1))))
[docs] class v_MConv(MessagePassing): """ Message passing layer with optional vector channel. Aggregation done via softmax operator. Message embedding via linear operator. """ def __init__( self, message_depth, node_depth, edge_embedding_dim, norm=None, ): super().__init__(aggr=VectorAugSoftmaxAggregation(temperature=1, learn=True, bias=0.1, channels=message_depth), node_dim=0) self.in_channels = node_depth self.out_channels = node_depth self.edge_dim = edge_embedding_dim self.message_dim = message_depth '''initialize scalar transforms''' self.edge2message = nn.Linear(edge_embedding_dim, message_depth, bias=False) self.source_node2message = nn.Linear(node_depth, message_depth, bias=False) self.tgt_node2message = nn.Linear(node_depth, message_depth, bias=False) self.norm = Normalization(norm, message_depth) self.update2node = nn.Linear(message_depth, node_depth, bias=False) self.reset_parameters()
[docs] def forward( self, x: Tensor, edge_index: Adj, edge_attr: Tensor, ) -> Tensor: r""" Runs the forward pass of the module. """ out = self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr, num_nodes=x.size(0)) return x + self.update2node(out)
[docs] def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor) -> Tensor: edge_attr = self.edge2message(edge_attr) msg_i = self.source_node2message(x_i) msg_j = self.tgt_node2message(x_j) out = (msg_i + msg_j) * edge_attr[:, None, :] # switch to gating - addition is not allowed return self.norm(out)