mxtaltools.models.modules.augmented_softmax_aggregator

class mxtaltools.models.modules.augmented_softmax_aggregator.AugSoftmaxAggregation(temperature: float = 1.0, learn: bool = True, semi_grad: bool = False, channels: int = 1, bias: float = 0.1)[source]

Bases: Aggregation

The softmax aggregation operator based on a temperature term, as described in the “DeeperGCN: All You Need to Train Deeper GCNs” paper.

Modified with learnable bias term

forward(x: Tensor, index: Tensor | None = None, ptr: Tensor | None = None, dim_size: int | None = None, dim: int = -2) Tensor[source]
reset_parameters()[source]
class mxtaltools.models.modules.augmented_softmax_aggregator.VectorAugSoftmaxAggregation(temperature: float = 1.0, learn: bool = True, semi_grad: bool = False, channels: int = 1, bias: float = 0.1)[source]

Bases: Aggregation

adjusted to weigh by vector length rather than raw value

forward(x: Tensor, index: Tensor | None = None, ptr: Tensor | None = None, dim_size: int | None = None, dim: int = 0, cart_dim: int = 1) Tensor[source]
reset_parameters()[source]