Source code for torchsparsegradutils.utils.bicgstab

# Code imported from https://github.com/PythonOptimizers/pykrylov/blob/master/pykrylov/bicgstab/bicgstab.py
# Modifications to fit torchsparsegradutils

import logging
from typing import Callable, NamedTuple, Optional, Union

import torch

# Default (null) logger.
_null_log = logging.getLogger("bicgstab")
_null_log.disabled = True


[docs] class BICGSTABSettings(NamedTuple): matvec_max: Optional[int] = None # Max number of matvecs (default 2n) abstol: float = 1.0e-8 # Absolute stopping tolerance reltol: float = 1.0e-6 # Relative stopping tolerance precon: Optional[Union[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]] = None logger: logging.Logger = _null_log
[docs] def bicgstab( matmul_closure: Union[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]], rhs: torch.Tensor, initial_guess: Optional[torch.Tensor] = None, settings: BICGSTABSettings = BICGSTABSettings(), ) -> torch.Tensor: r""" Solve linear systems with the BiConjugate Gradient Stabilized (BiCGSTAB) method. Solves nonsymmetric, nonsingular systems :math:`A x = b`. Accepts either a matrix-like tensor (using ``.matmul``) or a callable for the matrix–vector product, and optionally a (left) preconditioner, also as tensor or callable, approximating :math:`M^{-1}`. Parameters ---------- matmul_closure : {torch.Tensor, callable(x) -> Ax} Matrix–vector multiplication operator. If a tensor is provided, its ``.matmul`` is used. rhs : torch.Tensor, shape (n,) or (n, k) Right-hand side vector(s). For multiple RHS, **each column is solved independently** (no block BiCGSTAB). initial_guess : torch.Tensor, optional, shape like ``rhs`` Initial guess. If ``None``, zero initialization is used. settings : BICGSTABSettings, optional Convergence tolerances, maximum matvecs, optional preconditioner and logger. Returns ------- torch.Tensor Solution(s) ``x`` with the **same shape as ``rhs``**. Raises ------ RuntimeError If ``matmul_closure`` is neither tensor nor callable, or if the ``precon`` is neither tensor nor callable. Notes ----- Per iteration (unpreconditioned) BiCGSTAB [1a]_ uses ~2 matvecs, several dot products, and vector updates. The algorithm can experience breakdown when certain inner products or denominators vanish (e.g., :math:`\langle r_0, v \rangle = 0` or :math:`\langle t, t \rangle = 0`). This implementation follows a standard variant [2a]_ and solves multiple RHS by looping over columns (no shared Krylov subspace). References ---------- .. [1a] Van der Vorst, H. A. (1992). *Bi-CGSTAB: A fast and smoothly converging variant of Bi-CG for the solution of nonsymmetric linear systems*. SIAM J. Sci. Stat. Comput., 13(2), 631–644. .. [2a] Kelley, C. T. (1995). *Iterative Methods for Linear and Nonlinear Equations*. SIAM. Examples -------- >>> import torch >>> from torchsparsegradutils.utils import bicgstab >>> A = torch.tensor([[3.0, 1.0], [2.0, 4.0]]) >>> b = torch.tensor([1.0, 2.0]) >>> x = bicgstab(A.matmul, b) >>> x.shape torch.Size([2]) Multiple right-hand sides: >>> B = torch.randn(2, 3) >>> X = bicgstab(A.matmul, B) >>> X.shape torch.Size([2, 3]) With custom settings: >>> from torchsparsegradutils.utils.bicgstab import BICGSTABSettings >>> settings = BICGSTABSettings(abstol=1e-10, reltol=1e-8, matvec_max=1000) >>> x = bicgstab(A.matmul, b, settings=settings) With preconditioning: >>> # Diagonal preconditioner >>> # Extract and regularize diagonal >>> diagA = torch.diag(A) >>> eps = 1e-12 >>> safe_diag = torch.where(diagA.abs() < eps, torch.full_like(diagA, eps), diagA) >>> inv_diag = 1.0 / safe_diag >>> # Supply as an operator (apply M^{-1} r = inv_diag * r elementwise) >>> settings_precond = BICGSTABSettings( ... precon=lambda r: inv_diag * r # r has same shape as b ... ) >>> x = bicgstab(A.matmul, b, settings=settings_precond) """ # support multiple right‐hand sides by solving each column separately if rhs.dim() > 1: cols = rhs.shape[1] sols = [ bicgstab( matmul_closure, rhs[:, i], None if initial_guess is None else initial_guess[:, i], settings, ) for i in range(cols) ] return torch.stack(sols, dim=1) n = rhs.shape[0] nMatvec = 0 if torch.is_tensor(matmul_closure): op = matmul_closure.matmul elif callable(matmul_closure): op = matmul_closure else: raise RuntimeError("matmul_closure must be a tensor, or a callable object!") if settings.precon is None: precon = None elif torch.is_tensor(settings.precon): precon = settings.precon.matmul elif callable(settings.precon): precon = settings.precon else: raise RuntimeError("settings.precon must be a tensor, or a callable object!") # Initial guess is zero unless one is supplied res_device = rhs.device res_dtype = rhs.dtype if initial_guess is None: x = torch.zeros(n, dtype=res_dtype, device=res_device) else: x = initial_guess.clone() # matvec_max = kwargs.get('matvec_max', 2*n) matvec_max = 2 * n if settings.matvec_max is None else settings.matvec_max # Initial residual is the fixed vector r0 = rhs.clone() if initial_guess is None: r0 = rhs - op(x) nMatvec += 1 rho = alpha = omega = 1.0 rho_next = torch.dot(r0, r0) residNorm = residNorm0 = torch.abs(torch.sqrt(rho_next)) threshold = max(settings.abstol, settings.reltol * residNorm0) finished = residNorm <= threshold or nMatvec >= matvec_max settings.logger.info("Initial residual = %8.2e" % residNorm0) settings.logger.info("Threshold = %8.2e" % threshold) hdr = "%6s %8s" % ("Matvec", "Residual") settings.logger.info(hdr) settings.logger.info("-" * len(hdr)) if not finished: r = r0.clone() p = torch.zeros(n, dtype=res_dtype, device=res_device) v = torch.zeros(n, dtype=res_dtype, device=res_device) while not finished: beta = rho_next / rho * alpha / omega rho = rho_next # Update p in-place p *= beta p -= beta * omega * v p += r # Compute preconditioned search direction if precon is not None: q = precon(p) else: q = p v = op(q) nMatvec += 1 alpha = rho / torch.dot(r0, v) s = r - alpha * v # Check for CGS termination residNorm = torch.linalg.norm(s) settings.logger.info("%6d %8.2e" % (nMatvec, residNorm)) if residNorm <= threshold: x += alpha * q finished = True continue if nMatvec >= matvec_max: finished = True continue if precon is not None: z = precon(s) else: z = s t = op(z) nMatvec += 1 omega = torch.dot(t, s) / torch.dot(t, t) rho_next = -omega * torch.dot(r0, t) # Update residual r = s - omega * t # Update solution in-place-ish. Note that 'z *= omega' alters s if # precon = None. That's ok since s is no longer needed in this iter. # 'q *= alpha' would alter p. z *= omega x += z x += alpha * q residNorm = torch.linalg.norm(r) settings.logger.info("%6d %8.2e" % (nMatvec, residNorm)) if residNorm <= threshold or nMatvec >= matvec_max: finished = True continue # converged = residNorm <= threshold # variable unused bestSolution = x return bestSolution