import warnings
from typing import Callable, Optional, cast
import torch
from torchsparsegradutils.utils import convert_coo_to_csr, sparse_block_diag, sparse_block_diag_split, stack_csr
[docs]
def sparse_triangular_solve(
A: torch.Tensor,
B: torch.Tensor,
upper: bool = True,
unitriangular: bool = False,
transpose: bool = False,
) -> torch.Tensor:
r"""Sparse triangular solve with memory-efficient sparse gradients.
Solves the triangular system :math:`\mathbf{A}\,\mathbf{x} = \mathbf{B}` (or
:math:`\mathbf{A}^{\top}\,\mathbf{x} = \mathbf{B}` if ``transpose=True``), where
:math:`\mathbf{A} \in \mathbb{R}^{m\times m}` is sparse triangular (COO/CSR) and
:math:`\mathbf{B} \in \mathbb{R}^{m\times p}` is dense. Gradients preserve the sparsity
pattern of :math:`\mathbf{A}` by evaluating only at its nonzero entries. Supports
unbatched 2D and batched 3D inputs; COO inputs are converted to CSR internally for the
factor solve.
Let the upstream gradient be :math:`\mathbf{G} = \frac{\partial \mathcal{L}}{\partial \mathbf{x}}`
for a scalar objective :math:`\mathcal{L}` and solution :math:`\mathbf{x}`. The dense-form
gradients are
Gradient with respect to B (dense):
.. math::
\frac{\partial \mathcal{L}}{\partial \mathbf{B}} \;=\; \mathbf{A}^{-\top} \, \mathbf{G},
and for ``transpose=True`` replace :math:`\mathbf{A}` by :math:`\mathbf{A}^{\top}` so that
:math:`\frac{\partial \mathcal{L}}{\partial \mathbf{B}} = \left(\mathbf{A}^{\top}\right)^{-\top} \mathbf{G} = \mathbf{A}^{-1} \mathbf{G}`.
Gradient with respect to A (sparse):
.. math::
\frac{\partial \mathcal{L}}{\partial \mathbf{A}} \;=\; -\big(\mathbf{A}^{-\top} \, \mathbf{G}\big)\, \mathbf{x}^{\top},
and only entries at the nonzeros of :math:`\mathbf{A}` are evaluated. Equivalently,
for a nonzero :math:`\mathbf{A}_{ij}` the contribution is
.. math::
\bigg[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\bigg]_{ij}
\;=\; -\, \big(\mathbf{A}^{-\top} \, \mathbf{G}\big)_{i,:} \,\cdot\, \mathbf{x}_{j,:},
where the dot denotes a row-wise inner product across the :math:`p` right-hand sides.
Parameters
----------
A : torch.Tensor, sparse COO or CSR, shape ``(m, m)`` or ``(b, m, m)``
Sparse triangular coefficient matrix. Must be square per batch. All tensors must
be on the same device.
B : torch.Tensor, dense (strided), shape ``(m, p)`` or ``(b, m, p)``
Right-hand side. ``B.shape[-2]`` must equal ``A.shape[-2]`` (``m``).
upper : bool, optional
If ``True`` (default), treat ``A`` as upper-triangular; else lower-triangular.
unitriangular : bool, optional
If ``True``, assume unit diagonal (implicit ones). The stored matrix must be
strictly triangular (no explicit diagonal entries). Default: ``False``.
transpose : bool, optional
If ``True``, solves :math:`A^\top x = B`; otherwise :math:`A x = B`. Default: ``False``.
Returns
-------
torch.Tensor
Solution with the same shape as ``B``: ``(m, p)`` or ``(b, m, p)``.
Raises
------
ValueError
If inputs are not tensors; ranks are < 2 or not both 2D/3D; layouts are
incompatible (``A`` not COO/CSR or ``B`` not dense); shapes are incompatible;
batch sizes differ; or if ``unitriangular=True`` but explicit diagonal
entries are present.
RuntimeError
If the underlying triangular solve fails.
Notes
-----
Backprop computes gradients only at nonzero entries of :math:`\mathbf{A}`, keeping the
gradient sparse and reducing memory. COO inputs are converted to CSR since PyTorch's
triangular solver requires CSR [1e]_. For autograd implementation details, see [2e]_.
See Also
--------
torch.sparse.mm : Sparse ``@`` dense multiply.
torch.linalg.solve_triangular : Dense triangular solver (modern API).
References
----------
.. [1e] PyTorch issue on sparse triangular solve:
https://github.com/pytorch/pytorch/issues/87358
.. [2e] PyTorch issue on autograd/triangular solve:
https://github.com/pytorch/pytorch/issues/88890
Examples
--------
Upper-triangular::
>>> import torch
>>> from torchsparsegradutils import sparse_triangular_solve
>>> A = torch.sparse_csr_tensor([0, 2, 3, 4], [0, 2, 1, 2],
... torch.tensor([2.0, 1.0, 3.0, 1.0]), (3, 3))
>>> B = torch.tensor([[1.0], [2.0], [3.0]])
>>> x = sparse_triangular_solve(A, B, upper=True)
>>> x.shape
torch.Size([3, 1])
Lower-triangular::
>>> A_low = torch.sparse_csr_tensor([0, 1, 3, 5], [0, 0, 1, 0, 2],
... torch.tensor([2.0, 1.0, 3.0, 0.5, 1.0]), (3, 3))
>>> x = sparse_triangular_solve(A_low, B, upper=False)
Batched::
>>> # Convert to COO for batching (since torch.stack doesn't work with CSR)
>>> A_coo = A.to_sparse_coo()
>>> A_b = torch.stack([A_coo, A_coo]) # (2, 3, 3)
>>> B_b = torch.stack([B, B]) # (2, 3, 1)
>>> x_b = sparse_triangular_solve(A_b, B_b)
>>> x_b.shape
torch.Size([2, 3, 1])
"""
# --- minimal validations to match the docstring expectations ---
if not isinstance(A, torch.Tensor) or not isinstance(B, torch.Tensor):
raise ValueError("Both A and B should be instances of torch.Tensor")
if A.dim() < 2 or B.dim() < 2:
raise ValueError("Both A and B should be at least 2-dimensional tensors")
if A.dim() != B.dim() or A.dim() not in (2, 3):
raise ValueError("A and B must both be 2D or both be 3D tensors")
if A.layout not in {torch.sparse_coo, torch.sparse_csr}:
raise ValueError("A should be in either COO or CSR sparse format")
if B.layout != torch.strided:
raise ValueError("B must be a dense (strided) tensor")
if A.shape[-2] != A.shape[-1]:
raise ValueError("A must be square on its last two dimensions")
if A.size(-2) != B.size(-2):
raise ValueError(f"Incompatible inner dimensions: A[..., {A.size(-2)}] vs B[..., {B.size(-2)}]")
if A.dim() == 3 and A.size(0) != B.size(0):
raise ValueError("If batched, A and B must have the same batch size")
return cast(torch.Tensor, SparseTriangularSolve.apply(A, B, upper, unitriangular, transpose))
[docs]
class SparseTriangularSolve(torch.autograd.Function):
r"""
Autograd function for memory-efficient sparse triangular system solving.
See Also
--------
sparse_triangular_solve : User-facing function that calls this autograd function.
torch.triangular_solve : PyTorch's native triangular solver.
"""
@staticmethod
def forward(ctx, A, B, upper, unitriangular, transpose):
ctx.batch_size = B.size()[0] if B.dim() == 3 else None
ctx.A_shape = A.size() # (b), m, m
ctx.B_shape = B.size() # (b), m, p
ctx.csr = True
ctx.upper = upper
ctx.unitriangular = unitriangular
ctx.transpose = transpose
grad_flag = A.requires_grad or B.requires_grad
if ctx.batch_size is not None:
A = sparse_block_diag(*A)
B = B.reshape(-1, B.size(-1))
if A.layout == torch.sparse_coo:
A = convert_coo_to_csr(A) # NOTE: triangular solve doesn't work with sparse coo
ctx.csr = False
# NOTE: DEPRECATED: Check if a workaround for https://github.com/pytorch/pytorch/issues/88890 is needed
x = torch.triangular_solve(
B.detach(), A.detach(), upper=upper, unitriangular=unitriangular, transpose=transpose
).solution
x.requires_grad = grad_flag
ctx.save_for_backward(A, x.detach())
if ctx.batch_size is not None:
x = x.view(ctx.batch_size, ctx.A_shape[-2], ctx.B_shape[-1])
return x
@staticmethod
def backward(ctx, grad): # type: ignore[override]
if ctx.batch_size is not None:
grad = grad.reshape(-1, grad.size(-1))
A, x = ctx.saved_tensors
# Backprop rule: gradB = A^{-T} grad
# NOTE: DEPRECATED: Check if a workaround for https://github.com/pytorch/pytorch/issues/88890 is needed
gradB = torch.triangular_solve(
grad, A, upper=ctx.upper, transpose=not ctx.transpose, unitriangular=ctx.unitriangular
).solution
# The gradient with respect to the matrix A seen as a dense matrix would
# lead to a backprop rule as follows
# gradA = -(A^{-T} grad)(A^{-1} B) = - gradB @ x.T
# but we are only interested in the gradient with respect to
# the (non-zero) values of A. To save memory, instead of computing the full
# dense matrix gradB @ x.T and then subsampling at the nnz locations in a,
# we can directly only compute the required values:
# gradA[i,j] = - dotprod(gradB[i,:], x[j,:])
# We start by getting the i and j indices:
A_col_idx = A.col_indices()
A_crow_idx = A.crow_indices()
# Uncompress row indices:
A_row_idx = torch.repeat_interleave(
torch.arange(A.size()[0], device=A.device), A_crow_idx[1:] - A_crow_idx[:-1]
)
if ctx.transpose:
mgradbselect = -gradB.index_select(0, A_col_idx) # -gradB[j, :]
xselect = x.index_select(0, A_row_idx) # x[i, :]
else:
mgradbselect = -gradB.index_select(0, A_row_idx) # -gradB[i, :]
xselect = x.index_select(0, A_col_idx) # x[j, :]
if ctx.unitriangular is True and torch.any(A_row_idx == A_col_idx):
raise ValueError("First input should be strictly triangular (i.e. unit diagonals is implicit)")
# Dot product:
mgbx = mgradbselect * xselect
gradA = torch.sum(mgbx, dim=1)
if ctx.csr is False:
gradA = torch.sparse_coo_tensor(torch.stack([A_row_idx, A_col_idx]), gradA, A.shape)
else:
gradA = torch.sparse_csr_tensor(A_crow_idx, A_col_idx, gradA, A.shape)
if ctx.batch_size is not None:
shapes = ctx.A_shape[0] * (ctx.A_shape[-2:],)
gradA = sparse_block_diag_split(gradA, *shapes)
if not ctx.csr:
gradA = torch.stack([*gradA])
else:
gradA = stack_csr([*gradA])
gradB = gradB.view(ctx.B_shape)
return gradA, gradB, None, None, None
[docs]
def sparse_generic_solve(
A: torch.Tensor,
B: torch.Tensor,
solve: Optional[Callable[..., torch.Tensor]] = None,
transpose_solve: Optional[Callable[..., torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
r"""Sparse linear solve with iterative methods and sparse-aware gradients.
Solves :math:`\mathbf{A}\,\mathbf{x} = \mathbf{B}` with sparse :math:`\mathbf{A} \in \mathbb{R}^{n\times n}`
(COO/CSR) and dense :math:`\mathbf{B} \in \mathbb{R}^{n\times p}` using iterative methods, while
preserving sparsity in :math:`\frac{\partial \mathcal{L}}{\partial \mathbf{A}}`. Supports single
(vector) and multiple (matrix) right-hand sides and works with non-differentiable solvers via
the implicit function theorem.
Let :math:`\mathbf{G} = \frac{\partial \mathcal{L}}{\partial \mathbf{x}}` be the upstream gradient and
:math:`\mathbf{x}` the solution. The dense-form gradients are
Gradient with respect to B (dense):
.. math::
\frac{\partial \mathcal{L}}{\partial \mathbf{B}} \;=\; \mathbf{A}^{-\top} \, \mathbf{G}
\;\equiv\; \mathbf{G}_B.
Gradient with respect to A (sparse):
.. math::
\frac{\partial \mathcal{L}}{\partial \mathbf{A}} \;=\; -\, \mathbf{G}_B\, \mathbf{x}^{\top}.
We evaluate only the entries corresponding to nonzeros of :math:`\mathbf{A}`, yielding a
sparse gradient tensor with memory proportional to ``nnz(A)``. Equivalently, for a nonzero
:math:`\mathbf{A}_{ij}` the contribution is
.. math::
\bigg[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\bigg]_{ij}
\;=\; -\, (\mathbf{G}_B)_{i,:} \,\cdot\, \mathbf{x}_{j,:}.
Parameters
----------
A : torch.Tensor, sparse COO or CSR, shape ``(n, n)``
Sparse square coefficient matrix. Must be invertible (or suitable) for the
chosen solver. All tensors must be on the same device.
B : torch.Tensor, dense (strided), shape ``(n,)`` or ``(n, k)``
Right-hand side(s). ``B.shape[0]`` must equal ``A.shape[0]``.
solve : callable, optional
Forward solver with signature ``solve(A, B, **kwargs) -> X``.
If ``None``, uses ``minres`` (recommended for symmetric indefinite).
Other typical choices include:
* ``linear_cg`` (SPD matrices)
* ``bicgstab`` (general non-symmetric)
transpose_solve : callable, optional
Solver for the transpose system used in backprop, with signature
``transpose_solve(A, G, **kwargs) -> Y`` that solves :math:`A^\top Y = G` in the
least-squares / iterative sense. If ``None``, defaults to ``solve``.
**kwargs : dict
Extra keyword arguments forwarded to the solvers (e.g., tolerances,
iteration caps, or solver-specific settings objects).
Returns
-------
torch.Tensor
Solution tensor ``X`` with the same shape as ``B``: ``(n,)`` or ``(n, k)``.
Raises
------
ValueError
If inputs are not tensors; shapes are incompatible; ranks are invalid.
TypeError
If ``A`` is not COO/CSR or if ``B`` is not dense (strided).
UserWarning
If ``A`` and ``B`` use different dtypes (may affect solver behavior).
Notes
-----
Only entries at the nonzeros of :math:`\mathbf{A}` are computed, keeping the gradient
sparse and memory-efficient.
See Also
--------
sparse_triangular_solve : Triangular systems with sparse-aware gradients.
sparse_generic_lstsq : Overdetermined least-squares with sparse-aware gradients.
Examples
--------
>>> import torch
>>> from torchsparsegradutils import sparse_generic_solve
>>> from torchsparsegradutils.utils import linear_cg, bicgstab, minres
>>> # Symmetric positive definite example
>>> indices = torch.tensor([[0, 0, 1, 1, 2],
... [0, 1, 0, 1, 2]])
>>> values = torch.tensor([4.0, -1.0, -1.0, 4.0, 2.0])
>>> A = torch.sparse_coo_tensor(indices, values, (3, 3))
>>> B = torch.tensor([1.0, 2.0, 3.0])
>>> x = sparse_generic_solve(A, B, solve=linear_cg)
>>> x.shape
torch.Size([3])
>>> # Multiple RHS with BiCGSTAB
>>> X = sparse_generic_solve(A, torch.randn(3, 5), solve=bicgstab)
>>> X.shape
torch.Size([3, 5])
>>> # Default solver (MINRES)
>>> x = sparse_generic_solve(A, B)
>>> # With custom solver settings:
>>> from torchsparsegradutils.utils.linear_cg import LinearCGSettings
>>> settings = LinearCGSettings(max_cg_iterations=1000, cg_tolerance=1e-8)
>>> x = sparse_generic_solve(A, B, solve=linear_cg, settings=settings)
>>> # With gradients (A.grad is sparse)
>>> A.requires_grad_(True) # doctest: +ELLIPSIS
tensor(...)
>>> B.requires_grad_(True) # doctest: +ELLIPSIS
tensor(...)
>>> x = sparse_generic_solve(A, B)
>>> x.sum().backward()
>>> A.grad.is_sparse
True
"""
# Input validation
if not isinstance(A, torch.Tensor) or not isinstance(B, torch.Tensor):
raise ValueError("Both A and B should be instances of torch.Tensor")
if A.layout not in (torch.sparse_coo, torch.sparse_csr):
raise TypeError(f"Unsupported sparse layout: {A.layout}. Only COO and CSR are supported.")
if A.dim() != 2:
raise ValueError("A must be a 2D tensor")
if A.shape[0] != A.shape[1]:
raise ValueError("A must be square")
if B.dim() not in (1, 2):
raise ValueError("B must be a 1D or 2D tensor")
if B.shape[0] != A.shape[0]:
raise ValueError(f"Incompatible dimensions: A has shape {tuple(A.shape)}, B has shape {tuple(B.shape)}")
if B.layout != torch.strided:
raise TypeError("B must be a dense (strided) tensor")
if A.dtype != B.dtype:
warnings.warn(
f"A and B have different dtypes: A={A.dtype}, B={B.dtype}. This may affect solver behavior.",
UserWarning,
stacklevel=2,
)
# ---------- default solvers ----------
if solve is None and transpose_solve is None:
from .utils import minres
solve = minres
transpose_solve = minres
elif solve is None:
solve = transpose_solve
elif transpose_solve is None:
transpose_solve = solve
X = cast(torch.Tensor, SparseGenericSolve.apply(A, B, solve, transpose_solve, kwargs))
# Ensure output rank matches B (solver might return (n,1) for 1D B, etc.)
if B.dim() == 1 and X.dim() == 2 and X.shape[1] == 1:
X = X.squeeze(-1)
elif B.dim() == 2 and X.dim() == 1:
X = X.unsqueeze(-1)
return X
[docs]
class SparseGenericSolve(torch.autograd.Function):
r"""
Autograd function for sparse linear system solving with iterative methods.
See Also
--------
sparse_generic_solve : User-facing function that calls this autograd function.
"""
@staticmethod
def forward(ctx, A, B, solve, transpose_solve, kwargs):
grad_flag = A.requires_grad or B.requires_grad
ctx.transpose_solve = transpose_solve
ctx.kwargs = kwargs # Store kwargs for backward pass
x = solve(A.detach(), B.detach(), **kwargs)
# Ensure output dtype matches input dtype
if x.dtype != A.dtype:
x = x.to(dtype=A.dtype)
x.requires_grad = grad_flag
ctx.save_for_backward(A, x.detach())
return x
@staticmethod
def backward(ctx, grad): # type: ignore[override]
A, x = ctx.saved_tensors
# Unsqueeze, if necessary
is_vector = x.ndim == 1
if is_vector:
x = x.unsqueeze(-1)
grad = grad.unsqueeze(-1)
# Backprop rule: gradB = A^{-T} grad
gradB = ctx.transpose_solve(A, grad, **ctx.kwargs)
# Ensure gradient dtype matches input dtype
if gradB.dtype != A.dtype:
gradB = gradB.to(dtype=A.dtype)
# The gradient with respect to the matrix A seen as a dense matrix would
# lead to a backprop rule as follows
# gradA = -(A^{-T} grad)(A^{-1} B) = - gradB @ x.T
# but we are only interested in the gradient with respect to
# the (non-zero) values of A. To save memory, instead of computing the full
# dense matrix gradB @ x.T and then subsampling at the nnz locations in a,
# we can directly only compute the required values:
# gradA[i,j] = - dotprod(gradB[i,:], x[j,:])
# We start by getting the i and j indices:
if A.layout == torch.sparse_coo:
A_coalesced = A.coalesce() # Ensure tensor is coalesced before accessing indices
A_row_idx = A_coalesced.indices()[0, :]
A_col_idx = A_coalesced.indices()[1, :]
else:
A_col_idx = A.col_indices()
A_crow_idx = A.crow_indices()
# Uncompress row indices:
A_row_idx = torch.repeat_interleave(
torch.arange(A.size()[0], device=A.device), A_crow_idx[1:] - A_crow_idx[:-1]
)
mgradbselect = -gradB.index_select(0, A_row_idx) # -gradB[i, :]
xselect = x.index_select(0, A_col_idx) # x[j, :]
# Dot product:
mgbx = mgradbselect * xselect
gradA = torch.sum(mgbx, dim=1)
# Ensure gradient dtype matches input dtype
if gradA.dtype != A.dtype:
gradA = gradA.to(dtype=A.dtype)
if A.layout == torch.sparse_coo:
gradA = torch.sparse_coo_tensor(torch.stack([A_row_idx, A_col_idx]), gradA, A.shape)
else:
gradA = torch.sparse_csr_tensor(A.crow_indices(), A_col_idx, gradA, A.shape)
# Squeeze gradB back to original shape if it was a vector
if is_vector:
gradB = gradB.squeeze(-1)
return gradA, gradB, None, None, None