torchsparsegradutils.sparse_mm
- torchsparsegradutils.sparse_mm(A: Tensor, B: Tensor) Tensor[source]
Sparse–dense matrix multiplication with memory-efficient gradients.
Computes \(\mathbf{C} = \mathbf{A}\,\mathbf{B}\) where \(\mathbf{A} \in \mathbb{R}^{n\times m}\) is sparse (COO/CSR), \(\mathbf{B} \in \mathbb{R}^{m\times p}\) is dense, and \(\mathbf{C} \in \mathbb{R}^{n\times p}\). Gradients preserve the sparsity pattern of \(\mathbf{A}\). Supports unbatched 2D
(n,m) @ (m,p)and batched 3D inputs by block–diagonalising the batch of sparse matrices and concatenating dense matrices along the batch dimension.Let the upstream gradient be \(\mathbf{G} = \frac{\partial \mathcal{L}}{\partial \mathbf{C}} \in \mathbb{R}^{n\times p}\). The gradients are:
Gradient with respect to B (dense):
\[\frac{\partial \mathcal{L}}{\partial \mathbf{B}} \;=\; \mathbf{A}^{\top} \, \mathbf{G}.\]Gradient with respect to A (sparse): For a dense view one has
\[\frac{\partial \mathcal{L}}{\partial \mathbf{A}} \;=\; \mathbf{G}\, \mathbf{B}^{\top},\]but we evaluate only the entries at the nonzeros of \(\mathbf{A}\). Equivalently, for a nonzero entry \(\mathbf{A}_{ij}\) the contribution is
\[\bigg[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\bigg]_{ij} \;=\; \sum_{k=1}^{p} \mathbf{G}_{ik} \, \mathbf{B}_{jk} \;=\; \mathbf{G}_{i,:} \,\cdot\, \mathbf{B}_{j,:},\]where the dot denotes a row-wise inner product across the \(p\) right-hand sides.
- Parameters:
A (torch.Tensor, sparse COO or CSR, shape
(n, m)or(b, n, m)) – Left operand. For batched input, all batch items must share(n, m). All tensors must be on the same device.B (torch.Tensor, dense (strided), shape
(m, p)or(b, m, p)) – Right operand. Must have the same number of dimensions asAand matching batch size / inner dimensionm.
- Returns:
Dense result of shape
(n, p)or(b, n, p).- Return type:
- Raises:
ValueError – If
AorBare not tensors; if ranks are < 2 or not both 2D/3D; if layouts are incompatible (Anot COO/CSR orBnot dense); if shapes are incompatible (batch or inner dims).RuntimeError – If the underlying sparse matmul fails.
Notes
This avoids dense gradients for sparse matrices [1a] (a known issue with
torch.sparse.mm()backprop), computing only gradients at the nonzero entries of \(A\) to reduce memory use.See also
torch.sparse.mmPyTorch’s native sparse
@dense.sparse_generic_lstsqSparse least-squares with sparse-aware gradients.
References
[1a]PyTorch issue on dense gradients for sparse ops: https://github.com/pytorch/pytorch/issues/41128
Examples
Basic (unbatched):
>>> indices = torch.tensor([[0, 0, 1, 1, 2, 2], ... [0, 2, 1, 3, 0, 2]]) >>> values = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) >>> A = torch.sparse_coo_tensor(indices, values, (3, 4)) >>> B = torch.randn(4, 2) >>> out = sparse_mm(A, B) >>> out.shape torch.Size([3, 2])
Batched:
>>> A_batch = torch.stack([A, A]) # (2, 3, 4) — COO stack >>> B_batch = torch.randn(2, 4, 2) # (2, 4, 2) >>> out = sparse_mm(A_batch, B_batch) >>> out.shape torch.Size([2, 3, 2])
With gradients:
>>> A.requires_grad_(True) tensor(...) >>> B.requires_grad_(True) tensor(...) >>> out = sparse_mm(A, B) >>> out.sum().backward() >>> A.grad.is_sparse True