torchsparsegradutils.sparse_generic_solve

torchsparsegradutils.sparse_generic_solve(A: Tensor, B: Tensor, solve: Callable[[...], Tensor] | None = None, transpose_solve: Callable[[...], Tensor] | None = None, **kwargs) Tensor[source]

Sparse linear solve with iterative methods and sparse-aware gradients.

Solves \(\mathbf{A}\,\mathbf{x} = \mathbf{B}\) with sparse \(\mathbf{A} \in \mathbb{R}^{n\times n}\) (COO/CSR) and dense \(\mathbf{B} \in \mathbb{R}^{n\times p}\) using iterative methods, while preserving sparsity in \(\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\). Supports single (vector) and multiple (matrix) right-hand sides and works with non-differentiable solvers via the implicit function theorem.

Let \(\mathbf{G} = \frac{\partial \mathcal{L}}{\partial \mathbf{x}}\) be the upstream gradient and \(\mathbf{x}\) the solution. The dense-form gradients are

Gradient with respect to B (dense):

\[\frac{\partial \mathcal{L}}{\partial \mathbf{B}} \;=\; \mathbf{A}^{-\top} \, \mathbf{G} \;\equiv\; \mathbf{G}_B.\]

Gradient with respect to A (sparse):

\[\frac{\partial \mathcal{L}}{\partial \mathbf{A}} \;=\; -\, \mathbf{G}_B\, \mathbf{x}^{\top}.\]

We evaluate only the entries corresponding to nonzeros of \(\mathbf{A}\), yielding a sparse gradient tensor with memory proportional to nnz(A). Equivalently, for a nonzero \(\mathbf{A}_{ij}\) the contribution is

\[\bigg[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\bigg]_{ij} \;=\; -\, (\mathbf{G}_B)_{i,:} \,\cdot\, \mathbf{x}_{j,:}.\]
Parameters:
  • A (torch.Tensor, sparse COO or CSR, shape (n, n)) – Sparse square coefficient matrix. Must be invertible (or suitable) for the chosen solver. All tensors must be on the same device.

  • B (torch.Tensor, dense (strided), shape (n,) or (n, k)) – Right-hand side(s). B.shape[0] must equal A.shape[0].

  • solve (callable, optional) –

    Forward solver with signature solve(A, B, **kwargs) -> X. If None, uses minres (recommended for symmetric indefinite). Other typical choices include:

    • linear_cg (SPD matrices)

    • bicgstab (general non-symmetric)

  • transpose_solve (callable, optional) – Solver for the transpose system used in backprop, with signature transpose_solve(A, G, **kwargs) -> Y that solves \(A^\top Y = G\) in the least-squares / iterative sense. If None, defaults to solve.

  • **kwargs (dict) – Extra keyword arguments forwarded to the solvers (e.g., tolerances, iteration caps, or solver-specific settings objects).

Returns:

Solution tensor X with the same shape as B: (n,) or (n, k).

Return type:

torch.Tensor

Raises:
  • ValueError – If inputs are not tensors; shapes are incompatible; ranks are invalid.

  • TypeError – If A is not COO/CSR or if B is not dense (strided).

  • UserWarning – If A and B use different dtypes (may affect solver behavior).

Notes

Only entries at the nonzeros of \(\mathbf{A}\) are computed, keeping the gradient sparse and memory-efficient.

See also

sparse_triangular_solve

Triangular systems with sparse-aware gradients.

sparse_generic_lstsq

Overdetermined least-squares with sparse-aware gradients.

Examples

>>> import torch
>>> from torchsparsegradutils import sparse_generic_solve
>>> from torchsparsegradutils.utils import linear_cg, bicgstab, minres
>>> # Symmetric positive definite example
>>> indices = torch.tensor([[0, 0, 1, 1, 2],
...                         [0, 1, 0, 1, 2]])
>>> values = torch.tensor([4.0, -1.0, -1.0, 4.0, 2.0])
>>> A = torch.sparse_coo_tensor(indices, values, (3, 3))
>>> B = torch.tensor([1.0, 2.0, 3.0])
>>> x = sparse_generic_solve(A, B, solve=linear_cg)
>>> x.shape
torch.Size([3])
>>> # Multiple RHS with BiCGSTAB
>>> X = sparse_generic_solve(A, torch.randn(3, 5), solve=bicgstab)
>>> X.shape
torch.Size([3, 5])
>>> # Default solver (MINRES)
>>> x = sparse_generic_solve(A, B)
>>> # With custom solver settings:
>>> from torchsparsegradutils.utils.linear_cg import LinearCGSettings
>>> settings = LinearCGSettings(max_cg_iterations=1000, cg_tolerance=1e-8)
>>> x = sparse_generic_solve(A, B, solve=linear_cg, settings=settings)
>>> # With gradients (A.grad is sparse)
>>> A.requires_grad_(True)  
tensor(...)
>>> B.requires_grad_(True)  
tensor(...)
>>> x = sparse_generic_solve(A, B)
>>> x.sum().backward()
>>> A.grad.is_sparse
True