Backend Integrations

This module contains integrations with external backends like CuPy and JAX.

CuPy Backend

torchsparsegradutils.cupy.cupy_sparse_solve.sparse_solve_c4t(A: Tensor, B: Tensor, solve: str | Callable[[...], Any] | None = None, transpose_solve: str | Callable[[...], Any] | None = None, **kwargs: Any) Tensor[source]

Solve sparse linear systems using CuPy / SciPy with automatic backend selection.

Solves \(A X = B\) using CPU (NumPy / SciPy) or GPU (CuPy / cupyx.scipy) backends, chosen from the device of the input PyTorch sparse tensor A. Supports selected iterative solvers and a direct sparse solve with automatic COO / CSR format conversion and an autograd-compatible backward pass via a transpose solve.

Parameters:
  • A (torch.Tensor) – Sparse square matrix of shape (n, n) in torch.sparse_coo or torch.sparse_csr layout.

  • B (torch.Tensor) – Right-hand side(s). Shape (n,) (vector RHS) or (n, k) (multi-RHS).

  • solve ({"cg", "cgs", "minres", "gmres", "spsolve"} or callable, optional) –

    Solver selector or a custom callable solve(A, b, **kwargs) -> x. Built-ins:

    • "cg" : Conjugate Gradient (SPD; vector RHS only)

    • "cgs" : Conjugate Gradient Squared (vector RHS only)

    • "minres" : MINRES (symmetric; vector RHS only)

    • "gmres" : GMRES (vector RHS only)

    • "spsolve" : Direct sparse solve (supports multi-RHS)

    If None (default):

    • vector RHS → direct spsolve

    • multi-RHS → factorize then solve (SciPy/CuPy factorized)

  • transpose_solve ({"cg", "cgs", "minres", "gmres", "spsolve"} or callable, optional) – Solver for the transpose system \(A^T y = g\) used in backprop. Defaults to using the same selection as solve (or factorized).

  • **kwargs (dict) –

    Additional solver parameters passed through to the chosen backend:

    • Iterative solvers commonly accept tol/rtol, atol, maxiter, and optionally x0, M, callback; unsupported kwargs are ignored.

    • Direct spsolve ignores iteration controls.

Returns:

Solution tensor X with the same shape and dtype as B (or cast to match A.dtype if necessary) and on the same device as A.

Return type:

torch.Tensor

Raises:
  • TypeError – If A is not torch.sparse_coo or torch.sparse_csr.

  • ValueError – If A is not square; if B has incompatible shape; or if an iterative solver is requested for a multi-RHS input.

Notes

Backend selection
  • CPU tensors → NumPy/SciPy

  • CUDA tensors → CuPy/cupyx.scipy

Solver compatibility
  • Iterative solvers (cg, cgs, minres, gmres): vector RHS only

  • Direct solver (spsolve): supports vector and multi-RHS

Performance considerations
  • CSR is typically more efficient than COO for these solvers.

  • Backends may internally convert to CSC/CSR and emit efficiency warnings.

  • SciPy minres may upcast float32 to float64 on CPU.

Gradients

The backward pass solves \(A^T y = \\mathrm{grad}\) using the same backend and forms sparse gradients for A by only computing entries at its nonzero positions.

Examples

Basic solve (default direct solver) >>> import torch >>> from torchsparsegradutils.cupy import sparse_solve_c4t >>> idx = torch.tensor([[0, 1, 1], [0, 0, 1]]) >>> val = torch.tensor([2.0, -1.0, 2.0]) >>> A = torch.sparse_coo_tensor(idx, val, (2, 2)) >>> b = torch.tensor([1.0, 3.0]) >>> x = sparse_solve_c4t(A, b) >>> x.shape torch.Size([2])

Iterative solver >>> x_cg = sparse_solve_c4t(A, b, solve=”cg”, tol=1e-8)

Multi-RHS with direct solve >>> B = torch.randn(2, 3) >>> X = sparse_solve_c4t(A, B, solve=”spsolve”) >>> X.shape torch.Size([2, 3])

Note: CUDA backend (CuPy) is selected automatically when tensors are on GPU.

See also

torchsparsegradutils.jax.sparse_solve_j4t

JAX-backed sparse solver with autograd support.

torchsparsegradutils.cupy.t2c_coo, torchsparsegradutils.cupy.t2c_csr, torchsparsegradutils.cupy.c2t_coo, torchsparsegradutils.cupy.c2t_csr

class torchsparsegradutils.cupy.cupy_sparse_solve.SparseSolveC4T(*args, **kwargs)[source]

Bases: Function

Autograd function for CuPy / SciPy–backed sparse solves.

Forward: converts PyTorch sparse A (COO / CSR) and dense B to backend sparse / dense types, then calls either an iterative solver, direct spsolve, or a cached factorized solve for multi-RHS.

Backward: solves \(A^T y = \mathrm{grad}_X\) and reconstructs \(\nabla_A\) only at the nonzero positions of A via \(\nabla_A = -(A^{-T} \, \mathrm{grad}_X) \, X^T\) (sampled at existing sparsity pattern) while returning \(\nabla_B = A^{-T} \mathrm{grad}_X\).

See also

torchsparsegradutils.jax.sparse_solve_j4t

JAX-backed sparse solver.

static forward(ctx, A: Tensor, B: Tensor, solve: Callable[[...], Any] | None, transpose_solve: Callable[[...], Any] | None, kwargs: Dict[str, Any]) Tensor[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See Extending torch.autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

static backward(ctx, grad: Tensor) Tuple[Tensor, Tensor, None, None, None][source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

CuPy / NumPy ↔ PyTorch sparse interop.

Utilities to convert between PyTorch sparse tensors (COO / CSR) and SciPy / CuPy sparse matrices. The NumPy / SciPy vs CuPy / cuSPARSE stack is selected automatically from the device (CPU vs CUDA) and CuPy availability.

Notes

  • COO conversions require coalesced PyTorch tensors. If an input COO is

    not coalesced, it is coalesced with a warning.

  • Index / indptr arrays use zero‑copy views where possible (no host/device

    round‑trips unless needed for dtype / layout changes).

  • CSR shapes and index semantics match PyTorch’s torch.sparse_csr_tensor.

  • Conversion preserves dtype & device (except when backend routines upcast).

See also

torch.sparse_coo_tensor torch.sparse_csr_tensor scipy.sparse cupyx.scipy.sparse

torchsparsegradutils.cupy.cupy_bindings.t2c_csr(x_torch: Tensor) Any[source]

Convert a PyTorch CSR tensor to a CuPy / NumPy CSR matrix.

Parameters:

x_torch (torch.Tensor) – 2D sparse CSR tensor (layout torch.sparse_csr).

Returns:

cupyx.scipy.sparse.csr_matrix (CUDA with CuPy) else scipy.sparse.csr_matrix.

Return type:

Any

Raises:

ValueError – If x_torch is not a 2D CSR tensor.

See also

c2t_csr

PyTorch conversion in the opposite direction.

t2c_coo, c2t_coo

Examples

>>> import torch
>>> from torchsparsegradutils.cupy import t2c_csr
>>> x = torch.randn(4, 4).to_sparse_csr()
>>> X = t2c_csr(x)
>>> X.shape
(4, 4)
torchsparsegradutils.cupy.cupy_bindings.c2t_csr(x_cupy: Any) Tensor[source]

Convert a CuPy / NumPy CSR matrix to a PyTorch CSR tensor.

Parameters:

x_cupy (Any) – CSR matrix with attributes data, indices, indptr.

Returns:

2D sparse CSR tensor (same shape & numeric data) with layout torch.sparse_csr.

Return type:

torch.Tensor

See also

t2c_csr

Reverse conversion.

Examples

>>> import numpy as np, scipy.sparse as nsp
>>> from torchsparsegradutils.cupy import c2t_csr
>>> X = nsp.random(5, 3, density=0.2, format='csr')
>>> x = c2t_csr(X)
>>> x.layout is torch.sparse_csr
True
torchsparsegradutils.cupy.cupy_bindings.t2c_coo(x_torch: Tensor) Any[source]

Convert a PyTorch COO tensor to a CuPy / NumPy COO matrix.

Parameters:

x_torch (torch.Tensor) – 2D sparse COO tensor (layout torch.sparse_coo). Coalesced automatically (with a warning) if duplicates are present.

Returns:

cupyx.scipy.sparse.coo_matrix (CUDA+CuPy) else scipy.sparse.coo_matrix.

Return type:

Any

Warns:

UserWarning – If the input is not coalesced.

Raises:

ValueError – If x_torch is not a 2D COO tensor.

See also

c2t_coo

Reverse COO conversion.

t2c_csr, c2t_csr

Examples

>>> import torch
>>> from torchsparsegradutils.cupy import t2c_coo
>>> idx = torch.tensor([[0, 1], [1, 0]])
>>> val = torch.tensor([2.0, 3.0])
>>> x = torch.sparse_coo_tensor(idx, val, (2, 2))
>>> X = t2c_coo(x)
>>> X.shape
(2, 2)
torchsparsegradutils.cupy.cupy_bindings.c2t_coo(x_cupy: Any) Tensor[source]

Convert a CuPy / NumPy COO matrix to a PyTorch COO tensor.

Parameters:

x_cupy (Any) – COO matrix with data, row, col attributes.

Returns:

2D sparse COO tensor with identical numerical content.

Return type:

torch.Tensor

See also

t2c_coo

Reverse conversion.

Examples

>>> import numpy as np, scipy.sparse as nsp
>>> from torchsparsegradutils.cupy import c2t_coo
>>> X = nsp.coo_matrix(np.array([[0, 1], [2, 0]]))
>>> x = c2t_coo(X)
>>> x.layout is torch.sparse_coo
True

JAX Backend