import warnings
from functools import reduce
from itertools import chain, product
from math import ceil, floor
from operator import mul
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import numpy
import torch
from torchsparsegradutils.utils import convert_coo_to_csr_indices_values
from torchsparsegradutils.utils.utils import _sort_coo_indices
def _trim_nd(x: torch.Tensor, offsets: Tuple[int, ...]) -> torch.Tensor:
r"""Trim a tensor along each axis according to per-dimension offsets.
Positive offsets drop elements from the **start** of a dimension
(keep ``offset:``); negative offsets drop elements from the **end**
(keep ``:offset``). A zero offset leaves that dimension unchanged.
The number of offsets must match ``x.ndim``.
Parameters
----------
x : torch.Tensor
Input tensor of arbitrary shape.
offsets : Tuple[int, ...]
Tuple of integer offsets (one per dimension of ``x``). For an entry ``k``:
* ``k > 0`` → keep ``x[k:]`` along that axis
* ``k == 0`` → keep the whole axis (``x[:]``)
* ``k < 0`` → keep ``x[:k]`` (drop ``|k|`` elements from the end)
Returns
-------
torch.Tensor
A **view** of ``x`` trimmed according to ``offsets`` (device, dtype and
strides are preserved, subject to standard PyTorch slicing semantics).
Raises
------
ValueError
If ``len(offsets) != x.ndim``.
Notes
-----
Equivalent slice construction (for demonstration):
>>> import torch
>>> x = torch.arange(6)
>>> offsets = (2,)
>>> slices = tuple(slice(None if off < 0 else off, None if off > -1 else off) for off in offsets)
>>> y = x[slices]
Slicing returns a view when possible—no data copy is performed.
Examples
--------
1D:
>>> x = torch.arange(6) # tensor([0, 1, 2, 3, 4, 5])
>>> _trim_nd(x, (2,)) # keep from index 2 onward
tensor([2, 3, 4, 5])
>>> _trim_nd(x, (-2,)) # drop last 2
tensor([0, 1, 2, 3])
2D:
>>> x = torch.arange(12).view(3, 4)
>>> x
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> _trim_nd(x, (1, 0)) # drop first row
tensor([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> _trim_nd(x, (0, -1)) # drop last column
tensor([[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10]])
"""
if x.ndim != len(offsets):
raise ValueError(f"Number of dimensions in tensor ({x.ndim}) does not match number of offsets ({len(offsets)})")
return x[tuple(map(lambda i: slice(None if i < 0 else i, None if i > -1 else i), offsets))]
def _gen_coords_nd(radius: float, spatial_dims: int) -> Set[Tuple[int, ...]]:
r"""Generate integer lattice coordinates inside an :math:`N`-D :math:`\ell_2` ball.
Returns all integer points :math:`x \in \mathbb{Z}^d` such that
:math:`\|x\|_2 \le r` (``r = radius``), excluding the origin
:math:`(0,\dots,0)`. Points are enumerated from the hypercube
:math:`[\lfloor-r\rfloor,\lceil r\rceil]^d` and filtered by the
Euclidean norm test.
Parameters
----------
radius : float
Radius of the hypersphere (may be non-integer). If ``radius < 0`` the
result is the empty set.
spatial_dims : int
Number of spatial dimensions ``d``.
Returns
-------
Set[Tuple[int, ...]]
Integer coordinate tuples inside the closed ball of radius ``radius``
(origin excluded). Order is unspecified.
Raises
------
ValueError
If ``spatial_dims <= 0``.
Notes
-----
* Only the all-zero vector is excluded; individual components may be zero.
* Runtime / output size scale like :math:`O((2\lceil r\rceil+1)^d)`.
* For large ``radius`` or ``spatial_dims`` consider streaming instead of
materializing the full set.
Examples
--------
1D (interval on integers):
>>> _gen_coords_nd(2.0, 1) == {(-2,), (-1,), (1,), (2,)}
True
2D (disk of radius 1.5):
>>> pts = _gen_coords_nd(1.5, 2)
>>> sorted(pts) # doctest: +ELLIPSIS
[(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]
3D (ball of radius 1):
>>> sorted(_gen_coords_nd(1.0, 3)) # doctest: +ELLIPSIS
[(-1, 0, 0), (0, -1, 0), (0, 0, -1), (0, 0, 1), (0, 1, 0), (1, 0, 0)]
"""
if spatial_dims <= 0:
raise ValueError("spatial_dims must be a positive integer")
range_vals = range(floor(-radius), ceil(radius) + 1)
coords = set(
coord
for coord in product(range_vals, repeat=spatial_dims)
if sum(x**2 for x in coord) <= radius**2 and coord != tuple(0 for _ in range(spatial_dims))
)
return coords
def _gen_coords(radius: float) -> Set[Tuple[int, int, int]]:
r"""Integer lattice points inside a 3D :math:`\ell_2` ball (deprecated wrapper).
.. deprecated:: 0.x
Use :func:`_gen_coords_nd` with ``spatial_dims=3`` instead:
``_gen_coords_nd(radius, 3)``. This function remains for backward
compatibility and forwards directly.
Returns all integer points :math:`(x,y,z) \in \mathbb{Z}^3` with
:math:`\sqrt{x^2 + y^2 + z^2} \le r`, excluding the origin
:math:`(0,0,0)`.
Parameters
----------
radius : float
Sphere radius (if ``radius < 0`` the result is empty).
Returns
-------
Set[Tuple[int, int, int]]
Integer triples inside the closed 3D ball (origin excluded).
Notes
-----
Enumeration over :math:`[\lfloor-r\rfloor, \lceil r\rceil]^3` filtered by the
norm test. Complexity :math:`O((2\lceil r\rceil+1)^3)`.
Cardinality (reference):
* ``r < 1`` → 0 points
* ``1 \le r < \sqrt{2}`` → 6 (axis neighbors)
* ``\sqrt{2} \le r < \sqrt{3}`` → 18 (adds edge neighbors)
* ``\sqrt{3} \le r < 2`` → 26 (adds corner neighbors)
* ``2 \le r < \sqrt{5}`` → 32 (adds distance-2 axis neighbors)
See Also
--------
_gen_coords_nd : Preferred N-D implementation.
Examples
--------
>>> pts = _gen_coords(1.0)
>>> sorted(pts) # doctest: +ELLIPSIS
[(-1, 0, 0), (0, -1, 0), (0, 0, -1), (0, 0, 1), (0, 1, 0), (1, 0, 0)]
>>> _gen_coords(1.0) == _gen_coords_nd(1.0, 3)
True
"""
# Cast for type checker: underlying returns Set[Tuple[int, ...]] but here we constrain to 3D.
return set(tuple(c) for c in _gen_coords_nd(radius, 3)) # type: ignore[return-value]
def _gen_offsets_nd(
radius: float,
spatial_dims: int,
upper: bool | None = None,
num_channels: int = 1,
channel_voxel_relation: str = "indep",
) -> list[tuple[int, ...]]:
r"""Generate :math:`(1+N)`-D channel+voxel offset tuples inside an :math:`N`-D ball.
Returns a **sorted** list of tuples ``(c, s1, ..., sN)`` where ``c`` is the channel
offset and ``(s1..sN)`` a spatial offset with :math:`\sum_i s_i^2 \le r^2` (``r = radius``),
excluding the all-zero tuple. Depending on ``channel_voxel_relation`` the set is
augmented with pure channel offsets and/or combined channel+spatial offsets.
Sign filtering (argument ``upper``) keeps offsets based on the *first non-zero* entry
in the full tuple ``(c, s1, ..., sN)``:
* ``upper is False`` → keep those whose first non-zero is positive
* ``upper is True`` → keep those whose first non-zero is negative
* ``upper is None`` → keep all (except the all-zero)
Ordering key (stable, deterministic):
1. Squared radius in augmented space where the channel component is scaled by 10
2. Lexicographic order of absolute values ``(|c|, |s1|, ..., |sN|)``
3. Sign preference (non-negative entries ordered after negative ones on ties)
Parameters
----------
radius : float
Spatial neighborhood radius (may be non-integer).
spatial_dims : int
Number of spatial dims ``N``.
upper : bool or None, optional
Sign-selection filter (see above). Default ``None``.
num_channels : int, optional
Number of channels (affects channel offsets). Default ``1``.
channel_voxel_relation : {'indep', 'intra', 'inter'}, optional
* ``'indep'`` – only spatial offsets ``(0, s1..sN)``
* ``'intra'`` – plus intra-voxel channel offsets ``(c, 0, ..., 0)``
* ``'inter'`` – plus intra offsets and inter-voxel ``(c, s1..sN)``
Returns
-------
list[tuple[int, ...]]
List of offset tuples of length ``1 + spatial_dims`` (no all-zero tuple).
Raises
------
ValueError
If ``spatial_dims <= 0`` (from :func:`_gen_coords_nd`).
Notes
-----
* Spatial offsets from :func:`_gen_coords_nd` never include the zero vector.
* Channel component is scaled by 10 in the radius used for ordering to keep
channel steps ranked above small spatial ties.
See Also
--------
_gen_coords_nd : Enumerate spatial coordinates within radius.
Examples
--------
2D, channel independent:
>>> _gen_offsets_nd(1.5, spatial_dims=2, upper=None, num_channels=1, channel_voxel_relation='indep') # doctest: +ELLIPSIS
[(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), ...]
Add intra-voxel channel offsets (two channels):
>>> _gen_offsets_nd(1.0, 2, num_channels=2, channel_voxel_relation='intra') # doctest: +ELLIPSIS
[(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (1, 0, 0)]
Inter-voxel channel+spatial combinations:
>>> offs = _gen_offsets_nd(1.0, 2, num_channels=2, channel_voxel_relation='inter')
>>> any(o[0] == 1 and o[1:] != (0, 0) for o in offs)
True
Restrict to ``upper=False``:
>>> _gen_offsets_nd(1.0, 1, upper=False, num_channels=1, channel_voxel_relation='indep')
[(0, 1)]
"""
def first_non_zero_positive(coord):
for c in coord:
if c != 0:
return c > 0
return False
def first_non_zero_negative(coord):
for c in coord:
if c != 0:
return c < 0
return False
# Generate spatial offset coordinates:
coords = _gen_coords_nd(radius, spatial_dims)
# Add channel offsets based on channel_voxel_relation:
offsets = [(0,) + s for s in coords] # channel independent offsets
if channel_voxel_relation != "indep":
# Add intra-voxel channel offsets (no spatial offset):
for c in range(1, num_channels):
offsets.append((c,) + tuple(0 for _ in range(spatial_dims)))
if channel_voxel_relation == "inter":
# Add inter-voxel channel offsets (both spatial and channel offsets):
for c in range(1, num_channels):
offsets.extend([(c,) + s for s in coords])
if upper is False:
offsets = [offset for offset in offsets if first_non_zero_positive(offset)]
elif upper is True:
offsets = [offset for offset in offsets if first_non_zero_negative(offset)]
# Offsets are sorted first by radius then lexigraphically by absolute value of each element
# If the absolute values are equal, the positive element is considered bigger than the negative element.
# This is done for reproducibility and also logical ordering of the offsets.
offsets = sorted(
offsets,
key=lambda x: (sum([i**2 for i in (10 * x[0],) + x[1:]]), tuple(map(abs, x)), tuple(y >= 0 for y in x)),
)
return offsets
def _gen_offsets(
radius: float,
upper: bool | None = None,
num_channels: int = 1,
channel_voxel_relation: str = "indep",
) -> list[tuple[int, int, int, int]]:
r"""Generate 4D channel+spatial offsets in a 3D spherical neighborhood (deprecated).
.. deprecated:: 0.x
Use :func:`_gen_offsets_nd(radius, 3, upper, num_channels, channel_voxel_relation)`.
Produces sorted tuples ``(c, z, y, x)`` where ``(z,y,x)`` satisfy
:math:`z^2 + y^2 + x^2 \le r^2` and channel offsets are added according to
``channel_voxel_relation``.
Parameters
----------
radius : float
Spatial radius ``r``.
upper : bool or None, optional
Sign-selection filter (first non-zero criterion). Default ``None``.
num_channels : int, optional
Number of channels. Default ``1``.
channel_voxel_relation : {'indep', 'intra', 'inter'}, optional
Channel/spatial relation mode.
Returns
-------
list[tuple[int, int, int, int]]
4D offset tuples (without the all-zero tuple).
Notes
-----
Equivalent to calling :func:`_gen_offsets_nd` with ``spatial_dims=3``.
See Also
--------
_gen_offsets_nd : N-D generalization.
_gen_coords_nd : Underlying spatial coordinate generator.
Examples
--------
Channel-independent (only spatial):
>>> _gen_offsets(1.5, upper=None, num_channels=1, channel_voxel_relation='indep') # doctest: +ELLIPSIS
[(0, 0, 0, -1), (0, 0, 0, 1), (0, 0, -1, 0), (0, 0, 1, 0), (0, -1, 0, 0), (0, 1, 0, 0), ...]
Intra-voxel channel offsets:
>>> _gen_offsets(1.0, num_channels=2, channel_voxel_relation='intra') # doctest: +ELLIPSIS
[(0, 0, 0, -1), (0, 0, 0, 1), (0, 0, -1, 0), (0, 0, 1, 0), (0, -1, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)]
Inter-voxel combinations:
>>> offs = _gen_offsets(1.0, num_channels=2, channel_voxel_relation='inter')
>>> any(o[0] == 1 and o[1:] != (0, 0, 0) for o in offs)
True
"""
offs = _gen_offsets_nd(radius, 3, upper, num_channels, channel_voxel_relation)
return [tuple(o) for o in offs] # type: ignore[return-value]
[docs]
def calc_pairwise_coo_indices_nd(
radius: float,
volume_shape: Tuple[int, ...],
diag: bool = False,
upper: bool | None = None,
channel_voxel_relation: str = "indep",
dtype: torch.dtype = torch.int64,
device: torch.device | None = torch.device("cpu"),
) -> Dict[Tuple[int, ...], torch.Tensor]:
r"""Compute per-offset COO linear index pairs for an :math:`(C,*S)` volume.
For a volume ``(C, *spatial_dims)`` and spatial radius ``r``, return a dictionary
mapping each offset tuple ``(c, *spatial_offset)`` to a ``(2, M)`` tensor of linear
index pairs ``[[i...],[j...]]`` such that the second row is the first row shifted
by the offset (within bounds). Linearization follows row-major order
(``torch.arange(prod(volume_shape)).reshape(volume_shape).flatten()``).
Offsets come from :func:`_gen_offsets_nd` (sorted), which enumerates spatial offsets
with :math:`\|o_{spatial}\|_2 \le r` and augments them with channel offsets
according to ``channel_voxel_relation``.
Parameters
----------
radius : float
Neighborhood radius (``>= 1``).
volume_shape : tuple[int, ...]
Shape ``(C, *spatial_dims)`` with at least one spatial dimension.
diag : bool, optional
Include diagonal key ``(0,...,0)`` mapping to ``(i,i)`` pairs. Default ``False``.
upper : bool or None, optional
Forwarded sign filter (see :func:`_gen_offsets_nd`). Default ``None``.
channel_voxel_relation : {'indep','intra','inter'}, optional
Channel relation mode. Default ``'indep'``.
dtype : torch.dtype, optional
Integer dtype of output index tensors (default ``torch.int64``).
device : torch.device, optional
Target device (default CPU).
Returns
-------
dict[tuple[int, ...], torch.Tensor]
Mapping from offset tuple to a ``(2, M_o)`` tensor of linear index pairs.
Raises
------
ValueError
If arguments are inconsistent (e.g. ``radius < 1``).
Notes
-----
Each non-zero offset ``o`` yields pairs by trimming the index lattice twice with
:func:`_trim_nd`: once by ``o`` and once by ``-o``. Only valid in-bounds pairs
are produced (no padding). Sorting matches :func:`_gen_offsets_nd`.
See Also
--------
_gen_offsets_nd : Generate (sorted) offsets.
_trim_nd : Bounds-aware slicing used for forming pairs.
Examples
--------
2D single channel:
>>> idxs = calc_pairwise_coo_indices_nd(
... radius=1.0,
... volume_shape=(1, 3, 3), # (C,H,W)
... diag=True,
... upper=None,
... channel_voxel_relation='indep',
... )
>>> sorted(list(idxs.keys()))[:3] # doctest: +ELLIPSIS
[(0, -1, 0), (0, 0, -1), (0, 0, 0)]
>>> z = (0, 0, 0)
>>> idxs[z].shape
torch.Size([2, 9])
3D, inter-channel:
>>> idxs3d = calc_pairwise_coo_indices_nd(
... radius=1.0,
... volume_shape=(2, 3, 3, 3),
... channel_voxel_relation='inter',
... )
>>> any(o[0] == 1 and o[1:] != (0, 0, 0) for o in idxs3d.keys())
True
"""
if radius < 1:
raise ValueError("radius must be >= 1")
if not (len(volume_shape) >= 2 and all(isinstance(dim, int) and dim > 0 for dim in volume_shape)):
raise ValueError("volume_shape must be a tuple of at least 2 positive integers")
if channel_voxel_relation not in ["indep", "intra", "inter"]:
raise ValueError("channel_voxel_relation must be 'indep', 'intra', or 'inter'")
if volume_shape[0] == 1 and channel_voxel_relation != "indep":
raise ValueError("channel_voxel_relation must be 'indep' when number of channels is 1")
device = torch.device(device) if device is not None else None
spatial_dims = len(volume_shape) - 1
# Generate offsets:
offsets = _gen_offsets_nd(radius, spatial_dims, upper, volume_shape[0], channel_voxel_relation)
idx = torch.arange(reduce(mul, volume_shape), device=device, dtype=dtype).reshape(
volume_shape
) # create numbered array
indices = {}
if diag is True:
zero_offset = tuple(0 for _ in range(len(volume_shape)))
indices[zero_offset] = torch.stack([idx.flatten(), idx.flatten()])
for offset in offsets:
# Compute trimmed indices:
x1_idx = _trim_nd(idx, offset)
x2_idx = _trim_nd(idx, tuple([-o for o in offset]))
# Stack into indices list:
indices[offset] = torch.stack([x1_idx.flatten(), x2_idx.flatten()])
return indices
[docs]
def calc_pairwise_coo_indices(
radius: float,
volume_shape: Tuple[int, int, int, int],
diag: bool = False,
upper: bool | None = None,
channel_voxel_relation: str = "indep",
dtype: torch.dtype = torch.int64,
device: torch.device = torch.device("cpu"),
) -> Dict[Tuple[int, int, int, int], torch.Tensor]:
r"""3D wrapper for :func:`calc_pairwise_coo_indices_nd` (deprecated).
.. deprecated:: 0.x
Use :func:`calc_pairwise_coo_indices_nd`.
Parameters
----------
radius : float
Spatial radius (``>=1``).
volume_shape : tuple[int,int,int,int]
``(C,H,D,W)``.
diag : bool, optional
Include diagonal. Default ``False``.
upper : bool or None, optional
Sign-selection (see N-D version). Default ``None``.
channel_voxel_relation : {'indep','intra','inter'}, optional
Channel relation mode. Default ``'indep'``.
dtype : torch.dtype, optional
Index dtype (default ``torch.int64``).
device : torch.device, optional
Target device (default CPU).
Returns
-------
dict[tuple[int,int,int,int], torch.Tensor]
Per-offset COO index pairs.
Raises
------
ValueError
If inputs are invalid.
"""
# Validate 4D shape for backward compatibility
if not (len(volume_shape) == 4 and all(isinstance(dim, int) and dim > 0 for dim in volume_shape)):
raise ValueError("`volume_shape` must be a 4D tuple of positive integers, representing [C, H, D, W]")
out = calc_pairwise_coo_indices_nd(radius, volume_shape, diag, upper, channel_voxel_relation, dtype, device)
# Narrow key type for static checker (all offsets have length 4 here)
return {tuple(k): v for k, v in out.items()} # type: ignore[return-value]
# Keep the typo version for backward compatibility
calc_pariwise_coo_indices = calc_pairwise_coo_indices # type: ignore[misc]
[docs]
class PairwiseEncoder(torch.nn.Module):
r"""Encode pairwise spatial–channel neighborhoods as sparse tensors.
Precomputes a mapping from local neighborhoods (within spatial radius ``r``)
and optional channel interactions to global sparse matrix indices over the
linearized volume. Useful for graph-like layers, covariance assembly, sparse
attention and any operator exploiting local geometric structure.
Parameters
----------
radius : float
Spatial neighborhood radius ``r``.
volume_shape : tuple[int, ...]
``(C,*spatial_dims)`` (at least one spatial dimension).
diag : bool, optional
Include diagonal (self-edges). Default ``False``.
upper : bool or None, optional
Triangular selection on offset set (first non-zero criterion). ``None`` keeps all.
channel_voxel_relation : {'indep','intra','inter'}, optional
Channel interaction model. Default ``'indep'``.
layout : torch.layout, optional
``torch.sparse_coo`` (default) or ``torch.sparse_csr``.
indices_dtype : torch.dtype, optional
Integer dtype for indices (``int32`` or ``int64``). Default ``int64``.
device : torch.device, optional
Device to store cached indices (default CPU).
Attributes
----------
volume_numel : int
``C * prod(spatial_dims)``.
spatial_dims : int
``len(volume_shape) - 1``.
offsets : list[tuple[int,...]]
Ordered offsets (optionally with diagonal key first if ``diag``).
indices : torch.Tensor
(COO) ``(2, nnz_total)`` tensor of linear index pairs.
crow_indices, col_indices, csr_permutation : torch.Tensor
(CSR) components & permutation to reorder values into CSR order.
Notes
-----
* Input to :meth:`__call__` must have shape ``[(B), N, C, *S]`` where
``N == len(self.offsets)``. Batch dimension optional.
* Edge handling uses trimming (no wrap, no padding).
* CSR values are internally reordered via ``self.csr_permutation``.
* Complexity scales with number of offsets times valid pairs (≈ :math:`O(r^2)` in 2D, :math:`O(r^3)` in 3D).
See Also
--------
calc_pairwise_coo_indices_nd : Build per-offset COO indices.
convert_coo_to_csr_indices_values : COO→CSR conversion + permutation.
_gen_offsets_nd : Construct ordered offset set.
_trim_nd : Bounds-aware slicing for forming value blocks.
Examples
--------
Basic 2D:
>>> from torchsparsegradutils.encoders import PairwiseEncoder
>>> encoder = PairwiseEncoder(
... radius=1.5,
... volume_shape=(3, 8, 8),
... diag=True,
... channel_voxel_relation='indep'
... )
>>> encoder.volume_numel
192
>>> len(encoder.offsets) # doctest: +SKIP
13
Create sparse tensor from values:
>>> values = torch.randn(len(encoder.offsets), 3, 8, 8)
>>> sp = encoder(values)
>>> sp.shape
torch.Size([192, 192])
>>> sp.is_sparse
True
Batched:
>>> values_b = torch.randn(4, len(encoder.offsets), 3, 8, 8)
>>> sp_b = encoder(values_b)
>>> sp_b.shape
torch.Size([4, 192, 192])
3D with inter-channel relations:
>>> encoder3d = PairwiseEncoder(
... radius=2.0,
... volume_shape=(5, 16, 16, 16),
... channel_voxel_relation='inter',
... layout=torch.sparse_csr,
... )
Upper-triangular (symmetric use-case):
>>> sym = PairwiseEncoder(
... radius=1.0,
... volume_shape=(1, 10, 10),
... upper=True,
... diag=True,
... )
>>> v = torch.randn(len(sym.offsets), 1, 10, 10)
>>> _ = sym(v)
"""
[docs]
def __init__(
self,
radius: float,
volume_shape: Tuple[int, ...],
diag: bool = False,
upper: bool | None = None,
channel_voxel_relation: str = "indep",
layout=torch.sparse_coo,
indices_dtype: torch.dtype = torch.int64,
device: torch.device = torch.device("cpu"),
):
super().__init__()
if not ((len(volume_shape) >= 2) and all(isinstance(dim, int) and dim > 0 for dim in volume_shape)):
raise ValueError(
"`volume_shape` must be a tuple of at least 2 positive integers, representing [C, *spatial_dims]"
)
if indices_dtype not in [torch.int64, torch.int32]:
raise ValueError("`indices_dtype` must be torch.int64 or torch.int32 for torch.sparse_coo")
self.radius = radius
self.volume_shape = volume_shape
self.diag = diag
self.upper = upper
self.channel_voxel_relation = channel_voxel_relation
self.layout = layout
self.indices_dtype = indices_dtype
self.volume_numel = reduce(mul, volume_shape)
self.spatial_dims = len(volume_shape) - 1
indices_coo_dict = calc_pairwise_coo_indices_nd(
radius, volume_shape, diag, upper, channel_voxel_relation, indices_dtype, device
)
self.offsets = list(indices_coo_dict.keys()) # dictionary keys are ordered as of Python 3.7
indices_coo = torch.cat([indices_coo_dict[offset] for offset in indices_coo_dict], dim=1)
if layout == torch.sparse_coo:
self.indices = indices_coo
self.csr_permutation = None
elif layout == torch.sparse_csr:
self.crow_indices, self.col_indices, self.csr_permutation = convert_coo_to_csr_indices_values(
indices_coo, num_rows=self.volume_numel, values=None
)
else:
raise ValueError("layout must be either torch.sparse_coo or torch.sparse_csr")
def _apply(self, fn, recurse=True):
# Applying the function to the desired attributes
# This has been implemented to allow using the .to() method
for attr in ["indices", "csr_permutation", "crow_indices", "col_indices"]:
tensor = getattr(self, attr, None)
if tensor is not None:
setattr(self, attr, fn(tensor))
return self
@property
def device(self):
if self.layout == torch.sparse_coo:
return self.indices.device
elif self.layout == torch.sparse_csr:
return self.crow_indices.device
def _calc_values(self, values: torch.Tensor) -> torch.Tensor:
r"""Assemble flattened value vector for one (unbatched) call.
Parameters
----------
values : torch.Tensor
Tensor of shape ``(N, C, *spatial_dims)`` where ``N == len(self.offsets)``.
Returns
-------
torch.Tensor
Flattened concatenation of trimmed per-offset blocks (order matches ``self.indices`` or CSR permutation input order).
"""
values_out = []
for offset, val in zip(self.offsets, values):
trimmed_val = _trim_nd(val, offset).flatten()
values_out.append(trimmed_val)
return torch.cat(values_out)
[docs]
def __call__(self, values: torch.Tensor) -> torch.Tensor:
r"""Construct sparse tensor (COO or CSR) from per-offset value blocks.
Parameters
----------
values : torch.Tensor
Shape ``[(B), N, C, *spatial_dims]`` with optional batch ``B`` and
``N == len(self.offsets)``.
Returns
-------
torch.Tensor
Sparse tensor of shape ``[(B), S, S]`` where ``S = C * prod(spatial_dims)``.
Raises
------
ValueError
If shape, dtype or offset count are inconsistent.
"""
expected_spatial_dims = len(self.volume_shape) - 1
expected_full_dims = expected_spatial_dims + 2 # C + spatial_dims
if len(values.shape) < expected_full_dims or len(values.shape) > expected_full_dims + 1:
raise ValueError(
f"values must have {expected_full_dims} dimensions (N, C, *spatial_dims) "
f"or {expected_full_dims + 1} dimensions (B, N, C, *spatial_dims)"
)
# Check spatial dimensions match
spatial_shape_in_values = values.shape[-expected_spatial_dims:]
expected_spatial_shape = self.volume_shape[-expected_spatial_dims:]
if spatial_shape_in_values != expected_spatial_shape:
raise ValueError(
f"Spatial dimensions do not match: expected {expected_spatial_shape}, " f"got {spatial_shape_in_values}"
)
# Check number of offsets
offset_dim_idx = -expected_full_dims
if values.shape[offset_dim_idx] != len(self.offsets):
raise ValueError(
f"Shape of values at index {offset_dim_idx} ({values.shape[offset_dim_idx]}) "
f"must match number of offsets ({len(self.offsets)})"
)
if values.dtype not in [torch.float32, torch.float64]:
raise ValueError("values must be either torch.float32 or torch.float64 for sparse tensors")
batched = len(values.shape) == expected_full_dims + 1
batch_size: int | None = None
if batched:
batch_size = values.shape[0]
# Calculate values
if batched:
assert batch_size is not None # for type checker
size_batched: tuple[int, int, int] = (batch_size, self.volume_numel, self.volume_numel)
size_any = size_batched # unified name
processed = [self._calc_values(batch) for batch in values] # type: ignore[assignment]
values = torch.stack(processed)
else:
size_unbatched: tuple[int, int] = (self.volume_numel, self.volume_numel)
size_any = size_unbatched
values = self._calc_values(values)
if self.layout == torch.sparse_coo:
if batched:
assert batch_size is not None
sparse_dim_indices = self.indices.repeat(1, batch_size)
batch_dim_indices = (
torch.arange(batch_size, dtype=self.indices.dtype, device=self.indices.device)
.repeat_interleave(self.indices.shape[-1])
.unsqueeze(0)
)
indices = torch.cat([batch_dim_indices, sparse_dim_indices])
values = values.flatten()
else:
indices = self.indices
return torch.sparse_coo_tensor(
indices, values, size=size_any, dtype=values.dtype, device=values.device
).coalesce()
if self.layout == torch.sparse_csr:
if self.csr_permutation is None:
raise RuntimeError("csr_permutation is None; expected a permutation tensor when layout is sparse_csr.")
values = values.index_select(dim=-1, index=self.csr_permutation)
if batched:
assert batch_size is not None
crow_indices = self.crow_indices.repeat(batch_size, 1)
col_indices = self.col_indices.repeat(batch_size, 1)
else:
crow_indices = self.crow_indices
col_indices = self.col_indices
return torch.sparse_csr_tensor(
crow_indices, col_indices, values, size=size_any, dtype=values.dtype, device=values.device
)
raise RuntimeError("Unsupported sparse layout")