Source code for torchsparsegradutils.utils.lsmr

"""
Code adapted from https://github.com/rfeinman/pytorch-minimize/blob/master/torchmin/lstsq/lsmr.py
Code modified from scipy.sparse.linalg.lsmr

Copyright (C) 2010 David Fong and Michael Saunders
"""

from typing import Callable, Optional, Tuple, Union

import torch


def _sym_ortho(a, b, out):
    torch.hypot(a, b, out=out[2])
    torch.div(a, out[2], out=out[0])
    torch.div(b, out[2], out=out[1])
    return out


[docs] @torch.no_grad() def lsmr( A: Union[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]], b: torch.Tensor, Armat: Optional[Union[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]] = None, n: Optional[int] = None, damp: float = 0.0, atol: float = 1e-6, btol: float = 1e-6, conlim: float = 1e8, maxiter: Optional[int] = None, x0: Optional[torch.Tensor] = None, check_nonzero: bool = True, ) -> Tuple[torch.Tensor, int]: r""" Least Squares Minimal Residual (LSMR) solver. Iterative solver for :math:`A x = b` and least-squares problems :math:`\min_x \|A x - b\|_2`. Works with large, sparse, or rectangular :math:`A` and is often more stable than LSQR on ill-conditioned problems. Parameters ---------- A : {torch.Tensor, callable(x) -> A @ x} System matrix or matvec closure. If a tensor is given, it may be dense or sparse and ``.matmul`` is used. b : torch.Tensor, shape (m,) Right-hand side vector. Must be on the same device/dtype as ``A``. Armat : {torch.Tensor, callable(x) -> A^T @ x}, optional Transpose matvec or matrix. If ``A`` is a tensor and ``Armat`` is ``None``, uses ``A.adjoint().matmul``. If ``A`` is callable, ``Armat`` is **required**. n : int, optional Number of columns of ``A``. Required if ``A`` is callable; inferred from ``A.shape[1]`` if ``A`` is a tensor. damp : float, optional Tikhonov damping parameter (ridge). Solves :math:`\min_x \|(A; \text{damp} I) x - (b; 0)\|_2`. Default: 0.0. atol : float, optional Absolute convergence tolerance. Default: 1e-6. btol : float, optional Relative residual tolerance. Default: 1e-6. conlim : float, optional Condition estimate limit; stops if estimate exceeds this value. Default: 1e8. maxiter : int, optional Maximum iterations. If ``None``, uses ``min(m, n)``. x0 : torch.Tensor, optional, shape (n,) Initial guess. If ``None``, zeros are used. check_nonzero : bool, optional Skip the rare ``beta == 0`` synchronization check for performance when set to ``False`` (use with caution). Default: True. Returns ------- x : torch.Tensor, shape (n,) Approximate solution that minimizes :math:`\|A x - b\|_2` (with damped variant when ``damp > 0``). iterations : int Number of iterations executed. Raises ------ RuntimeError If ``A`` is neither a tensor nor a callable. RuntimeError If ``A`` is callable and ``n`` is not provided. RuntimeError If ``Armat`` is missing or is neither a tensor nor a callable. Notes ----- Uses Golub–Kahan bidiagonalization [1f]_ with specialized QR steps. For overdetermined systems (``m > n``), returns the least-squares solution. For underdetermined systems (``m < n``) with ``damp = 0``, returns the minimum-norm least-squares solution. Convergence checks (roughly): - Consistent: :math:`\|r\|_2 \le \text{atol} \, \|A\| \, \|x\| + \text{btol} \, \|b\|` - Inconsistent: :math:`\|A^\top r\|_2 \le \text{atol} \, \|A\| \, \|r\|` References ---------- .. [1f] Fong, D. C., & Saunders, M. (2011). LSMR: An iterative algorithm for sparse least-squares problems. SIAM Journal on Scientific Computing, 33(5), 2950-2971. Examples -------- Basic least squares problem: >>> import torch >>> from torchsparsegradutils.utils import lsmr >>> # Over-determined system (3x2) >>> A = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) >>> b = torch.tensor([1.0, 2.0, 3.0]) >>> x, iterations = lsmr(A, b) >>> x.shape torch.Size([2]) Sparse matrix least squares: >>> # Create sparse matrix >>> indices = torch.tensor([[0, 1, 2, 1, 2], [0, 0, 0, 1, 1]]) >>> values = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) >>> A_sparse = torch.sparse_coo_tensor(indices, values, (3, 2)) >>> x, it = lsmr(A_sparse, b) With damping for regularization (Tikhonov / ridge): >>> # Regularized least squares >>> x_reg, it = lsmr(A, b, damp=0.1) Callable matrix interface: >>> x, it = lsmr(lambda v: A @ v, b, Armat=lambda v: A.T @ v, n=2) Under-determined system (minimum norm solution): >>> A_under = torch.randn(2, 4) # 2x4 system >>> b_under = torch.randn(2) >>> x_min_norm, it = lsmr(A_under, b_under) >>> x_min_norm.shape torch.Size([4]) Custom tolerances and limits: >>> x, it = lsmr(A, b, atol=1e-10, btol=1e-10, conlim=1e12, maxiter=1000) """ if torch.is_tensor(A): if n is None: n = A.shape[1] if Armat is None: Armat = (torch.adjoint(A)).matmul A = A.matmul elif not callable(A): raise RuntimeError("matmul_closure must be a tensor, or a callable object!") if n is None: raise RuntimeError("n needs to be provided or computed from A given as a tensor") if torch.is_tensor(Armat): Armat = Armat.matmul elif not callable(Armat): raise RuntimeError("matmul_closure must be a tensor, or a callable object!") sdtype = b.dtype if b.dtype == torch.complex64: sdtype = torch.float32 elif b.dtype == torch.complex128: sdtype = torch.float64 b = torch.atleast_1d(b) if b.dim() > 1: b = b.squeeze() eps = torch.finfo(sdtype).eps damp = torch.as_tensor(damp, dtype=sdtype, device=b.device) ctol = 1 / conlim if conlim > 0 else 0.0 m = b.shape[0] if maxiter is None: maxiter = min(m, n) u = b.clone() normb = b.norm() if x0 is None: x = b.new_zeros(n) beta = normb.clone() else: x = torch.atleast_1d(x0).clone() u.sub_(A(x)) beta = u.norm() if beta > 0: u.div_(beta) v = Armat(u) alpha = v.norm() else: v = b.new_zeros(n) alpha = b.new_tensor(0, dtype=sdtype) v = torch.where(alpha > 0, v / alpha, v) # Initialize variables for 1st iteration. zetabar = alpha * beta alphabar = alpha.clone() rho = b.new_tensor(1, dtype=sdtype) rhobar = b.new_tensor(1, dtype=sdtype) cbar = b.new_tensor(1, dtype=sdtype) sbar = b.new_tensor(0, dtype=sdtype) h = v.clone() hbar = b.new_zeros(n) # Initialize variables for estimation of ||r||. betadd = beta.clone() betad = b.new_tensor(0, dtype=sdtype) rhodold = b.new_tensor(1, dtype=sdtype) tautildeold = b.new_tensor(0, dtype=sdtype) thetatilde = b.new_tensor(0, dtype=sdtype) zeta = b.new_tensor(0, dtype=sdtype) d = b.new_tensor(0, dtype=sdtype) # Initialize variables for estimation of ||A|| and cond(A) normA2 = alpha.square() maxrbar = b.new_tensor(0, dtype=sdtype) minrbar = b.new_tensor(0.99 * torch.finfo(sdtype).max, dtype=sdtype) normA = normA2.sqrt() condA = b.new_tensor(1, dtype=sdtype) normx = b.new_tensor(0, dtype=sdtype) # normar = b.new_tensor(0,dtype=sdtype) # normr = b.new_tensor(0,dtype=sdtype) normr = beta.clone() normar = alpha * beta if normar == 0: return x, 0 if normb == 0: x[:] = 0 return x, 0 # extra buffers (added by Reuben) c = b.new_tensor(0, dtype=sdtype) s = b.new_tensor(0, dtype=sdtype) chat = b.new_tensor(0, dtype=sdtype) shat = b.new_tensor(0, dtype=sdtype) alphahat = b.new_tensor(0, dtype=sdtype) ctildeold = b.new_tensor(0, dtype=sdtype) stildeold = b.new_tensor(0, dtype=sdtype) rhotildeold = b.new_tensor(0, dtype=sdtype) rhoold = b.new_tensor(0, dtype=sdtype) rhobarold = b.new_tensor(0, dtype=sdtype) zetaold = b.new_tensor(0, dtype=sdtype) thetatildeold = b.new_tensor(0, dtype=sdtype) betaacute = b.new_tensor(0, dtype=sdtype) betahat = b.new_tensor(0, dtype=sdtype) betacheck = b.new_tensor(0, dtype=sdtype) taud = b.new_tensor(0, dtype=sdtype) # Main iteration loop. for itn in range(1, maxiter + 1): # Perform the next step of the bidiagonalization to obtain the # next beta, u, alpha, v. These satisfy the relations # beta*u = a*v - alpha*u, # alpha*v = A'*u - beta*v. u.mul_(-alpha).add_(A(v)) torch.norm(u, out=beta) if (not check_nonzero) or beta > 0: # check_nonzero option provides a means to avoid the GPU-CPU # synchronization of a `beta > 0` check. For most cases # beta == 0 is unlikely, but use this option with caution. u.div_(beta) v.mul_(-beta).add_(Armat(u)) torch.norm(v, out=alpha) v = torch.where(alpha > 0, v / alpha, v) # At this point, beta = beta_{k+1}, alpha = alpha_{k+1}. _sym_ortho(alphabar, damp, out=(chat, shat, alphahat)) # Use a plane rotation (Q_i) to turn B_i to R_i rhoold.copy_(rho, non_blocking=True) _sym_ortho(alphahat, beta, out=(c, s, rho)) thetanew = torch.mul(s, alpha) torch.mul(c, alpha, out=alphabar) # Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar rhobarold.copy_(rhobar, non_blocking=True) zetaold.copy_(zeta, non_blocking=True) thetabar = sbar * rho rhotemp = cbar * rho _sym_ortho(cbar * rho, thetanew, out=(cbar, sbar, rhobar)) torch.mul(cbar, zetabar, out=zeta) zetabar.mul_(-sbar) # Update h, h_hat, x. hbar.mul_(-thetabar * rho).div_(rhoold * rhobarold) hbar.add_(h) x.addcdiv_(zeta * hbar, rho * rhobar) h.mul_(-thetanew).div_(rho) h.add_(v) # Estimate of ||r||. # Apply rotation Qhat_{k,2k+1}. torch.mul(chat, betadd, out=betaacute) torch.mul(-shat, betadd, out=betacheck) # Apply rotation Q_{k,k+1}. torch.mul(c, betaacute, out=betahat) torch.mul(-s, betaacute, out=betadd) # Apply rotation Qtilde_{k-1}. # betad = betad_{k-1} here. thetatildeold.copy_(thetatilde, non_blocking=True) _sym_ortho(rhodold, thetabar, out=(ctildeold, stildeold, rhotildeold)) torch.mul(stildeold, rhobar, out=thetatilde) torch.mul(ctildeold, rhobar, out=rhodold) betad.mul_(-stildeold).addcmul_(ctildeold, betahat) # betad = betad_k here. # rhodold = rhod_k here. tautildeold.mul_(-thetatildeold).add_(zetaold).div_(rhotildeold) torch.div(zeta - thetatilde * tautildeold, rhodold, out=taud) d.addcmul_(betacheck, betacheck) torch.sqrt(d + (betad - taud).square() + betadd.square(), out=normr) # Estimate ||A||. normA2.addcmul_(beta, beta) torch.sqrt(normA2, out=normA) normA2.addcmul_(alpha, alpha) # Estimate cond(A). torch.max(maxrbar, rhobarold, out=maxrbar) if itn > 1: torch.min(minrbar, rhobarold, out=minrbar) # ------- Test for convergence -------- # if itn % 10 == 0: if True: # Compute norms for convergence testing. torch.abs(zetabar, out=normar) torch.norm(x, out=normx) torch.div(torch.max(maxrbar, rhotemp), torch.min(minrbar, rhotemp), out=condA) # Now use these norms to estimate certain other quantities, # some of which will be small near a solution. test1 = normr / normb test2 = normar / (normA * normr + eps) test3 = 1 / (condA + eps) t1 = test1 / (1 + normA * normx / normb) rtol = btol + atol * normA * normx / normb # The first 3 tests guard against extremely small values of # atol, btol or ctol. (The user may have set any or all of # the parameters atol, btol, conlim to 0.) # The effect is equivalent to the normAl tests using # atol = eps, btol = eps, conlim = 1/eps. # The second 3 tests allow for tolerances set by the user. stop = ( (1 + test3 <= 1) | (1 + test2 <= 1) | (1 + t1 <= 1) | (test3 <= ctol) | (test2 <= atol) | (test1 <= rtol) ) if stop: break return x, itn