torchsparsegradutils.sparse_mm

torchsparsegradutils.sparse_mm(A: Tensor, B: Tensor) Tensor[source]

Sparse–dense matrix multiplication with memory-efficient gradients.

Computes \(\mathbf{C} = \mathbf{A}\,\mathbf{B}\) where \(\mathbf{A} \in \mathbb{R}^{n\times m}\) is sparse (COO/CSR), \(\mathbf{B} \in \mathbb{R}^{m\times p}\) is dense, and \(\mathbf{C} \in \mathbb{R}^{n\times p}\). Gradients preserve the sparsity pattern of \(\mathbf{A}\). Supports unbatched 2D (n,m) @ (m,p) and batched 3D inputs by block–diagonalising the batch of sparse matrices and concatenating dense matrices along the batch dimension.

Let the upstream gradient be \(\mathbf{G} = \frac{\partial \mathcal{L}}{\partial \mathbf{C}} \in \mathbb{R}^{n\times p}\). The gradients are:

Gradient with respect to B (dense):

\[\frac{\partial \mathcal{L}}{\partial \mathbf{B}} \;=\; \mathbf{A}^{\top} \, \mathbf{G}.\]

Gradient with respect to A (sparse): For a dense view one has

\[\frac{\partial \mathcal{L}}{\partial \mathbf{A}} \;=\; \mathbf{G}\, \mathbf{B}^{\top},\]

but we evaluate only the entries at the nonzeros of \(\mathbf{A}\). Equivalently, for a nonzero entry \(\mathbf{A}_{ij}\) the contribution is

\[\bigg[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\bigg]_{ij} \;=\; \sum_{k=1}^{p} \mathbf{G}_{ik} \, \mathbf{B}_{jk} \;=\; \mathbf{G}_{i,:} \,\cdot\, \mathbf{B}_{j,:},\]

where the dot denotes a row-wise inner product across the \(p\) right-hand sides.

Parameters:
  • A (torch.Tensor, sparse COO or CSR, shape (n, m) or (b, n, m)) – Left operand. For batched input, all batch items must share (n, m). All tensors must be on the same device.

  • B (torch.Tensor, dense (strided), shape (m, p) or (b, m, p)) – Right operand. Must have the same number of dimensions as A and matching batch size / inner dimension m.

Returns:

Dense result of shape (n, p) or (b, n, p).

Return type:

torch.Tensor

Raises:
  • ValueError – If A or B are not tensors; if ranks are < 2 or not both 2D/3D; if layouts are incompatible (A not COO/CSR or B not dense); if shapes are incompatible (batch or inner dims).

  • RuntimeError – If the underlying sparse matmul fails.

Notes

This avoids dense gradients for sparse matrices [1a] (a known issue with torch.sparse.mm() backprop), computing only gradients at the nonzero entries of \(A\) to reduce memory use.

See also

torch.sparse.mm

PyTorch’s native sparse @ dense.

sparse_generic_lstsq

Sparse least-squares with sparse-aware gradients.

References

[1a]

PyTorch issue on dense gradients for sparse ops: https://github.com/pytorch/pytorch/issues/41128

Examples

Basic (unbatched):

>>> indices = torch.tensor([[0, 0, 1, 1, 2, 2],
...                         [0, 2, 1, 3, 0, 2]])
>>> values = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
>>> A = torch.sparse_coo_tensor(indices, values, (3, 4))
>>> B = torch.randn(4, 2)
>>> out = sparse_mm(A, B)
>>> out.shape
torch.Size([3, 2])

Batched:

>>> A_batch = torch.stack([A, A])          # (2, 3, 4) — COO stack
>>> B_batch = torch.randn(2, 4, 2)         # (2, 4, 2)
>>> out = sparse_mm(A_batch, B_batch)
>>> out.shape
torch.Size([2, 3, 2])

With gradients:

>>> A.requires_grad_(True)  
tensor(...)
>>> B.requires_grad_(True)  
tensor(...)
>>> out = sparse_mm(A, B)
>>> out.sum().backward()
>>> A.grad.is_sparse
True