mxtaltools.models.modules.graph_convolution

class mxtaltools.models.modules.graph_convolution.MConv(message_dim, node_dim, edge_embedding_dim, norm=None, activation_fn='gelu')[source]

Bases: MessagePassing

Message passing layer with optional vector channel. Aggregation done via softmax operator. Message embedding via linear operator.

forward(x: Tensor, edge_index: Adj, edge_attr: Tensor) Tensor[source]

Runs the forward pass of the module.

message(x_i: Tensor, x_j: Tensor, edge_attr: Tensor) Tensor[source]
message_dim

initialize scalar transforms

class mxtaltools.models.modules.graph_convolution.v_MConv(message_depth, node_depth, edge_embedding_dim, norm=None)[source]

Bases: MessagePassing

Message passing layer with optional vector channel. Aggregation done via softmax operator. Message embedding via linear operator.

forward(x: Tensor, edge_index: Adj, edge_attr: Tensor) Tensor[source]

Runs the forward pass of the module.

message(x_i: Tensor, x_j: Tensor, edge_attr: OptTensor) Tensor[source]
message_dim

initialize scalar transforms