Source code for torchsparsegradutils.cupy.cupy_sparse_solve

import inspect
import warnings
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch

import torchsparsegradutils.cupy as tsgucupy

# from cupyx.scipy.sparse.linalg import cg, cgs, minres, gmres, spsolve


def _wrap_iterative_solver(base_solver, backend_type, solver_name=None):
    """Wrap an iterative solver to handle parameter mapping and return format."""

    # Introspect the actual solver signature to determine accepted parameter names
    try:
        sig_params = set(inspect.signature(base_solver).parameters.keys())
    except (ValueError, TypeError):
        sig_params = set()

    def wrapped_solver(A, b, **solver_kwargs):
        # Create a copy to avoid modifying the original
        filtered_kwargs = solver_kwargs.copy()

        # Extract tolerance parameter and map to correct name for the backend
        tolerance = filtered_kwargs.pop("tol", None)
        atol = filtered_kwargs.pop("atol", None)

        if tolerance is not None:
            # Use introspection to determine correct tolerance parameter name
            if "rtol" in sig_params:
                filtered_kwargs["rtol"] = tolerance
            elif "tol" in sig_params:
                filtered_kwargs["tol"] = tolerance

        # Handle atol parameter if the solver accepts it
        if atol is not None and "atol" in sig_params:
            filtered_kwargs["atol"] = atol

        # Filter to only parameters the solver actually accepts
        common_params = {"x0", "M", "callback", "show", "check"}
        if sig_params:
            final_kwargs = {k: v for k, v in filtered_kwargs.items() if k in sig_params or k in common_params}
        else:
            # Fallback: pass everything if we couldn't introspect
            final_kwargs = filtered_kwargs

        # Call the base solver
        result = base_solver(A, b, **final_kwargs)

        # Handle return format - some solvers return (solution, info) tuples
        if isinstance(result, tuple):
            return result[0]  # Return just the solution
        return result

    return wrapped_solver


def _wrap_direct_solver(base_solver):
    """Wrap a direct solver to ignore tolerance parameters."""

    def wrapped_solver(A, b, **solver_kwargs):
        # Direct solvers don't use iterative solver parameters, so ignore them all
        filtered_kwargs = {
            k: v
            for k, v in solver_kwargs.items()
            if k not in ["tol", "tolerance", "atol", "rtol", "maxiter", "matvec_max"]
        }
        return base_solver(A, b, **filtered_kwargs)

    return wrapped_solver


def _get_solver_function(solver_name, xsp, device):
    """Get the appropriate solver function based on the backend."""
    if solver_name is None or callable(solver_name):
        return solver_name

    # Determine backend type
    backend_type = "scipy" if device.type == "cpu" else "cupy"

    solver_map = {
        "cg": _wrap_iterative_solver(xsp.linalg.cg, backend_type, "cg"),
        "cgs": _wrap_iterative_solver(xsp.linalg.cgs, backend_type, "cgs"),
        "minres": _wrap_iterative_solver(xsp.linalg.minres, backend_type, "minres"),
        "gmres": _wrap_iterative_solver(xsp.linalg.gmres, backend_type, "gmres"),
        "spsolve": _wrap_direct_solver(xsp.linalg.spsolve),
    }

    if solver_name not in solver_map:
        raise ValueError(f"Unknown solver: {solver_name}. Supported solvers: {list(solver_map.keys())}")

    return solver_map[solver_name]


[docs] def sparse_solve_c4t( A: torch.Tensor, B: torch.Tensor, solve: Optional[Union[str, Callable[..., Any]]] = None, transpose_solve: Optional[Union[str, Callable[..., Any]]] = None, **kwargs: Any, ) -> torch.Tensor: r""" Solve sparse linear systems using CuPy / SciPy with automatic backend selection. Solves :math:`A X = B` using CPU (NumPy / SciPy) or GPU (CuPy / cupyx.scipy) backends, chosen from the device of the input PyTorch sparse tensor ``A``. Supports selected iterative solvers and a direct sparse solve with automatic COO / CSR format conversion and an autograd-compatible backward pass via a transpose solve. Parameters ---------- A : torch.Tensor Sparse square matrix of shape ``(n, n)`` in ``torch.sparse_coo`` or ``torch.sparse_csr`` layout. B : torch.Tensor Right-hand side(s). Shape ``(n,)`` (vector RHS) or ``(n, k)`` (multi-RHS). solve : {"cg", "cgs", "minres", "gmres", "spsolve"} or callable, optional Solver selector or a custom callable ``solve(A, b, **kwargs) -> x``. Built-ins: - ``"cg"`` : Conjugate Gradient (SPD; vector RHS only) - ``"cgs"`` : Conjugate Gradient Squared (vector RHS only) - ``"minres"`` : MINRES (symmetric; vector RHS only) - ``"gmres"`` : GMRES (vector RHS only) - ``"spsolve"`` : Direct sparse solve (supports multi-RHS) If ``None`` (default): - vector RHS → direct ``spsolve`` - multi-RHS → factorize then solve (SciPy/CuPy factorized) transpose_solve : {"cg", "cgs", "minres", "gmres", "spsolve"} or callable, optional Solver for the transpose system :math:`A^T y = g` used in backprop. Defaults to using the same selection as ``solve`` (or factorized). **kwargs : dict Additional solver parameters passed through to the chosen backend: - Iterative solvers commonly accept ``tol``/``rtol``, ``atol``, ``maxiter``, and optionally ``x0``, ``M``, ``callback``; unsupported kwargs are ignored. - Direct ``spsolve`` ignores iteration controls. Returns ------- torch.Tensor Solution tensor ``X`` with the same shape and dtype as ``B`` (or cast to match ``A.dtype`` if necessary) and on the same device as ``A``. Raises ------ TypeError If ``A`` is not ``torch.sparse_coo`` or ``torch.sparse_csr``. ValueError If ``A`` is not square; if ``B`` has incompatible shape; or if an iterative solver is requested for a multi-RHS input. Notes ----- Backend selection - CPU tensors → NumPy/SciPy - CUDA tensors → CuPy/cupyx.scipy Solver compatibility - Iterative solvers (``cg``, ``cgs``, ``minres``, ``gmres``): **vector RHS only** - Direct solver (``spsolve``): supports vector **and** multi-RHS Performance considerations - CSR is typically more efficient than COO for these solvers. - Backends may internally convert to CSC/CSR and emit efficiency warnings. - SciPy ``minres`` may upcast float32 to float64 on CPU. Gradients The backward pass solves :math:`A^T y = \\mathrm{grad}` using the same backend and forms sparse gradients for ``A`` by only computing entries at its nonzero positions. Examples -------- Basic solve (default direct solver) >>> import torch >>> from torchsparsegradutils.cupy import sparse_solve_c4t >>> idx = torch.tensor([[0, 1, 1], [0, 0, 1]]) >>> val = torch.tensor([2.0, -1.0, 2.0]) >>> A = torch.sparse_coo_tensor(idx, val, (2, 2)) >>> b = torch.tensor([1.0, 3.0]) >>> x = sparse_solve_c4t(A, b) >>> x.shape torch.Size([2]) Iterative solver >>> x_cg = sparse_solve_c4t(A, b, solve="cg", tol=1e-8) Multi-RHS with direct solve >>> B = torch.randn(2, 3) >>> X = sparse_solve_c4t(A, B, solve="spsolve") >>> X.shape torch.Size([2, 3]) Note: CUDA backend (CuPy) is selected automatically when tensors are on GPU. See Also -------- torchsparsegradutils.jax.sparse_solve_j4t : JAX-backed sparse solver with autograd support. torchsparsegradutils.cupy.t2c_coo, torchsparsegradutils.cupy.t2c_csr : PyTorch→CuPy/NumPy sparse converters used internally. torchsparsegradutils.cupy.c2t_coo, torchsparsegradutils.cupy.c2t_csr : CuPy/NumPy→PyTorch sparse converters used internally. """ # Input validation if not isinstance(A, torch.Tensor) or not isinstance(B, torch.Tensor): raise ValueError("Both A and B should be instances of torch.Tensor") if A.layout not in (torch.sparse_coo, torch.sparse_csr): raise TypeError(f"Unsupported sparse layout: {A.layout}. Only COO and CSR are supported.") if A.dim() != 2: raise ValueError("A must be a 2D tensor") if A.shape[0] != A.shape[1]: raise ValueError("A must be square") if B.dim() not in (1, 2): raise ValueError("B must be a 1D or 2D tensor") if B.shape[0] != A.shape[0]: raise ValueError(f"Incompatible dimensions: A has shape {A.shape}, B has shape {B.shape}") # Check for iterative solver compatibility with multi-RHS vector_solvers = {"cg", "cgs", "minres", "gmres"} is_multi_rhs = B.ndim == 2 and B.shape[1] > 1 if solve in vector_solvers and is_multi_rhs: raise ValueError( f"Solver '{solve}' does not support multi-RHS (B.shape={B.shape}). " f"Use solve='spsolve' or solve=None for multi-RHS problems, or reshape B to a vector." ) if transpose_solve in vector_solvers and is_multi_rhs: raise ValueError( f"Transpose solver '{transpose_solve}' does not support multi-RHS (B.shape={B.shape}). " f"Use transpose_solve='spsolve' or transpose_solve=None for multi-RHS problems." ) # Convert string solver names to functions xp, xsp = tsgucupy._get_array_modules(A.data) # Warn about dtype issues with minres on CPU if solve == "minres" and A.device.type == "cpu": warnings.warn( "Using 'minres' solver on CPU may change the output dtype to float64 " "even with float32 inputs due to SciPy implementation. Consider using " "'cg' or 'spsolve' for consistent dtype behavior.", UserWarning, stacklevel=2, ) if transpose_solve == "minres" and A.device.type == "cpu": warnings.warn( "Using 'minres' transpose solver on CPU may change the output dtype to float64 " "even with float32 inputs due to SciPy implementation.", UserWarning, stacklevel=2, ) solve_func = _get_solver_function(solve, xsp, A.device) transpose_solve_func = _get_solver_function(transpose_solve, xsp, A.device) return SparseSolveC4T.apply(A, B, solve_func, transpose_solve_func, kwargs)
[docs] class SparseSolveC4T(torch.autograd.Function): r""" Autograd function for CuPy / SciPy–backed sparse solves. Forward: converts PyTorch sparse ``A`` (COO / CSR) and dense ``B`` to backend sparse / dense types, then calls either an iterative solver, direct ``spsolve``, or a cached factorized solve for multi-RHS. Backward: solves :math:`A^T y = \mathrm{grad}_X` and reconstructs :math:`\nabla_A` only at the nonzero positions of ``A`` via :math:`\nabla_A = -(A^{-T} \, \mathrm{grad}_X) \, X^T` (sampled at existing sparsity pattern) while returning :math:`\nabla_B = A^{-T} \mathrm{grad}_X`. See Also -------- torchsparsegradutils.jax.sparse_solve_j4t : JAX-backed sparse solver. """
[docs] @staticmethod def forward( ctx, A: torch.Tensor, B: torch.Tensor, solve: Optional[Callable[..., Any]], transpose_solve: Optional[Callable[..., Any]], kwargs: Dict[str, Any], ) -> torch.Tensor: xp, xsp = tsgucupy._get_array_modules(A.data) grad_flag = A.requires_grad or B.requires_grad ctx.transpose_solve = transpose_solve # Transfer data to cupy/scipy if A.layout == torch.sparse_coo: A_c = tsgucupy.t2c_coo(A.detach()) elif A.layout == torch.sparse_csr: A_c = tsgucupy.t2c_csr(A.detach()) else: raise TypeError(f"Unsupported layout type: {A.layout}") B_c = tsgucupy._torch_to_backend(B.detach(), xp) # Solve the sparse system ctx.factorisedsolver = None ctx.kwargs = kwargs # Store kwargs for backward pass if solve is not None: x_c = solve(A_c, B_c, **kwargs) elif (B.ndim == 1) or (B.shape[1] == 1): # xp.sparse.linalg.spsolve only works if B is a vector but is fully on GPU with cupy # TODO: Is this still true? # see: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.spsolve.html#scipy.sparse.linalg.spsolve x_c = xsp.linalg.spsolve(A_c, B_c) else: # Make use of a factorisation (only the solver is then on the GPU with cupy) # We store it in ctx to reuse it in the backward pass ctx.factorisedsolver = xsp.linalg.factorized(A_c) x_c = ctx.factorisedsolver(B_c) if isinstance(x_c, tuple): # If the solver returns a tuple, we assume the first element is the solution x_c = x_c[0] x = tsgucupy._backend_to_torch(x_c) # Ensure output dtype matches input dtype if x.dtype != A.dtype: x = x.to(dtype=A.dtype) if (B.ndim == 2) and (x.ndim == 1): x = x.unsqueeze(-1) ctx.save_for_backward(A, x) ctx.A_c = A_c x.requires_grad = grad_flag return x
[docs] @staticmethod def backward(ctx, grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]: A, x = ctx.saved_tensors xp, xsp = tsgucupy._get_array_modules(A.data) # Unsqueeze, if necessary is_vector = x.ndim == 1 if is_vector: x = x.unsqueeze(-1) grad = grad.unsqueeze(-1) grad_c = tsgucupy._torch_to_backend(grad.detach(), xp) # Backprop rule: gradB = A^{-T} grad if ctx.transpose_solve is not None: gradB_c = ctx.transpose_solve(ctx.A_c, grad_c, **ctx.kwargs) elif ctx.factorisedsolver is None: gradB_c = xsp.linalg.spsolve(xp.transpose(ctx.A_c), grad_c) else: # Re-use factorised solver from forward pass gradB_c = ctx.factorisedsolver(grad_c, trans="T") if isinstance(gradB_c, tuple): # If the solver returns a tuple, we assume the first element is the gradient gradB_c = gradB_c[0] gradB = tsgucupy._backend_to_torch(gradB_c) # Ensure gradient dtype matches input dtype if gradB.dtype != A.dtype: gradB = gradB.to(dtype=A.dtype) if (grad.ndim == 2) and (gradB.ndim == 1): gradB = gradB.unsqueeze(-1) # The gradient with respect to the matrix A seen as a dense matrix would # lead to a backprop rule as follows # gradA = -(A^{-T} grad)(A^{-1} B) = - gradB @ x.T # but we are only interested in the gradient with respect to # the (non-zero) values of A. To save memory, instead of computing the full # dense matrix gradB @ x.T and then subsampling at the nnz locations in a, # we can directly only compute the required values: # gradA[i,j] = - dotprod(gradB[i,:], x[j,:]) # We start by getting the i and j indices: if A.layout == torch.sparse_coo: A_row_idx = A.indices()[0, :] A_col_idx = A.indices()[1, :] else: A_col_idx = A.col_indices() A_crow_idx = A.crow_indices() # Uncompress row indices: A_row_idx = torch.repeat_interleave( torch.arange(A.size()[0], device=A.device), A_crow_idx[1:] - A_crow_idx[:-1] ) mgradbselect = -gradB.index_select(0, A_row_idx) # -gradB[i, :] xselect = x.index_select(0, A_col_idx) # x[j, :] # Dot product: mgbx = mgradbselect * xselect gradA = torch.sum(mgbx, dim=1) # Ensure gradient dtype matches input dtype if gradA.dtype != A.dtype: gradA = gradA.to(dtype=A.dtype) if A.layout == torch.sparse_coo: gradA = torch.sparse_coo_tensor(torch.stack([A_row_idx, A_col_idx]), gradA, A.shape) else: gradA = torch.sparse_csr_tensor(A_crow_idx, A_col_idx, gradA, A.shape) # Squeeze gradB back to original shape if it was a vector if is_vector: gradB = gradB.squeeze(-1) return gradA, gradB, None, None, None