from typing import Optional
import torch
from torch_geometric import nn as gnn
def radius(x: torch.Tensor, y: torch.Tensor, r: float,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
max_num_neighbors: int = 32,
num_workers: int = 1) -> torch.Tensor:
r"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
y (Tensor): Node feature matrix
:math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
r (float): The radius.
batch_x (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. :obj:`batch_x` needs to be sorted.
(default: :obj:`None`)
batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. :obj:`batch_y` needs to be sorted.
(default: :obj:`None`)
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`.
If the number of actual neighbors is greater than
:obj:`max_num_neighbors`, returned neighbors are picked randomly.
(default: :obj:`32`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
.. code-block:: python
import torch
from torch_cluster import radius
x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]])
batch_y = torch.tensor([0, 0])
assign_index = radius(x, y, 1.5, batch_x, batch_y)
"""
x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()
# ptr_x: Optional[torch.Tensor] = None
# if batch_x is not None:
# assert x.size(0) == batch_x.numel()
# batch_size = int(batch_x.max()) + 1
#
# deg = x.new_zeros(batch_size, dtype=torch.long)
# deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
#
# ptr_x = deg.new_zeros(batch_size + 1)
# torch.cumsum(deg, 0, out=ptr_x[1:])
#
# ptr_y: Optional[torch.Tensor] = None
# if batch_y is not None:
# assert y.size(0) == batch_y.numel()
# batch_size = int(batch_y.max()) + 1
#
# deg = y.new_zeros(batch_size, dtype=torch.long)
# deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
#
# ptr_y = deg.new_zeros(batch_size + 1)
# torch.cumsum(deg, 0, out=ptr_y[1:])
return radius(
x,
y,
r,
batch_x,
batch_y,
max_num_neighbors,
num_workers
)
# update usage of torch_cluster.radius
# return torch.ops.torch_cluster.radius(x,
# y,
# ptr_x,
# ptr_y,
# r,
# max_num_neighbors,
# num_workers)
from torch_cluster import radius
# @torch.jit.script
[docs]
def asymmetric_radius_graph(x: torch.Tensor,
r: float,
inside_inds: torch.Tensor,
convolve_inds: torch.Tensor,
batch: torch.Tensor,
loop: bool = False,
max_num_neighbors: int = 32,
flow: str = 'source_to_target',
num_workers: int = 1) -> torch.Tensor:
r"""Computes graph edges to all points within a given distance.
Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
r (float): The radius.
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. :obj:`batch` needs to be sorted.
(default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element.
If the number of actual neighbors is greater than
:obj:`max_num_neighbors`, returned neighbors are picked randomly.
(default: :obj:`32`)
flow (string, optional): The flow direction when used in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
inside_inds (Tensor): original indices for the nodes in the y subgraph
:rtype: :class:`LongTensor`
.. code-block:: python
import torch
from torch_cluster import radius_graph
x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch = torch.tensor([0, 0, 0, 0])
edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
"""
if convolve_inds is None: # indexes of items within x to convolve against y
convolve_inds = torch.arange(len(x))
assert flow in ['source_to_target', 'target_to_source']
if batch is not None:
edge_index = radius(x[convolve_inds], x[inside_inds], r, batch[convolve_inds], batch[inside_inds],
max_num_neighbors if loop else max_num_neighbors + 1,
num_workers)
else:
edge_index = radius(x[convolve_inds], x[inside_inds], r, None, None,
max_num_neighbors if loop else max_num_neighbors + 1,
num_workers)
target, source = edge_index[0], edge_index[1]
# edge_index[1] = inside_inds[edge_index[1, :]] # reindex
target = inside_inds[target] # contains correct indexes
source = convolve_inds[source]
if flow == 'source_to_target':
row, col = source, target
else:
row, col = target, source
if not loop: # now properly deletes self-loops
mask = row != col
row, col = row[mask], col[mask]
return torch.stack([row, col], dim=0)
[docs]
def build_radial_graph(pos: torch.FloatTensor,
batch: torch.LongTensor,
ptr: torch.LongTensor,
cutoff: float,
max_num_neighbors: int,
aux_ind: torch.LongTensor=None,
mol_ind: torch.LongTensor=None,
):
r"""
Construct edge indices over a radial graph.
Optionally, compute intra (within ref_mol_inds) and inter (between ref_mol_inds and outside inds) edges.
Args:
pos: node positions
batch: index of graph to which each node belongs
ptr: edges of batch
cutoff: maximum edge length
max_num_neighbors: maximum number of neighbors per node
aux_ind: optional auxiliary index for identifying "inside" and "outside" nodes
mol_ind: optional index for the identity of the molecule a given atom is inside, for when there are multiple molecules per asymmetric unit, or in a cluster of molecules
Returns:
dict: dictionary of edge information
"""
if aux_ind is not None: # there is an 'inside' 'outside' distinction
assert aux_ind.dtype in (torch.int64, torch.int32, torch.bool)
assert (aux_ind == 0).sum() + (aux_ind == 1).sum() + (aux_ind == 2).sum() == aux_ind.numel()
inside_bool = aux_ind == 0
outside_bool = aux_ind == 1
inside_inds = torch.where(inside_bool)[0]
# atoms which are not in the asymmetric unit but which we will convolve - pre-excluding many from outside the cutoff
outside_inds = torch.where(outside_bool)[0]
inside_batch = batch[inside_inds] # get the feature vectors we want to repeat
# counts of atoms per molecule in full batch
batch_counts = torch.bincount(batch, minlength=len(ptr) - 1)
# counts of atoms per molecule in the inside region
inside_counts = torch.bincount(inside_batch, minlength=len(ptr) - 1)
n_repeats = (batch_counts // inside_counts).tolist()
# within aunit edges - includes intermolecular edges when Z'>1
edge_index = asymmetric_radius_graph(pos, batch=batch, r=cutoff,
# intramolecular interactions - stack over range 3 convolutions
max_num_neighbors=max_num_neighbors, flow='source_to_target',
inside_inds=inside_inds, convolve_inds=inside_inds)
# intermolecular edges
edge_index_inter = asymmetric_radius_graph(pos, batch=batch, r=cutoff,
# extra radius for intermolecular graph convolution
max_num_neighbors=max_num_neighbors, flow='source_to_target',
inside_inds=inside_inds, convolve_inds=outside_inds)
# disaggregate within-aunit intermolecular edges from genuine intramolecular edges
# necessary for Z'>1 structures
intra_edge = mol_ind[edge_index[0]] == mol_ind[edge_index[1]]
if torch.any(~intra_edge):
# if there are any within aunit intra edges
edge_index_inter = torch.cat([edge_index_inter, edge_index[:, ~intra_edge]], dim=1)
edge_index = edge_index[:, intra_edge]
return {'edge_index': edge_index,
'edge_index_inter': edge_index_inter,
'inside_inds': inside_inds,
'outside_inds': outside_inds,
'inside_batch': inside_batch,
'n_repeats': n_repeats}
else:
edge_index = gnn.radius_graph(pos,
r=cutoff,
batch=batch,
max_num_neighbors=max_num_neighbors,
flow='source_to_target') # note - requires batch be monotonically increasing
return {'edge_index': edge_index}