torchsparsegradutils.gather_mm

torchsparsegradutils.gather_mm(a: Tensor, b: Tensor, idx_b: Tensor) Tensor[source]

Per-row indexed matrix multiplication.

For each row i in a this computes a[i] @ b[idx_b[i]] and stacks the results into the output.

Parameters:
  • a (torch.Tensor, shape (N, D1)) – Left operand with one row per output.

  • b (torch.Tensor, shape (R, D1, D2)) – Bank of transformation matrices.

  • idx_b (torch.Tensor, shape (N,), integer dtype) – Indices selecting which matrix in b to use for each row. Values must satisfy 0 <= idx_b[i] < R.

Returns:

Row-wise results where out[i] = a[i] @ b[idx_b[i]].

Return type:

torch.Tensor, shape (N, D2)

Raises:
  • NotImplementedError – If the fallback path is used on a PyTorch version lacking nested tensor matmul support (requires PyTorch >= 2.4).

  • ValueError – If inputs are not tensors, ranks are incorrect, or sizes are incompatible.

Notes

If DGL is available, this uses dgl.ops.gather_mm() [1b]. Otherwise it uses a dependency-free PyTorch nested-tensor fallback.

See also

segment_mm

Segmented matrix multiplication over contiguous chunks.

References

Examples

>>> import torch
>>> # N = 5, D1 = 3, D2 = 2, R = 3
>>> a = torch.randn(5, 3)
>>> b = torch.randn(3, 3, 2)
>>> idx_b = torch.tensor([0, 1, 0, 2, 1])
>>> out = gather_mm(a, b, idx_b)
>>> out.shape
torch.Size([5, 2])

All rows using the same matrix:

>>> torch.allclose(gather_mm(a, b, torch.zeros(5, dtype=torch.long)), a @ b[0])
True

Mixed indexing example:

>>> # Different transformation for each row
>>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])  # (2, 2)
>>> b = torch.tensor([[[1.0, 0.0], [0.0, 1.0]],  # Identity
...                   [[2.0, 0.0], [0.0, 2.0]]])  # 2x scale
>>> idx_b = torch.tensor([0, 1])  # Use identity, then 2x scale
>>> result = gather_mm(a, b, idx_b)
>>> result
tensor([[1., 2.],
        [6., 8.]])