torchsparsegradutils.sparse_triangular_solve

torchsparsegradutils.sparse_triangular_solve(A: Tensor, B: Tensor, upper: bool = True, unitriangular: bool = False, transpose: bool = False) Tensor[source]

Sparse triangular solve with memory-efficient sparse gradients.

Solves the triangular system \(\mathbf{A}\,\mathbf{x} = \mathbf{B}\) (or \(\mathbf{A}^{\top}\,\mathbf{x} = \mathbf{B}\) if transpose=True), where \(\mathbf{A} \in \mathbb{R}^{m\times m}\) is sparse triangular (COO/CSR) and \(\mathbf{B} \in \mathbb{R}^{m\times p}\) is dense. Gradients preserve the sparsity pattern of \(\mathbf{A}\) by evaluating only at its nonzero entries. Supports unbatched 2D and batched 3D inputs; COO inputs are converted to CSR internally for the factor solve.

Let the upstream gradient be \(\mathbf{G} = \frac{\partial \mathcal{L}}{\partial \mathbf{x}}\) for a scalar objective \(\mathcal{L}\) and solution \(\mathbf{x}\). The dense-form gradients are

Gradient with respect to B (dense):

\[\frac{\partial \mathcal{L}}{\partial \mathbf{B}} \;=\; \mathbf{A}^{-\top} \, \mathbf{G},\]

and for transpose=True replace \(\mathbf{A}\) by \(\mathbf{A}^{\top}\) so that \(\frac{\partial \mathcal{L}}{\partial \mathbf{B}} = \left(\mathbf{A}^{\top}\right)^{-\top} \mathbf{G} = \mathbf{A}^{-1} \mathbf{G}\).

Gradient with respect to A (sparse):

\[\frac{\partial \mathcal{L}}{\partial \mathbf{A}} \;=\; -\big(\mathbf{A}^{-\top} \, \mathbf{G}\big)\, \mathbf{x}^{\top},\]

and only entries at the nonzeros of \(\mathbf{A}\) are evaluated. Equivalently, for a nonzero \(\mathbf{A}_{ij}\) the contribution is

\[\bigg[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\bigg]_{ij} \;=\; -\, \big(\mathbf{A}^{-\top} \, \mathbf{G}\big)_{i,:} \,\cdot\, \mathbf{x}_{j,:},\]

where the dot denotes a row-wise inner product across the \(p\) right-hand sides.

Parameters:
  • A (torch.Tensor, sparse COO or CSR, shape (m, m) or (b, m, m)) – Sparse triangular coefficient matrix. Must be square per batch. All tensors must be on the same device.

  • B (torch.Tensor, dense (strided), shape (m, p) or (b, m, p)) – Right-hand side. B.shape[-2] must equal A.shape[-2] (m).

  • upper (bool, optional) – If True (default), treat A as upper-triangular; else lower-triangular.

  • unitriangular (bool, optional) – If True, assume unit diagonal (implicit ones). The stored matrix must be strictly triangular (no explicit diagonal entries). Default: False.

  • transpose (bool, optional) – If True, solves \(A^\top x = B\); otherwise \(A x = B\). Default: False.

Returns:

Solution with the same shape as B: (m, p) or (b, m, p).

Return type:

torch.Tensor

Raises:
  • ValueError – If inputs are not tensors; ranks are < 2 or not both 2D/3D; layouts are incompatible (A not COO/CSR or B not dense); shapes are incompatible; batch sizes differ; or if unitriangular=True but explicit diagonal entries are present.

  • RuntimeError – If the underlying triangular solve fails.

Notes

Backprop computes gradients only at nonzero entries of \(\mathbf{A}\), keeping the gradient sparse and reducing memory. COO inputs are converted to CSR since PyTorch’s triangular solver requires CSR [1e]. For autograd implementation details, see [2e].

See also

torch.sparse.mm

Sparse @ dense multiply.

torch.linalg.solve_triangular

Dense triangular solver (modern API).

References

[1e]

PyTorch issue on sparse triangular solve: https://github.com/pytorch/pytorch/issues/87358

[2e]

PyTorch issue on autograd/triangular solve: https://github.com/pytorch/pytorch/issues/88890

Examples

Upper-triangular:

>>> import torch
>>> from torchsparsegradutils import sparse_triangular_solve
>>> A = torch.sparse_csr_tensor([0, 2, 3, 4], [0, 2, 1, 2],
...                             torch.tensor([2.0, 1.0, 3.0, 1.0]), (3, 3))
>>> B = torch.tensor([[1.0], [2.0], [3.0]])
>>> x = sparse_triangular_solve(A, B, upper=True)
>>> x.shape
torch.Size([3, 1])

Lower-triangular:

>>> A_low = torch.sparse_csr_tensor([0, 1, 3, 5], [0, 0, 1, 0, 2],
...                                 torch.tensor([2.0, 1.0, 3.0, 0.5, 1.0]), (3, 3))
>>> x = sparse_triangular_solve(A_low, B, upper=False)

Batched:

>>> # Convert to COO for batching (since torch.stack doesn't work with CSR)
>>> A_coo = A.to_sparse_coo()
>>> A_b = torch.stack([A_coo, A_coo])   # (2, 3, 3)
>>> B_b = torch.stack([B, B])   # (2, 3, 1)
>>> x_b = sparse_triangular_solve(A_b, B_b)
>>> x_b.shape
torch.Size([2, 3, 1])