mxtaltools.models.functions.radial_graph
- mxtaltools.models.functions.radial_graph.asymmetric_radius_graph(x: Tensor, r: float, inside_inds: Tensor, convolve_inds: Tensor, batch: Tensor, loop: bool = False, max_num_neighbors: int = 32, flow: str = 'source_to_target', num_workers: int = 1) Tensor[source]
Computes graph edges to all points within a given distance.
- Parameters:
x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
r (float) – The radius.
batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.
batchneeds to be sorted. (default:None)loop (bool, optional) – If
True, the graph will contain self-loops. (default:False)max_num_neighbors (int, optional) – The maximum number of neighbors to return for each element. If the number of actual neighbors is greater than
max_num_neighbors, returned neighbors are picked randomly. (default:32)flow (string, optional) – The flow direction when used in combination with message passing (
"source_to_target"or"target_to_source"). (default:"source_to_target")num_workers (int) – Number of workers to use for computation. Has no effect in case
batchis notNone, or the input lies on the GPU. (default:1)inside_inds (Tensor) – original indices for the nodes in the y subgraph
- Return type:
LongTensor
import torch from torch_cluster import radius_graph x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) batch = torch.tensor([0, 0, 0, 0]) edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
- mxtaltools.models.functions.radial_graph.build_radial_graph(pos: FloatTensor, batch: LongTensor, ptr: LongTensor, cutoff: float, max_num_neighbors: int, aux_ind: LongTensor = None, mol_ind: LongTensor = None)[source]
Construct edge indices over a radial graph. Optionally, compute intra (within ref_mol_inds) and inter (between ref_mol_inds and outside inds) edges. :param pos: node positions :param batch: index of graph to which each node belongs :param ptr: edges of batch :param cutoff: maximum edge length :param max_num_neighbors: maximum number of neighbors per node :param aux_ind: optional auxiliary index for identifying “inside” and “outside” nodes :param mol_ind: optional index for the identity of the molecule a given atom is inside, for when there are multiple molecules per asymmetric unit, or in a cluster of molecules
- Returns:
dictionary of edge information
- Return type: