Core Operations

This module contains the core sparse tensor operations.

Sparse Matrix Multiplication

torchsparsegradutils.sparse_matmul.sparse_mm(A: Tensor, B: Tensor) Tensor[source]

Sparse–dense matrix multiplication with memory-efficient gradients.

Computes \(\mathbf{C} = \mathbf{A}\,\mathbf{B}\) where \(\mathbf{A} \in \mathbb{R}^{n\times m}\) is sparse (COO/CSR), \(\mathbf{B} \in \mathbb{R}^{m\times p}\) is dense, and \(\mathbf{C} \in \mathbb{R}^{n\times p}\). Gradients preserve the sparsity pattern of \(\mathbf{A}\). Supports unbatched 2D (n,m) @ (m,p) and batched 3D inputs by block–diagonalising the batch of sparse matrices and concatenating dense matrices along the batch dimension.

Let the upstream gradient be \(\mathbf{G} = \frac{\partial \mathcal{L}}{\partial \mathbf{C}} \in \mathbb{R}^{n\times p}\). The gradients are:

Gradient with respect to B (dense):

\[\frac{\partial \mathcal{L}}{\partial \mathbf{B}} \;=\; \mathbf{A}^{\top} \, \mathbf{G}.\]

Gradient with respect to A (sparse): For a dense view one has

\[\frac{\partial \mathcal{L}}{\partial \mathbf{A}} \;=\; \mathbf{G}\, \mathbf{B}^{\top},\]

but we evaluate only the entries at the nonzeros of \(\mathbf{A}\). Equivalently, for a nonzero entry \(\mathbf{A}_{ij}\) the contribution is

\[\bigg[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\bigg]_{ij} \;=\; \sum_{k=1}^{p} \mathbf{G}_{ik} \, \mathbf{B}_{jk} \;=\; \mathbf{G}_{i,:} \,\cdot\, \mathbf{B}_{j,:},\]

where the dot denotes a row-wise inner product across the \(p\) right-hand sides.

Parameters:
  • A (torch.Tensor, sparse COO or CSR, shape (n, m) or (b, n, m)) – Left operand. For batched input, all batch items must share (n, m). All tensors must be on the same device.

  • B (torch.Tensor, dense (strided), shape (m, p) or (b, m, p)) – Right operand. Must have the same number of dimensions as A and matching batch size / inner dimension m.

Returns:

Dense result of shape (n, p) or (b, n, p).

Return type:

torch.Tensor

Raises:
  • ValueError – If A or B are not tensors; if ranks are < 2 or not both 2D/3D; if layouts are incompatible (A not COO/CSR or B not dense); if shapes are incompatible (batch or inner dims).

  • RuntimeError – If the underlying sparse matmul fails.

Notes

This avoids dense gradients for sparse matrices [1a] (a known issue with torch.sparse.mm() backprop), computing only gradients at the nonzero entries of \(A\) to reduce memory use.

See also

torch.sparse.mm

PyTorch’s native sparse @ dense.

sparse_generic_lstsq

Sparse least-squares with sparse-aware gradients.

References

[1a]

PyTorch issue on dense gradients for sparse ops: https://github.com/pytorch/pytorch/issues/41128

Examples

Basic (unbatched):

>>> indices = torch.tensor([[0, 0, 1, 1, 2, 2],
...                         [0, 2, 1, 3, 0, 2]])
>>> values = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
>>> A = torch.sparse_coo_tensor(indices, values, (3, 4))
>>> B = torch.randn(4, 2)
>>> out = sparse_mm(A, B)
>>> out.shape
torch.Size([3, 2])

Batched:

>>> A_batch = torch.stack([A, A])          # (2, 3, 4) — COO stack
>>> B_batch = torch.randn(2, 4, 2)         # (2, 4, 2)
>>> out = sparse_mm(A_batch, B_batch)
>>> out.shape
torch.Size([2, 3, 2])

With gradients:

>>> A.requires_grad_(True)  
tensor(...)
>>> B.requires_grad_(True)  
tensor(...)
>>> out = sparse_mm(A, B)
>>> out.sum().backward()
>>> A.grad.is_sparse
True
class torchsparsegradutils.sparse_matmul.SparseMatMul(*args, **kwargs)[source]

Bases: Function

Autograd kernel for memory-efficient sparse matrix multiplication.

See also

sparse_mm

User-facing function that calls this autograd function.

torch.sparse.mm

PyTorch’s native sparse matrix multiplication.

Indexed Matrix Multiplication

torchsparsegradutils.indexed_matmul.gather_mm(a: Tensor, b: Tensor, idx_b: Tensor) Tensor[source]

Per-row indexed matrix multiplication.

For each row i in a this computes a[i] @ b[idx_b[i]] and stacks the results into the output.

Parameters:
  • a (torch.Tensor, shape (N, D1)) – Left operand with one row per output.

  • b (torch.Tensor, shape (R, D1, D2)) – Bank of transformation matrices.

  • idx_b (torch.Tensor, shape (N,), integer dtype) – Indices selecting which matrix in b to use for each row. Values must satisfy 0 <= idx_b[i] < R.

Returns:

Row-wise results where out[i] = a[i] @ b[idx_b[i]].

Return type:

torch.Tensor, shape (N, D2)

Raises:
  • NotImplementedError – If the fallback path is used on a PyTorch version lacking nested tensor matmul support (requires PyTorch >= 2.4).

  • ValueError – If inputs are not tensors, ranks are incorrect, or sizes are incompatible.

Notes

If DGL is available, this uses dgl.ops.gather_mm() [1b]. Otherwise it uses a dependency-free PyTorch nested-tensor fallback.

See also

segment_mm

Segmented matrix multiplication over contiguous chunks.

References

Examples

>>> import torch
>>> # N = 5, D1 = 3, D2 = 2, R = 3
>>> a = torch.randn(5, 3)
>>> b = torch.randn(3, 3, 2)
>>> idx_b = torch.tensor([0, 1, 0, 2, 1])
>>> out = gather_mm(a, b, idx_b)
>>> out.shape
torch.Size([5, 2])

All rows using the same matrix:

>>> torch.allclose(gather_mm(a, b, torch.zeros(5, dtype=torch.long)), a @ b[0])
True

Mixed indexing example:

>>> # Different transformation for each row
>>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])  # (2, 2)
>>> b = torch.tensor([[[1.0, 0.0], [0.0, 1.0]],  # Identity
...                   [[2.0, 0.0], [0.0, 2.0]]])  # 2x scale
>>> idx_b = torch.tensor([0, 1])  # Use identity, then 2x scale
>>> result = gather_mm(a, b, idx_b)
>>> result
tensor([[1., 2.],
        [6., 8.]])
torchsparsegradutils.indexed_matmul.segment_mm(a: Tensor, b: Tensor, seglen_a: Tensor) Tensor[source]

Segmented matrix multiplication with variable-length segments.

Performs matrix multiplication between contiguous segments of a and the corresponding matrices in b. If seglen_a == [10, 5, 0, 3], the operator computes:

a[0:10] @ b[0], a[10:15] @ b[1],
a[15:15] @ b[2], a[15:18] @ b[3]
Parameters:
  • a (torch.Tensor, shape (N, D1)) – Left operand containing the concatenation of all segments.

  • b (torch.Tensor, shape (R, D1, D2)) – Right operand containing one (D1, D2) matrix per segment.

  • seglen_a (torch.Tensor, shape (R,), integer dtype) – Length of each segment in a. seglen_a.sum() must equal N.

Returns:

Concatenation of all segment results in original order.

Return type:

torch.Tensor, shape (N, D2)

Raises:
  • NotImplementedError – If the fallback path is used on a PyTorch version lacking nested tensor matmul support (requires PyTorch >= 2.4).

  • ValueError – If input ranks or sizes are incompatible.

Notes

If DGL is available, this uses dgl.ops.segment_mm() [1c] (typically faster). Otherwise it falls back to a PyTorch nested-tensor implementation.

See also

gather_mm

Per-row indexed matrix multiplication.

References

Examples

>>> import torch
>>> # N = 18, D1 = 4, D2 = 2
>>> a = torch.randn(18, 4)
>>> b = torch.randn(3, 4, 2)
>>> seglen_a = torch.tensor([10, 5, 3])
>>> out = segment_mm(a, b, seglen_a)
>>> out.shape
torch.Size([18, 2])

Zero-length segment:

>>> seglen_a = torch.tensor([10, 5, 0, 3])
>>> b = torch.randn(4, 4, 2)
>>> segment_mm(a, b, seglen_a).shape
torch.Size([18, 2])

Sparse Linear Solvers

torchsparsegradutils.sparse_solve.sparse_triangular_solve(A: Tensor, B: Tensor, upper: bool = True, unitriangular: bool = False, transpose: bool = False) Tensor[source]

Sparse triangular solve with memory-efficient sparse gradients.

Solves the triangular system \(\mathbf{A}\,\mathbf{x} = \mathbf{B}\) (or \(\mathbf{A}^{\top}\,\mathbf{x} = \mathbf{B}\) if transpose=True), where \(\mathbf{A} \in \mathbb{R}^{m\times m}\) is sparse triangular (COO/CSR) and \(\mathbf{B} \in \mathbb{R}^{m\times p}\) is dense. Gradients preserve the sparsity pattern of \(\mathbf{A}\) by evaluating only at its nonzero entries. Supports unbatched 2D and batched 3D inputs; COO inputs are converted to CSR internally for the factor solve.

Let the upstream gradient be \(\mathbf{G} = \frac{\partial \mathcal{L}}{\partial \mathbf{x}}\) for a scalar objective \(\mathcal{L}\) and solution \(\mathbf{x}\). The dense-form gradients are

Gradient with respect to B (dense):

\[\frac{\partial \mathcal{L}}{\partial \mathbf{B}} \;=\; \mathbf{A}^{-\top} \, \mathbf{G},\]

and for transpose=True replace \(\mathbf{A}\) by \(\mathbf{A}^{\top}\) so that \(\frac{\partial \mathcal{L}}{\partial \mathbf{B}} = \left(\mathbf{A}^{\top}\right)^{-\top} \mathbf{G} = \mathbf{A}^{-1} \mathbf{G}\).

Gradient with respect to A (sparse):

\[\frac{\partial \mathcal{L}}{\partial \mathbf{A}} \;=\; -\big(\mathbf{A}^{-\top} \, \mathbf{G}\big)\, \mathbf{x}^{\top},\]

and only entries at the nonzeros of \(\mathbf{A}\) are evaluated. Equivalently, for a nonzero \(\mathbf{A}_{ij}\) the contribution is

\[\bigg[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\bigg]_{ij} \;=\; -\, \big(\mathbf{A}^{-\top} \, \mathbf{G}\big)_{i,:} \,\cdot\, \mathbf{x}_{j,:},\]

where the dot denotes a row-wise inner product across the \(p\) right-hand sides.

Parameters:
  • A (torch.Tensor, sparse COO or CSR, shape (m, m) or (b, m, m)) – Sparse triangular coefficient matrix. Must be square per batch. All tensors must be on the same device.

  • B (torch.Tensor, dense (strided), shape (m, p) or (b, m, p)) – Right-hand side. B.shape[-2] must equal A.shape[-2] (m).

  • upper (bool, optional) – If True (default), treat A as upper-triangular; else lower-triangular.

  • unitriangular (bool, optional) – If True, assume unit diagonal (implicit ones). The stored matrix must be strictly triangular (no explicit diagonal entries). Default: False.

  • transpose (bool, optional) – If True, solves \(A^\top x = B\); otherwise \(A x = B\). Default: False.

Returns:

Solution with the same shape as B: (m, p) or (b, m, p).

Return type:

torch.Tensor

Raises:
  • ValueError – If inputs are not tensors; ranks are < 2 or not both 2D/3D; layouts are incompatible (A not COO/CSR or B not dense); shapes are incompatible; batch sizes differ; or if unitriangular=True but explicit diagonal entries are present.

  • RuntimeError – If the underlying triangular solve fails.

Notes

Backprop computes gradients only at nonzero entries of \(\mathbf{A}\), keeping the gradient sparse and reducing memory. COO inputs are converted to CSR since PyTorch’s triangular solver requires CSR [1e]. For autograd implementation details, see [2e].

See also

torch.sparse.mm

Sparse @ dense multiply.

torch.linalg.solve_triangular

Dense triangular solver (modern API).

References

[1e]

PyTorch issue on sparse triangular solve: https://github.com/pytorch/pytorch/issues/87358

[2e]

PyTorch issue on autograd/triangular solve: https://github.com/pytorch/pytorch/issues/88890

Examples

Upper-triangular:

>>> import torch
>>> from torchsparsegradutils import sparse_triangular_solve
>>> A = torch.sparse_csr_tensor([0, 2, 3, 4], [0, 2, 1, 2],
...                             torch.tensor([2.0, 1.0, 3.0, 1.0]), (3, 3))
>>> B = torch.tensor([[1.0], [2.0], [3.0]])
>>> x = sparse_triangular_solve(A, B, upper=True)
>>> x.shape
torch.Size([3, 1])

Lower-triangular:

>>> A_low = torch.sparse_csr_tensor([0, 1, 3, 5], [0, 0, 1, 0, 2],
...                                 torch.tensor([2.0, 1.0, 3.0, 0.5, 1.0]), (3, 3))
>>> x = sparse_triangular_solve(A_low, B, upper=False)

Batched:

>>> # Convert to COO for batching (since torch.stack doesn't work with CSR)
>>> A_coo = A.to_sparse_coo()
>>> A_b = torch.stack([A_coo, A_coo])   # (2, 3, 3)
>>> B_b = torch.stack([B, B])   # (2, 3, 1)
>>> x_b = sparse_triangular_solve(A_b, B_b)
>>> x_b.shape
torch.Size([2, 3, 1])
class torchsparsegradutils.sparse_solve.SparseTriangularSolve(*args, **kwargs)[source]

Bases: Function

Autograd function for memory-efficient sparse triangular system solving.

See also

sparse_triangular_solve

User-facing function that calls this autograd function.

torch.triangular_solve

PyTorch’s native triangular solver.

torchsparsegradutils.sparse_solve.sparse_generic_solve(A: Tensor, B: Tensor, solve: Callable[[...], Tensor] | None = None, transpose_solve: Callable[[...], Tensor] | None = None, **kwargs) Tensor[source]

Sparse linear solve with iterative methods and sparse-aware gradients.

Solves \(\mathbf{A}\,\mathbf{x} = \mathbf{B}\) with sparse \(\mathbf{A} \in \mathbb{R}^{n\times n}\) (COO/CSR) and dense \(\mathbf{B} \in \mathbb{R}^{n\times p}\) using iterative methods, while preserving sparsity in \(\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\). Supports single (vector) and multiple (matrix) right-hand sides and works with non-differentiable solvers via the implicit function theorem.

Let \(\mathbf{G} = \frac{\partial \mathcal{L}}{\partial \mathbf{x}}\) be the upstream gradient and \(\mathbf{x}\) the solution. The dense-form gradients are

Gradient with respect to B (dense):

\[\frac{\partial \mathcal{L}}{\partial \mathbf{B}} \;=\; \mathbf{A}^{-\top} \, \mathbf{G} \;\equiv\; \mathbf{G}_B.\]

Gradient with respect to A (sparse):

\[\frac{\partial \mathcal{L}}{\partial \mathbf{A}} \;=\; -\, \mathbf{G}_B\, \mathbf{x}^{\top}.\]

We evaluate only the entries corresponding to nonzeros of \(\mathbf{A}\), yielding a sparse gradient tensor with memory proportional to nnz(A). Equivalently, for a nonzero \(\mathbf{A}_{ij}\) the contribution is

\[\bigg[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\bigg]_{ij} \;=\; -\, (\mathbf{G}_B)_{i,:} \,\cdot\, \mathbf{x}_{j,:}.\]
Parameters:
  • A (torch.Tensor, sparse COO or CSR, shape (n, n)) – Sparse square coefficient matrix. Must be invertible (or suitable) for the chosen solver. All tensors must be on the same device.

  • B (torch.Tensor, dense (strided), shape (n,) or (n, k)) – Right-hand side(s). B.shape[0] must equal A.shape[0].

  • solve (callable, optional) –

    Forward solver with signature solve(A, B, **kwargs) -> X. If None, uses minres (recommended for symmetric indefinite). Other typical choices include:

    • linear_cg (SPD matrices)

    • bicgstab (general non-symmetric)

  • transpose_solve (callable, optional) – Solver for the transpose system used in backprop, with signature transpose_solve(A, G, **kwargs) -> Y that solves \(A^\top Y = G\) in the least-squares / iterative sense. If None, defaults to solve.

  • **kwargs (dict) – Extra keyword arguments forwarded to the solvers (e.g., tolerances, iteration caps, or solver-specific settings objects).

Returns:

Solution tensor X with the same shape as B: (n,) or (n, k).

Return type:

torch.Tensor

Raises:
  • ValueError – If inputs are not tensors; shapes are incompatible; ranks are invalid.

  • TypeError – If A is not COO/CSR or if B is not dense (strided).

  • UserWarning – If A and B use different dtypes (may affect solver behavior).

Notes

Only entries at the nonzeros of \(\mathbf{A}\) are computed, keeping the gradient sparse and memory-efficient.

See also

sparse_triangular_solve

Triangular systems with sparse-aware gradients.

sparse_generic_lstsq

Overdetermined least-squares with sparse-aware gradients.

Examples

>>> import torch
>>> from torchsparsegradutils import sparse_generic_solve
>>> from torchsparsegradutils.utils import linear_cg, bicgstab, minres
>>> # Symmetric positive definite example
>>> indices = torch.tensor([[0, 0, 1, 1, 2],
...                         [0, 1, 0, 1, 2]])
>>> values = torch.tensor([4.0, -1.0, -1.0, 4.0, 2.0])
>>> A = torch.sparse_coo_tensor(indices, values, (3, 3))
>>> B = torch.tensor([1.0, 2.0, 3.0])
>>> x = sparse_generic_solve(A, B, solve=linear_cg)
>>> x.shape
torch.Size([3])
>>> # Multiple RHS with BiCGSTAB
>>> X = sparse_generic_solve(A, torch.randn(3, 5), solve=bicgstab)
>>> X.shape
torch.Size([3, 5])
>>> # Default solver (MINRES)
>>> x = sparse_generic_solve(A, B)
>>> # With custom solver settings:
>>> from torchsparsegradutils.utils.linear_cg import LinearCGSettings
>>> settings = LinearCGSettings(max_cg_iterations=1000, cg_tolerance=1e-8)
>>> x = sparse_generic_solve(A, B, solve=linear_cg, settings=settings)
>>> # With gradients (A.grad is sparse)
>>> A.requires_grad_(True)  
tensor(...)
>>> B.requires_grad_(True)  
tensor(...)
>>> x = sparse_generic_solve(A, B)
>>> x.sum().backward()
>>> A.grad.is_sparse
True
class torchsparsegradutils.sparse_solve.SparseGenericSolve(*args, **kwargs)[source]

Bases: Function

Autograd function for sparse linear system solving with iterative methods.

See also

sparse_generic_solve

User-facing function that calls this autograd function.

Sparse Least Squares

torchsparsegradutils.sparse_lstsq.sparse_generic_lstsq(A: Tensor, B: Tensor, lstsq: Callable[[Tensor, Tensor], Tensor] | None = None, transpose_lstsq: Callable[[Tensor, Tensor], Tensor] | None = None) Tensor[source]

Sparse linear least squares with sparse-aware gradients.

Solves the overdetermined problem \(\min_x \|\mathbf{A}x - \mathbf{B}\|_2^2\) where \(\mathbf{A} \in \mathbb{R}^{m\times n}\) is sparse and tall (\(m>n\)) and \(\mathbf{B} \in \mathbb{R}^{m\times p}\) is dense. Backprop preserves the sparsity pattern by returning sparse gradients for \(\mathbf{A}\) at its nonzero entries only.

We assume \(\mathbf{A}\) has full column rank so that \(\mathbf{A}^{+}\mathbf{A}=\mathbf{I}\) (with \(\,\cdot^{+}\) the Moore–Penrose pseudoinverse). Let \(\mathbf{x} \in \mathbb{R}^{n\times p}\) denote the solution and let the upstream gradient be \(\frac{\partial \mathcal{L}}{\partial \mathbf{x}} \in \mathbb{R}^{n\times p}\) for some scalar objective \(\mathcal{L}\). Using Golub & Pereyra (1973) [1f], the gradients are:

Gradient with respect to B (dense):

\[\frac{\partial \mathcal{L}}{\partial \mathbf{B}} \;=\; (\mathbf{A}^{\top})^{+} \, \frac{\partial \mathcal{L}}{\partial \mathbf{x}} \;\equiv\; \mathbf{G}_B.\]

Gradient with respect to A (sparse): The dense form is

\[\frac{\partial \mathcal{L}}{\partial \mathbf{A}} \;=\; -\, \mathbf{G}_B\, \mathbf{x}^{\top} \; -\; (\mathbf{A}\,\mathbf{x} - \mathbf{B})\; \big(\mathbf{A}^{+}\, \mathbf{G}_B\big)^{\top},\]

and we evaluate only the entries corresponding to nonzeros of \(\mathbf{A}\) to keep the gradient sparse. Equivalently, for a nonzero entry \(\mathbf{A}_{ij}\) with residuals \(\mathbf{r}=\mathbf{A}\,\mathbf{x}-\mathbf{B}\) and \(\mathbf{H}=\mathbf{A}^{+}\,\mathbf{G}_B\), the contribution is

\[\bigg[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\bigg]_{ij} \;=\; -\, (\mathbf{G}_B)_{i,:}\,\cdot\, \mathbf{x}_{j,:} \; -\; \mathbf{r}_{i,:}\,\cdot\, \mathbf{H}_{j,:},\]

where dots denote row-wise inner products over the \(p\) right-hand sides.

Parameters:
  • A (torch.Tensor) – Sparse COO/CSR tensor of shape (m, n) with m>n and full column rank.

  • B (torch.Tensor) – Dense RHS of shape (m,) or (m, k) with B.shape[0] == A.shape[0].

  • lstsq (callable, optional) – Solver lstsq(A,B)->X ((n,) or (n,k)). Default: LSMR (torchsparsegradutils.utils.lsmr()).

  • transpose_lstsq (callable, optional) – Solver for transpose system in backward ((A^T) Y = G). Default: LSMR on A^T.

Returns:

Solution X minimizing \(\|AX - B\|_2^2\) with shape (n,) or (n,k).

Return type:

torch.Tensor

Raises:
  • TypeError – If A is not sparse COO/CSR.

  • ValueError – If dimension mismatch or if backward encounters non-tall A.

  • RuntimeError – If a provided solver fails or returns unexpected shape.

See also

SparseGenericLstsq

Autograd implementation.

References

[1f]

Gene H. Golub and Victor Pereyra. The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems Whose Variables Separate. SIAM Journal on Numerical Analysis, 10(2):413-432, 1973.

Examples

>>> # Simple sparse least squares example:
>>> import torch
>>> from torchsparsegradutils import sparse_generic_lstsq
>>> indices = torch.tensor([[0, 1, 2, 3, 4, 1, 2, 3],
...                         [0, 0, 0, 0, 1, 1, 1, 2]])
>>> values = torch.tensor([1.0, 2.0, 1.0, 3.0, 1.0, 2.0, 1.0, 1.0])
>>> A = torch.sparse_coo_tensor(indices, values, (5, 3)).coalesce()
>>> B = torch.randn(5)
>>> x = sparse_generic_lstsq(A, B)
>>> x.shape
torch.Size([3])
>>> # Multiple RHS:
>>> Bm = torch.randn(5, 4)
>>> Xm = sparse_generic_lstsq(A, Bm)
>>> Xm.shape
torch.Size([3, 4])
>>> # Custom solver:
>>> from torchsparsegradutils.utils import lsmr
>>> def my_lstsq(A_, B_):
...     return lsmr(A_, B_, atol=1e-10, btol=1e-10)[0]
>>> _ = sparse_generic_lstsq(A, B, lstsq=my_lstsq)
>>> # Gradients:
>>> A.requires_grad_(True)  
tensor(...)
>>> B.requires_grad_(True)  
tensor(...)
>>> x = sparse_generic_lstsq(A, B)
>>> loss = x.sum()  # Simple loss to preserve sparsity
>>> loss.backward()
>>> A.grad.is_sparse
True
class torchsparsegradutils.sparse_lstsq.SparseGenericLstsq(*args, **kwargs)[source]

Bases: Function

Autograd kernel for sparse least squares with sparse-aware gradients.

See also

sparse_generic_lstsq

User wrapper.