Source code for torchsparsegradutils.sparse_matmul

from typing import cast

import torch

from torchsparsegradutils.utils import sparse_block_diag, sparse_block_diag_split, stack_csr


[docs] def sparse_mm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: r"""Sparse–dense matrix multiplication with memory-efficient gradients. Computes :math:`\mathbf{C} = \mathbf{A}\,\mathbf{B}` where :math:`\mathbf{A} \in \mathbb{R}^{n\times m}` is sparse (COO/CSR), :math:`\mathbf{B} \in \mathbb{R}^{m\times p}` is dense, and :math:`\mathbf{C} \in \mathbb{R}^{n\times p}`. Gradients preserve the sparsity pattern of :math:`\mathbf{A}`. Supports unbatched 2D ``(n,m) @ (m,p)`` and batched 3D inputs by block–diagonalising the batch of sparse matrices and concatenating dense matrices along the batch dimension. Let the upstream gradient be :math:`\mathbf{G} = \frac{\partial \mathcal{L}}{\partial \mathbf{C}} \in \mathbb{R}^{n\times p}`. The gradients are: Gradient with respect to B (dense): .. math:: \frac{\partial \mathcal{L}}{\partial \mathbf{B}} \;=\; \mathbf{A}^{\top} \, \mathbf{G}. Gradient with respect to A (sparse): For a dense view one has .. math:: \frac{\partial \mathcal{L}}{\partial \mathbf{A}} \;=\; \mathbf{G}\, \mathbf{B}^{\top}, but we evaluate only the entries at the nonzeros of :math:`\mathbf{A}`. Equivalently, for a nonzero entry :math:`\mathbf{A}_{ij}` the contribution is .. math:: \bigg[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\bigg]_{ij} \;=\; \sum_{k=1}^{p} \mathbf{G}_{ik} \, \mathbf{B}_{jk} \;=\; \mathbf{G}_{i,:} \,\cdot\, \mathbf{B}_{j,:}, where the dot denotes a row-wise inner product across the :math:`p` right-hand sides. Parameters ---------- A : torch.Tensor, sparse COO or CSR, shape ``(n, m)`` or ``(b, n, m)`` Left operand. For batched input, all batch items must share ``(n, m)``. All tensors must be on the same device. B : torch.Tensor, dense (strided), shape ``(m, p)`` or ``(b, m, p)`` Right operand. Must have the same number of dimensions as ``A`` and matching batch size / inner dimension ``m``. Returns ------- torch.Tensor Dense result of shape ``(n, p)`` or ``(b, n, p)``. Raises ------ ValueError If ``A`` or ``B`` are not tensors; if ranks are < 2 or not both 2D/3D; if layouts are incompatible (``A`` not COO/CSR or ``B`` not dense); if shapes are incompatible (batch or inner dims). RuntimeError If the underlying sparse matmul fails. Notes ----- This avoids dense gradients for sparse matrices [1a]_ (a known issue with :func:`torch.sparse.mm` backprop), computing only gradients at the nonzero entries of :math:`A` to reduce memory use. See Also -------- torch.sparse.mm : PyTorch's native sparse ``@`` dense. sparse_generic_lstsq : Sparse least-squares with sparse-aware gradients. References ---------- .. [1a] PyTorch issue on dense gradients for sparse ops: https://github.com/pytorch/pytorch/issues/41128 Examples -------- Basic (unbatched):: >>> indices = torch.tensor([[0, 0, 1, 1, 2, 2], ... [0, 2, 1, 3, 0, 2]]) >>> values = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) >>> A = torch.sparse_coo_tensor(indices, values, (3, 4)) >>> B = torch.randn(4, 2) >>> out = sparse_mm(A, B) >>> out.shape torch.Size([3, 2]) Batched:: >>> A_batch = torch.stack([A, A]) # (2, 3, 4) — COO stack >>> B_batch = torch.randn(2, 4, 2) # (2, 4, 2) >>> out = sparse_mm(A_batch, B_batch) >>> out.shape torch.Size([2, 3, 2]) With gradients:: >>> A.requires_grad_(True) # doctest: +ELLIPSIS tensor(...) >>> B.requires_grad_(True) # doctest: +ELLIPSIS tensor(...) >>> out = sparse_mm(A, B) >>> out.sum().backward() >>> A.grad.is_sparse True """ if not isinstance(A, torch.Tensor) or not isinstance(B, torch.Tensor): raise ValueError("Both A and B should be instances of torch.Tensor") if A.dim() < 2 or B.dim() < 2: raise ValueError("Both A and B should be at least 2-dimensional tensors") if A.dim() != B.dim() or A.dim() not in (2, 3): raise ValueError("A and B must both be 2D or both be 3D tensors") if A.layout not in {torch.sparse_coo, torch.sparse_csr}: raise ValueError("A should be in either COO or CSR sparse format") if B.layout != torch.strided: raise ValueError("B must be a dense (strided) tensor") if A.dim() == 3 and A.size(0) != B.size(0): raise ValueError("If batched, A and B must have the same batch size") if A.size(-1) != B.size(-2): raise ValueError(f"Incompatible inner dimensions: A[..., {A.size(-1)}] vs B[..., {B.size(-2)}]") return cast(torch.Tensor, SparseMatMul.apply(A, B))
[docs] class SparseMatMul(torch.autograd.Function): r"""Autograd kernel for memory-efficient sparse matrix multiplication. See Also -------- sparse_mm : User-facing function that calls this autograd function. torch.sparse.mm : PyTorch's native sparse matrix multiplication. """ @staticmethod def forward(ctx, A, B): ctx.batch_size = B.size()[0] if B.dim() == 3 else None ctx.A_shape = A.size() # (b), n, m ctx.B_shape = B.size() # (b), m, p grad_flag = A.requires_grad or B.requires_grad A, B = A.detach(), B.detach() if ctx.batch_size is not None: A = sparse_block_diag(*A) B = B.reshape(-1, B.size(-1)) x = torch.sparse.mm(A, B) ctx.save_for_backward(A, B) if ctx.batch_size is not None: x = x.view(ctx.batch_size, ctx.A_shape[-2], ctx.B_shape[-1]) x.requires_grad_(grad_flag) return x @staticmethod def backward(ctx, grad): # type: ignore[override] A, B = ctx.saved_tensors gradA = None gradB = None # -------- Only compute gradA if needed -------- if ctx.needs_input_grad[0]: # The gradient with respect to the matrix A, seen as a dense matrix, would # lead to a backprop rule as follows: gradA = grad @ b.T # but we are only interested in the gradient with respect to # the (non-zero) values of A. To save memory, instead of computing the full # dense matrix prev_grad @ b and then subsampling at the nnz locations in A, # we can directly only compute the required values: # grad_a[i,j] = dotprod(grad[i,:], b[j,:]) # We start by getting the i and j indices: if A.layout == torch.sparse_coo: A_row_idx, A_col_idx = A._indices() elif A.layout == torch.sparse_csr: A_col_idx = A.col_indices() A_crow_idx = A.crow_indices() # Uncompress row indices: A_row_idx = torch.repeat_interleave( torch.arange(A.size()[0], device=A.device), A_crow_idx[1:] - A_crow_idx[:-1] ) else: raise ValueError(f"Unsupported layout: {A.layout}") if ctx.batch_size is not None: grad_for_A = grad.reshape(-1, grad.size(-1)) else: grad_for_A = grad grad_select = grad_for_A.index_select(0, A_row_idx) # grad[i, :] B_select = B.index_select(0, A_col_idx) # B[j, :] # Dot product: gradA = (grad_select * B_select).sum(dim=1) # Create a sparse matrix of the gradient with respect to the nnz of A if A.layout == torch.sparse_coo: gradA = torch.sparse_coo_tensor(A._indices(), gradA, A.shape) elif A.layout == torch.sparse_csr: gradA = torch.sparse_csr_tensor(A.crow_indices(), A_col_idx, gradA, A.shape) if ctx.batch_size is not None: shapes = ctx.A_shape[0] * (ctx.A_shape[-2:],) gradA = sparse_block_diag_split(gradA, *shapes) if A.layout == torch.sparse_coo: gradA = torch.stack([*gradA]) else: gradA = stack_csr([*gradA]) # NOTE: torch.stack does not work for csr tensors # -------- Only compute gradB if needed -------- if ctx.needs_input_grad[1]: if ctx.batch_size is not None: grad_for_B = grad.reshape(-1, grad.size(-1)) else: grad_for_B = grad # Now compute the dense gradient with respect to B gradB = torch.sparse.mm(A.t(), grad_for_B) if ctx.batch_size is not None: gradB = gradB.view(ctx.B_shape) return gradA, gradB