torchsparsegradutils.gather_mm
- torchsparsegradutils.gather_mm(a: Tensor, b: Tensor, idx_b: Tensor) Tensor[source]
Per-row indexed matrix multiplication.
For each row
iinathis computesa[i] @ b[idx_b[i]]and stacks the results into the output.- Parameters:
a (torch.Tensor, shape
(N, D1)) – Left operand with one row per output.b (torch.Tensor, shape
(R, D1, D2)) – Bank of transformation matrices.idx_b (torch.Tensor, shape
(N,), integer dtype) – Indices selecting which matrix inbto use for each row. Values must satisfy0 <= idx_b[i] < R.
- Returns:
Row-wise results where
out[i] = a[i] @ b[idx_b[i]].- Return type:
torch.Tensor, shape
(N, D2)- Raises:
NotImplementedError – If the fallback path is used on a PyTorch version lacking nested tensor matmul support (requires PyTorch >= 2.4).
ValueError – If inputs are not tensors, ranks are incorrect, or sizes are incompatible.
Notes
If DGL is available, this uses
dgl.ops.gather_mm()[1b]. Otherwise it uses a dependency-free PyTorch nested-tensor fallback.See also
segment_mmSegmented matrix multiplication over contiguous chunks.
References
[1b]DGL
gather_mmdocumentation: https://www.dgl.ai/dgl_docs/generated/dgl.ops.gather_mm.htmlExamples
>>> import torch >>> # N = 5, D1 = 3, D2 = 2, R = 3 >>> a = torch.randn(5, 3) >>> b = torch.randn(3, 3, 2) >>> idx_b = torch.tensor([0, 1, 0, 2, 1]) >>> out = gather_mm(a, b, idx_b) >>> out.shape torch.Size([5, 2])
All rows using the same matrix:
>>> torch.allclose(gather_mm(a, b, torch.zeros(5, dtype=torch.long)), a @ b[0]) True
Mixed indexing example:
>>> # Different transformation for each row >>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) # (2, 2) >>> b = torch.tensor([[[1.0, 0.0], [0.0, 1.0]], # Identity ... [[2.0, 0.0], [0.0, 2.0]]]) # 2x scale >>> idx_b = torch.tensor([0, 1]) # Use identity, then 2x scale >>> result = gather_mm(a, b, idx_b) >>> result tensor([[1., 2.], [6., 8.]])