mxtaltools.models.modules.graph_convolution
- class mxtaltools.models.modules.graph_convolution.MConv(message_dim, node_dim, edge_embedding_dim, norm=None, activation_fn='gelu')[source]
Bases:
MessagePassingMessage passing layer with optional vector channel. Aggregation done via softmax operator. Message embedding via linear operator.
- forward(x: Tensor, edge_index: Adj, edge_attr: Tensor) Tensor[source]
Runs the forward pass of the module.
- message_dim
initialize scalar transforms
- class mxtaltools.models.modules.graph_convolution.v_MConv(message_depth, node_depth, edge_embedding_dim, norm=None)[source]
Bases:
MessagePassingMessage passing layer with optional vector channel. Aggregation done via softmax operator. Message embedding via linear operator.
- forward(x: Tensor, edge_index: Adj, edge_attr: Tensor) Tensor[source]
Runs the forward pass of the module.
- message_dim
initialize scalar transforms