torchsparsegradutils.sparse_triangular_solve
- torchsparsegradutils.sparse_triangular_solve(A: Tensor, B: Tensor, upper: bool = True, unitriangular: bool = False, transpose: bool = False) Tensor[source]
Sparse triangular solve with memory-efficient sparse gradients.
Solves the triangular system \(\mathbf{A}\,\mathbf{x} = \mathbf{B}\) (or \(\mathbf{A}^{\top}\,\mathbf{x} = \mathbf{B}\) if
transpose=True), where \(\mathbf{A} \in \mathbb{R}^{m\times m}\) is sparse triangular (COO/CSR) and \(\mathbf{B} \in \mathbb{R}^{m\times p}\) is dense. Gradients preserve the sparsity pattern of \(\mathbf{A}\) by evaluating only at its nonzero entries. Supports unbatched 2D and batched 3D inputs; COO inputs are converted to CSR internally for the factor solve.Let the upstream gradient be \(\mathbf{G} = \frac{\partial \mathcal{L}}{\partial \mathbf{x}}\) for a scalar objective \(\mathcal{L}\) and solution \(\mathbf{x}\). The dense-form gradients are
Gradient with respect to B (dense):
\[\frac{\partial \mathcal{L}}{\partial \mathbf{B}} \;=\; \mathbf{A}^{-\top} \, \mathbf{G},\]and for
transpose=Truereplace \(\mathbf{A}\) by \(\mathbf{A}^{\top}\) so that \(\frac{\partial \mathcal{L}}{\partial \mathbf{B}} = \left(\mathbf{A}^{\top}\right)^{-\top} \mathbf{G} = \mathbf{A}^{-1} \mathbf{G}\).Gradient with respect to A (sparse):
\[\frac{\partial \mathcal{L}}{\partial \mathbf{A}} \;=\; -\big(\mathbf{A}^{-\top} \, \mathbf{G}\big)\, \mathbf{x}^{\top},\]and only entries at the nonzeros of \(\mathbf{A}\) are evaluated. Equivalently, for a nonzero \(\mathbf{A}_{ij}\) the contribution is
\[\bigg[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\bigg]_{ij} \;=\; -\, \big(\mathbf{A}^{-\top} \, \mathbf{G}\big)_{i,:} \,\cdot\, \mathbf{x}_{j,:},\]where the dot denotes a row-wise inner product across the \(p\) right-hand sides.
- Parameters:
A (torch.Tensor, sparse COO or CSR, shape
(m, m)or(b, m, m)) – Sparse triangular coefficient matrix. Must be square per batch. All tensors must be on the same device.B (torch.Tensor, dense (strided), shape
(m, p)or(b, m, p)) – Right-hand side.B.shape[-2]must equalA.shape[-2](m).upper (bool, optional) – If
True(default), treatAas upper-triangular; else lower-triangular.unitriangular (bool, optional) – If
True, assume unit diagonal (implicit ones). The stored matrix must be strictly triangular (no explicit diagonal entries). Default:False.transpose (bool, optional) – If
True, solves \(A^\top x = B\); otherwise \(A x = B\). Default:False.
- Returns:
Solution with the same shape as
B:(m, p)or(b, m, p).- Return type:
- Raises:
ValueError – If inputs are not tensors; ranks are < 2 or not both 2D/3D; layouts are incompatible (
Anot COO/CSR orBnot dense); shapes are incompatible; batch sizes differ; or ifunitriangular=Truebut explicit diagonal entries are present.RuntimeError – If the underlying triangular solve fails.
Notes
Backprop computes gradients only at nonzero entries of \(\mathbf{A}\), keeping the gradient sparse and reducing memory. COO inputs are converted to CSR since PyTorch’s triangular solver requires CSR [1e]. For autograd implementation details, see [2e].
See also
torch.sparse.mmSparse
@dense multiply.torch.linalg.solve_triangularDense triangular solver (modern API).
References
[1e]PyTorch issue on sparse triangular solve: https://github.com/pytorch/pytorch/issues/87358
[2e]PyTorch issue on autograd/triangular solve: https://github.com/pytorch/pytorch/issues/88890
Examples
Upper-triangular:
>>> import torch >>> from torchsparsegradutils import sparse_triangular_solve >>> A = torch.sparse_csr_tensor([0, 2, 3, 4], [0, 2, 1, 2], ... torch.tensor([2.0, 1.0, 3.0, 1.0]), (3, 3)) >>> B = torch.tensor([[1.0], [2.0], [3.0]]) >>> x = sparse_triangular_solve(A, B, upper=True) >>> x.shape torch.Size([3, 1])
Lower-triangular:
>>> A_low = torch.sparse_csr_tensor([0, 1, 3, 5], [0, 0, 1, 0, 2], ... torch.tensor([2.0, 1.0, 3.0, 0.5, 1.0]), (3, 3)) >>> x = sparse_triangular_solve(A_low, B, upper=False)
Batched:
>>> # Convert to COO for batching (since torch.stack doesn't work with CSR) >>> A_coo = A.to_sparse_coo() >>> A_b = torch.stack([A_coo, A_coo]) # (2, 3, 3) >>> B_b = torch.stack([B, B]) # (2, 3, 1) >>> x_b = sparse_triangular_solve(A_b, B_b) >>> x_b.shape torch.Size([2, 3, 1])