torchsparsegradutils.sparse_generic_solve
- torchsparsegradutils.sparse_generic_solve(A: Tensor, B: Tensor, solve: Callable[[...], Tensor] | None = None, transpose_solve: Callable[[...], Tensor] | None = None, **kwargs) Tensor[source]
Sparse linear solve with iterative methods and sparse-aware gradients.
Solves \(\mathbf{A}\,\mathbf{x} = \mathbf{B}\) with sparse \(\mathbf{A} \in \mathbb{R}^{n\times n}\) (COO/CSR) and dense \(\mathbf{B} \in \mathbb{R}^{n\times p}\) using iterative methods, while preserving sparsity in \(\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\). Supports single (vector) and multiple (matrix) right-hand sides and works with non-differentiable solvers via the implicit function theorem.
Let \(\mathbf{G} = \frac{\partial \mathcal{L}}{\partial \mathbf{x}}\) be the upstream gradient and \(\mathbf{x}\) the solution. The dense-form gradients are
Gradient with respect to B (dense):
\[\frac{\partial \mathcal{L}}{\partial \mathbf{B}} \;=\; \mathbf{A}^{-\top} \, \mathbf{G} \;\equiv\; \mathbf{G}_B.\]Gradient with respect to A (sparse):
\[\frac{\partial \mathcal{L}}{\partial \mathbf{A}} \;=\; -\, \mathbf{G}_B\, \mathbf{x}^{\top}.\]We evaluate only the entries corresponding to nonzeros of \(\mathbf{A}\), yielding a sparse gradient tensor with memory proportional to
nnz(A). Equivalently, for a nonzero \(\mathbf{A}_{ij}\) the contribution is\[\bigg[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\bigg]_{ij} \;=\; -\, (\mathbf{G}_B)_{i,:} \,\cdot\, \mathbf{x}_{j,:}.\]- Parameters:
A (torch.Tensor, sparse COO or CSR, shape
(n, n)) – Sparse square coefficient matrix. Must be invertible (or suitable) for the chosen solver. All tensors must be on the same device.B (torch.Tensor, dense (strided), shape
(n,)or(n, k)) – Right-hand side(s).B.shape[0]must equalA.shape[0].solve (callable, optional) –
Forward solver with signature
solve(A, B, **kwargs) -> X. IfNone, usesminres(recommended for symmetric indefinite). Other typical choices include:linear_cg(SPD matrices)bicgstab(general non-symmetric)
transpose_solve (callable, optional) – Solver for the transpose system used in backprop, with signature
transpose_solve(A, G, **kwargs) -> Ythat solves \(A^\top Y = G\) in the least-squares / iterative sense. IfNone, defaults tosolve.**kwargs (dict) – Extra keyword arguments forwarded to the solvers (e.g., tolerances, iteration caps, or solver-specific settings objects).
- Returns:
Solution tensor
Xwith the same shape asB:(n,)or(n, k).- Return type:
- Raises:
ValueError – If inputs are not tensors; shapes are incompatible; ranks are invalid.
TypeError – If
Ais not COO/CSR or ifBis not dense (strided).UserWarning – If
AandBuse different dtypes (may affect solver behavior).
Notes
Only entries at the nonzeros of \(\mathbf{A}\) are computed, keeping the gradient sparse and memory-efficient.
See also
sparse_triangular_solveTriangular systems with sparse-aware gradients.
sparse_generic_lstsqOverdetermined least-squares with sparse-aware gradients.
Examples
>>> import torch >>> from torchsparsegradutils import sparse_generic_solve >>> from torchsparsegradutils.utils import linear_cg, bicgstab, minres >>> # Symmetric positive definite example >>> indices = torch.tensor([[0, 0, 1, 1, 2], ... [0, 1, 0, 1, 2]]) >>> values = torch.tensor([4.0, -1.0, -1.0, 4.0, 2.0]) >>> A = torch.sparse_coo_tensor(indices, values, (3, 3)) >>> B = torch.tensor([1.0, 2.0, 3.0]) >>> x = sparse_generic_solve(A, B, solve=linear_cg) >>> x.shape torch.Size([3])
>>> # Multiple RHS with BiCGSTAB >>> X = sparse_generic_solve(A, torch.randn(3, 5), solve=bicgstab) >>> X.shape torch.Size([3, 5])
>>> # Default solver (MINRES) >>> x = sparse_generic_solve(A, B)
>>> # With custom solver settings: >>> from torchsparsegradutils.utils.linear_cg import LinearCGSettings >>> settings = LinearCGSettings(max_cg_iterations=1000, cg_tolerance=1e-8) >>> x = sparse_generic_solve(A, B, solve=linear_cg, settings=settings)
>>> # With gradients (A.grad is sparse) >>> A.requires_grad_(True) tensor(...) >>> B.requires_grad_(True) tensor(...) >>> x = sparse_generic_solve(A, B) >>> x.sum().backward() >>> A.grad.is_sparse True