# MIT-licensed code imported from https://github.com/cornellius-gp/linear_operator
# Minor modifications for torchsparsegradutils to remove dependencies
import warnings
from typing import Callable, NamedTuple, Optional, Union
import torch
[docs]
class LinearCGSettings(NamedTuple):
max_cg_iterations: int = 1000 # The maximum number of conjugate gradient iterations to perform (when computing
# matrix solves). A higher value rarely results in more accurate solves -- instead, lower the CG tolerance.
max_lanczos_quadrature_iterations: int = (
20 # The maximum number of Lanczos iterations to perform when doing stochastic
)
# Lanczos quadrature. This is ONLY used for log determinant calculations and
# computing Tr(K^{-1}dK/d\theta)
cg_tolerance: float = 1 # Relative residual tolerance to use for terminating CG.
terminate_cg_by_size: bool = False # If set to true, cg will terminate after n iterations for an n x n matrix.
verbose_linalg: bool = False # Print out information whenever running an expensive linear algebra routine
def _default_preconditioner(x):
return x.clone()
@torch.jit.script
def _jit_linear_cg_updates(
result, alpha, residual_inner_prod, eps, beta, residual, precond_residual, mul_storage, is_zero, curr_conjugate_vec
):
# # Update result
# # result_{k} = result_{k-1} + alpha_{k} p_vec_{k-1}
result = torch.addcmul(result, alpha, curr_conjugate_vec, out=result)
# beta_{k} = (precon_residual{k}^T r_vec_{k}) / (precon_residual{k-1}^T r_vec_{k-1})
beta.resize_as_(residual_inner_prod).copy_(residual_inner_prod)
torch.mul(residual, precond_residual, out=mul_storage)
torch.sum(mul_storage, -2, keepdim=True, out=residual_inner_prod)
# Do a safe division here
torch.lt(beta, eps, out=is_zero)
beta.masked_fill_(is_zero, 1)
torch.div(residual_inner_prod, beta, out=beta)
beta.masked_fill_(is_zero, 0)
# Update curr_conjugate_vec
# curr_conjugate_vec_{k} = precon_residual{k} + beta_{k} curr_conjugate_vec_{k-1}
curr_conjugate_vec.mul_(beta).add_(precond_residual)
@torch.jit.script
def _jit_linear_cg_updates_no_precond(
mvms,
result,
has_converged,
alpha,
residual_inner_prod,
eps,
beta,
residual,
precond_residual,
mul_storage,
is_zero,
curr_conjugate_vec,
):
torch.mul(curr_conjugate_vec, mvms, out=mul_storage)
torch.sum(mul_storage, dim=-2, keepdim=True, out=alpha)
# Do a safe division here
torch.lt(alpha, eps, out=is_zero)
alpha.masked_fill_(is_zero, 1)
torch.div(residual_inner_prod, alpha, out=alpha)
alpha.masked_fill_(is_zero, 0)
# We'll cancel out any updates by setting alpha=0 for any vector that has already converged
alpha.masked_fill_(has_converged, 0)
# Update residual
# residual_{k} = residual_{k-1} - alpha_{k} mat p_vec_{k-1}
torch.addcmul(residual, -alpha, mvms, out=residual)
# Update precond_residual
# precon_residual{k} = M^-1 residual_{k}
precond_residual = residual.clone()
_jit_linear_cg_updates(
result,
alpha,
residual_inner_prod,
eps,
beta,
residual,
precond_residual,
mul_storage,
is_zero,
curr_conjugate_vec,
)
[docs]
def linear_cg(
matmul_closure: Union[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]],
rhs: torch.Tensor,
n_tridiag: int = 0,
tolerance: Optional[float] = None,
eps: float = 1e-10,
stop_updating_after: float = 1e-10,
max_iter: Optional[int] = None,
max_tridiag_iter: Optional[int] = None,
initial_guess: Optional[torch.Tensor] = None,
preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
settings: LinearCGSettings = LinearCGSettings(),
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
r"""
Solve symmetric positive definite linear systems using conjugate gradient (CG).
Implements linear CG for systems :math:`A x = b` with symmetric positive definite
operator :math:`A`. Supports single/multiple RHS and optional stochastic Lanczos
tridiagonalization for eigenvalue/log-determinant estimation.
Parameters
----------
matmul_closure : {``torch.Tensor``, callable(x) -> ``A x``}
Matrix–vector multiply. If a tensor is provided, its ``.matmul`` is used.
The callable must accept inputs shaped like ``rhs`` and return ``A @ rhs``.
rhs : torch.Tensor, shape ``(..., n)`` or ``(..., n, k)``
Right-hand side(s). Leading batch dims are supported.
n_tridiag : int, optional
Number of Lanczos tridiagonalizations (probe vectors). If ``> 0``,
tridiagonal matrices are returned in addition to the solution. Default: ``0``.
tolerance : float, optional
Average residual-norm stopping criterion. If ``None``, uses
``settings.cg_tolerance``.
eps : float, optional
Small constant to avoid division by zero. Default: ``1e-10``.
stop_updating_after : float, optional
Per-vector early-stop threshold for residual norms. Default: ``1e-10``.
max_iter : int, optional
Maximum CG iterations. If ``None``, uses ``settings.max_cg_iterations``.
max_tridiag_iter : int, optional
Maximum Lanczos size. If ``None``, uses
``settings.max_lanczos_quadrature_iterations``.
initial_guess : torch.Tensor, optional, shape like ``rhs``
Initial guess. If ``None``, zeros are used.
preconditioner : callable, optional
Preconditioner with signature ``preconditioner(x) -> M^{-1} x``.
If ``None``, no preconditioning is used.
settings : LinearCGSettings, optional
Configuration for iteration caps, tolerances, and logging verbosity.
Returns
-------
torch.Tensor or (torch.Tensor, torch.Tensor)
* If ``n_tridiag == 0``: solution ``x`` with the same shape as ``rhs``.
* If ``n_tridiag > 0``: ``(x, T)`` where ``T`` has shape
``(n_tridiag, *rhs.shape[:-2], r, r)`` with ``r = last_tridiag_iter + 1``
and ``r <= min(max_tridiag_iter, n)``. Without batch dims this is ``(n_tridiag, r, r)``.
Raises
------
RuntimeError
If ``max_tridiag_iter > max_iter``.
RuntimeError
If ``matmul_closure`` is neither a tensor nor a callable.
Notes
-----
CG converges in at most ``n`` iterations for SPD matrices, but typically much
faster if eigenvalues are clustered. Preconditioning (e.g. diagonal or
incomplete Cholesky) can significantly accelerate convergence. When
``n_tridiag > 0``, Lanczos tridiagonalization is accumulated alongside CG for
spectral / log-determinant estimates.
This implementation is based on MIT-licensed code from the linear_operator
library [1e]_.
Examples
--------
Basic CG solve::
>>> A = torch.tensor([[4.0, -1.0], [-1.0, 4.0]])
>>> b = torch.tensor([1.0, 2.0])
>>> x = linear_cg(A.matmul, b)
>>> x.shape
torch.Size([2])
Multiple RHS::
>>> B = torch.randn(2, 5) # 5 RHS
>>> X = linear_cg(A.matmul, B, max_iter=100, tolerance=1e-8)
>>> X.shape
torch.Size([2, 5])
With preconditioning::
>>> M_inv = torch.diag(1.0 / torch.diag(A))
>>> x = linear_cg(A.matmul, b, preconditioner=lambda v: M_inv @ v)
With Lanczos tridiagonalization::
>>> x, T = linear_cg(A.matmul, b, n_tridiag=1)
>>> T.shape # (n_tridiag, r, r) with r <= n
torch.Size([1, 2, 2])
Sparse operator via closure::
>>> indices = torch.tensor([[0, 0, 1, 1, 2], [0, 1, 0, 1, 2]])
>>> values = torch.tensor([4.0, -1.0, -1.0, 4.0, 2.0])
>>> A_sp = torch.sparse_coo_tensor(indices, values, (3, 3))
>>> x = linear_cg(lambda v: A_sp @ v, torch.randn(3))
References
----------
.. [1e] linear_operator library. https://github.com/cornellius-gp/linear_operator
"""
# Unsqueeze, if necessary
is_vector = rhs.ndimension() == 1
if is_vector:
rhs = rhs.unsqueeze(-1)
# Some default arguments
if max_iter is None:
max_iter = settings.max_cg_iterations
if max_tridiag_iter is None:
max_tridiag_iter = settings.max_lanczos_quadrature_iterations
if initial_guess is None:
initial_guess = torch.zeros_like(rhs)
else:
# Unsqueeze, if necessary
is_vector = initial_guess.ndimension() == 1
if is_vector:
initial_guess = initial_guess.unsqueeze(-1)
if tolerance is None:
tolerance = settings.cg_tolerance
if preconditioner is None:
preconditioner = _default_preconditioner
precond = False
else:
precond = True
# If we are running m CG iterations, we obviously can't get more than m Lanczos coefficients
if max_tridiag_iter > max_iter:
raise RuntimeError("Getting a tridiagonalization larger than the number of CG iterations run is not possible!")
# Check matmul_closure object
if torch.is_tensor(matmul_closure):
matmul_closure = matmul_closure.matmul
elif not callable(matmul_closure):
raise RuntimeError("matmul_closure must be a tensor, or a callable object!")
# Get some constants
num_rows = rhs.size(-2)
n_iter = min(max_iter, num_rows) if settings.terminate_cg_by_size else max_iter
n_tridiag_iter = min(max_tridiag_iter, num_rows)
eps = torch.tensor(eps, dtype=rhs.dtype, device=rhs.device)
# Get the norm of the rhs - used for convergence checks
# Here we're going to make almost-zero norms actually be 1 (so we don't get divide-by-zero issues)
# But we'll store which norms were actually close to zero
rhs_norm = rhs.norm(2, dim=-2, keepdim=True)
rhs_is_zero = rhs_norm.lt(eps)
rhs_norm = rhs_norm.masked_fill_(rhs_is_zero, 1)
# Let's normalize. We'll un-normalize afterwards
rhs = rhs.div(rhs_norm)
initial_guess = initial_guess.div(rhs_norm)
# residual: residual_{0} = b_vec - lhs x_{0}
residual = rhs - matmul_closure(initial_guess)
batch_shape = residual.shape[:-2]
# result <- x_{0}
result = initial_guess.expand_as(residual).contiguous()
# Maybe log
if settings.verbose_linalg:
# settings.verbose_linalg.logger.debug(
print(f"Running CG on a {rhs.shape} RHS for {n_iter} iterations (tol={tolerance}). Output: {result.shape}.")
# Check for NaNs
if not torch.equal(residual, residual):
raise RuntimeError("NaNs encountered when trying to perform matrix-vector multiplication")
# Sometime we're lucky and the preconditioner solves the system right away
# Check for convergence
residual_norm = residual.norm(2, dim=-2, keepdim=True)
has_converged = torch.lt(residual_norm, stop_updating_after)
if has_converged.all() and not n_tridiag:
n_iter = 0 # Skip the iteration!
# Otherwise, let's define precond_residual and curr_conjugate_vec
else:
# precon_residual{0} = M^-1 residual_{0}
precond_residual = preconditioner(residual)
curr_conjugate_vec = precond_residual
residual_inner_prod = precond_residual.mul(residual).sum(-2, keepdim=True)
# Define storage matrices
mul_storage = torch.empty_like(residual)
alpha = torch.empty(*batch_shape, 1, rhs.size(-1), dtype=residual.dtype, device=residual.device)
beta = torch.empty_like(alpha)
is_zero = torch.empty(*batch_shape, 1, rhs.size(-1), dtype=torch.bool, device=residual.device)
# Define tridiagonal matrices, if applicable
if n_tridiag:
t_mat = torch.zeros(
n_tridiag_iter, n_tridiag_iter, *batch_shape, n_tridiag, dtype=alpha.dtype, device=alpha.device
)
alpha_tridiag_is_zero = torch.empty(*batch_shape, n_tridiag, dtype=torch.bool, device=t_mat.device)
alpha_reciprocal = torch.empty(*batch_shape, n_tridiag, dtype=t_mat.dtype, device=t_mat.device)
prev_alpha_reciprocal = torch.empty_like(alpha_reciprocal)
prev_beta = torch.empty_like(alpha_reciprocal)
update_tridiag = True
last_tridiag_iter = 0
# It's conceivable we reach the tolerance on the last iteration, so can't just check iteration number.
tolerance_reached = False
# Start the iteration
for k in range(n_iter):
# Get next alpha
# alpha_{k} = (residual_{k-1}^T precon_residual{k-1}) / (p_vec_{k-1}^T mat p_vec_{k-1})
mvms = matmul_closure(curr_conjugate_vec)
if precond:
torch.mul(curr_conjugate_vec, mvms, out=mul_storage)
torch.sum(mul_storage, -2, keepdim=True, out=alpha)
# Do a safe division here
torch.lt(alpha, eps, out=is_zero)
alpha.masked_fill_(is_zero, 1)
torch.div(residual_inner_prod, alpha, out=alpha)
alpha.masked_fill_(is_zero, 0)
# We'll cancel out any updates by setting alpha=0 for any vector that has already converged
alpha.masked_fill_(has_converged, 0)
# Update residual
# residual_{k} = residual_{k-1} - alpha_{k} mat p_vec_{k-1}
residual = torch.addcmul(residual, alpha, mvms, value=-1, out=residual)
# Update precond_residual
# precon_residual{k} = M^-1 residual_{k}
precond_residual = preconditioner(residual)
_jit_linear_cg_updates(
result,
alpha,
residual_inner_prod,
eps,
beta,
residual,
precond_residual,
mul_storage,
is_zero,
curr_conjugate_vec,
)
else:
_jit_linear_cg_updates_no_precond(
mvms,
result,
has_converged,
alpha,
residual_inner_prod,
eps,
beta,
residual,
precond_residual,
mul_storage,
is_zero,
curr_conjugate_vec,
)
torch.norm(residual, 2, dim=-2, keepdim=True, out=residual_norm)
residual_norm.masked_fill_(rhs_is_zero, 0)
torch.lt(residual_norm, stop_updating_after, out=has_converged)
if (
k >= min(10, max_iter - 1)
and bool(residual_norm.mean() < tolerance)
and not (n_tridiag and k < min(n_tridiag_iter, max_iter - 1))
):
tolerance_reached = True
break
# Update tridiagonal matrices, if applicable
if n_tridiag and k < n_tridiag_iter and update_tridiag:
alpha_tridiag = alpha.squeeze(-2).narrow(-1, 0, n_tridiag)
beta_tridiag = beta.squeeze(-2).narrow(-1, 0, n_tridiag)
torch.eq(alpha_tridiag, 0, out=alpha_tridiag_is_zero)
alpha_tridiag.masked_fill_(alpha_tridiag_is_zero, 1)
torch.reciprocal(alpha_tridiag, out=alpha_reciprocal)
alpha_tridiag.masked_fill_(alpha_tridiag_is_zero, 0)
if k == 0:
t_mat[k, k].copy_(alpha_reciprocal)
else:
torch.addcmul(alpha_reciprocal, prev_beta, prev_alpha_reciprocal, out=t_mat[k, k])
torch.mul(prev_beta.sqrt_(), prev_alpha_reciprocal, out=t_mat[k, k - 1])
t_mat[k - 1, k].copy_(t_mat[k, k - 1])
if t_mat[k - 1, k].max() < 1e-6:
update_tridiag = False
last_tridiag_iter = k
prev_alpha_reciprocal.copy_(alpha_reciprocal)
prev_beta.copy_(beta_tridiag)
# Un-normalize
result = result.mul(rhs_norm)
if not tolerance_reached and n_iter > 0:
warnings.warn(
"CG terminated in {} iterations with average residual norm {}"
" which is larger than the tolerance of {} specified by"
" linear_operator.settings.cg_tolerance."
" If performance is affected, consider raising the maximum number of CG iterations by running code in"
" a linear_operator.settings.max_cg_iterations(value) context.".format(
k + 1, residual_norm.mean(), tolerance
),
UserWarning,
)
if is_vector:
result = result.squeeze(-1)
if n_tridiag:
t_mat = t_mat[: last_tridiag_iter + 1, : last_tridiag_iter + 1]
return result, t_mat.permute(-1, *range(2, 2 + len(batch_shape)), 0, 1).contiguous()
else:
return result