Source code for torchsparsegradutils.distributions.sparse_multivariate_normal

import math
import warnings

import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import _standard_normal

from torchsparsegradutils import sparse_mm as spmm, sparse_triangular_solve as spts

# from .contraints import sparse_strictly_lower_triangular

__all__ = ["SparseMultivariateNormal", "SparseMultivariateNormalNative"]


def _batch_sparse_mv(op, bmat, bvec, **kwargs):
    r"""
    Batched sparse–dense matvec helper (no broadcasting).

    Performs a matrix–vector (or matrix–matrix) operation where the matrix is
    sparse (COO/CSR) and the vector/array is dense. Limited batch combinations
    are supported and **no broadcasting** of batch dimensions is performed.

    Supported input ranks
    ---------------------
    - ``bmat.ndim == 2`` with ``bvec.ndim == 1``  → returns ``(n,)``
    - ``bmat.ndim == 2`` with ``bvec.ndim == 2``  → returns ``(n, k)``
    - ``bmat.ndim == 3`` with ``bvec.ndim == 2``  → returns ``(B, n)``
    - ``bmat.ndim == 3`` with ``bvec.ndim == 3``  → returns ``(B, n, k)``

    Parameters
    ----------
    op : callable
        Sparse operator to apply. Typically :func:`torchsparsegradutils.sparse_mm`
        (SpMM) or :func:`torchsparsegradutils.sparse_triangular_solve`.
        Must accept ``(bmat, bvec, **kwargs)`` and return a tensor shaped
        like a dense matmul result.
    bmat : torch.Tensor
        Sparse matrix of shape ``(n, n)`` or batched sparse matrix of shape
        ``(B, n, n)`` in COO or CSR layout.
    bvec : torch.Tensor
        Dense vector/array. Shape may be ``(n,)``, ``(n, k)``, ``(B, n)``,
        or ``(B, n, k)`` as listed above.
    **kwargs
        Extra keyword arguments forwarded to ``op`` (e.g., ``upper``,
        ``unitriangular``, ``transpose`` for triangular solves).

    Returns
    -------
    torch.Tensor
        Dense result with shape as indicated in *Supported input ranks*.

    Raises
    ------
    ValueError
        If the pair of ranks is not one of the supported combinations.

    Notes
    -----
    - Batch sizes must match exactly when ``bmat`` is batched.
    - This helper exists to centralize the small reshaping/permutation logic
      so that SpMM and triangular solves can share a common pathway.

    See Also
    --------
    torchsparsegradutils.sparse_mm : Sparse matrix–dense matrix multiply.
    torchsparsegradutils.sparse_triangular_solve : Sparse triangular solver.

    Examples
    --------
    Vector RHS (2D × 1D)::

        >>> A = torch.eye(4, dtype=torch.float64).to_sparse_csr()
        >>> v = torch.arange(4., dtype=torch.float64)
        >>> _batch_sparse_mv(spmm, A, v).shape
        torch.Size([4])

    Matrix RHS (2D × 2D)::

        >>> X = torch.randn(5, 4, dtype=torch.float64)  # 5 vectors of size 4
        >>> _batch_sparse_mv(spmm, A, X).shape
        torch.Size([5, 4])

    Batched matrix with vector RHS (3D × 2D)::

        >>> # Create batched sparse tensors using stack_csr utility
        >>> from torchsparsegradutils.utils import stack_csr
        >>> Ab = stack_csr([A, A])  # (B=2, 4, 4)
        >>> vb = torch.randn(2, 4, dtype=torch.float64)
        >>> _batch_sparse_mv(spmm, Ab, vb).shape
        torch.Size([2, 4])
    """
    if bmat.dim() == 2 and bvec.dim() == 1:
        return op(bmat, bvec.unsqueeze(-1), **kwargs).squeeze(-1)
    elif bmat.dim() == 2 and bvec.dim() == 2:
        return op(bmat, bvec.t(), **kwargs).t()
    elif bmat.dim() == 3 and bvec.dim() == 2:
        return op(bmat, bvec.unsqueeze(-1), **kwargs).squeeze(-1)
    elif bmat.dim() == 3 and bvec.dim() == 3:
        return op(bmat, bvec.permute(1, 2, 0), **kwargs).permute(2, 0, 1)
    else:
        raise ValueError("Invalid dimensions for bmat and bvec")


[docs] class SparseMultivariateNormal(Distribution): r""" Multivariate normal with sparse Cholesky / LDL^T parameterizations. Supports sparse covariance **or** sparse precision factors using either the standard Cholesky (:math:`L L^\top`) or the modified Cholesky (:math:`L D L^\top`) parameterization. Sparse triangular factors can be given in COO or CSR layout and (optionally) batched with a single leading batch dimension. Parameterizations ----------------- **Cholesky (LL^T)** - Covariance form: :math:`\Sigma = L L^\top` with lower-triangular :math:`L` (incl. diagonal). - Precision form: :math:`\Omega = L L^\top` with lower-triangular :math:`L` (incl. diagonal). **Modified Cholesky (LDL^T)** - Covariance form: :math:`\Sigma = L D L^\top` with *unit* lower-triangular :math:`L` and diagonal :math:`D = \operatorname{diag}(\text{diagonal}) > 0`. - Precision form: :math:`\Omega = L D L^\top` with *unit* lower-triangular :math:`L` and diagonal :math:`D = \operatorname{diag}(\text{diagonal})` (entries may be any real numbers if you only care about sampling via :math:`\Omega^{-1/2}`). - In both cases, the strictly lower part lives in the sparse ``*_tril`` factor and the diagonal is provided separately via ``diagonal``. Parameters ---------- loc : torch.Tensor Mean vector, shape ``(n,)`` or ``(B, n)``. diagonal : torch.Tensor, optional Diagonal entries for the :math:`D` in :math:`L D L^\top`. Shape ``(n,)`` or ``(B, n)``. Must be positive for covariance parameterization; free real for precision parameterization. If ``None``, the :math:`L L^\top` parameterization is used. scale_tril : torch.Tensor, optional Sparse lower-triangular factor for covariance (:math:`L`). COO or CSR with shape ``(n, n)`` or ``(B, n, n)``. Mutually exclusive with ``precision_tril``. precision_tril : torch.Tensor, optional Sparse lower-triangular factor for precision (:math:`L`). COO or CSR with shape ``(n, n)`` or ``(B, n, n)``. Mutually exclusive with ``scale_tril``. validate_args : bool, optional If ``True``, run argument checks. Attributes ---------- loc : torch.Tensor Mean. diagonal : torch.Tensor or None Diagonal entries for :math:`D` in :math:`L D L^\top` form, if provided. scale_tril : torch.Tensor or None Covariance Cholesky factor (sparse lower-triangular), if provided. precision_tril : torch.Tensor or None Precision Cholesky factor (sparse lower-triangular), if provided. has_rsample : bool Reparameterized sampling is supported. Returns ------- SparseMultivariateNormal A distribution instance compatible with :mod:`torch.distributions`. Raises ------ ValueError On invalid shapes, unsupported layouts, or incompatible batching. Notes ----- - Only a **single** batch dimension is supported for the sparse factors. - Sampling uses reparameterization: - Covariance :math:`L L^\top`: :math:`x = L \varepsilon`. - Covariance :math:`L D L^\top`: :math:`x = L (\sqrt{D} \odot \varepsilon) + (\sqrt{D} \odot \varepsilon)` since :math:`L` is unit lower-triangular (strictly lower part in sparse factor). - Precision forms use sparse triangular solves with the transpose. - ``log_prob`` is not implemented in this class. - Sparse operations are delegated to :func:`torchsparsegradutils.sparse_mm` and :func:`torchsparsegradutils.sparse_triangular_solve`. See Also -------- torch.distributions.MultivariateNormal : Dense baseline distribution. torchsparsegradutils.sparse_mm : Sparse matrix–dense matrix multiply used during sampling. torchsparsegradutils.sparse_triangular_solve : Sparse triangular solver used with precision factors. Examples -------- :math:`L L^\top` parameterization with sparse covariance:: >>> import torch >>> from torchsparsegradutils.distributions import SparseMultivariateNormal >>> loc = torch.zeros(4) >>> indices = torch.tensor([[1, 2, 2, 3], [0, 0, 1, 2]]) >>> values = torch.tensor([0.5, 0.3, 0.8, 0.2]) >>> diag_indices = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]]) >>> diag_values = torch.tensor([1.0, 1.0, 1.0, 1.0]) >>> all_indices = torch.cat([diag_indices, indices], dim=1) >>> all_values = torch.cat([diag_values, values]) >>> scale_tril = torch.sparse_coo_tensor(all_indices, all_values, (4, 4)) >>> mvn = SparseMultivariateNormal(loc=loc, scale_tril=scale_tril) >>> samples = mvn.sample((100,)) >>> samples.shape torch.Size([100, 4]) :math:`L D L^\top` parameterization with diagonal precision:: >>> diagonal = torch.tensor([2.0, 1.5, 3.0, 1.0]) # Precision diagonal >>> precision_tril = torch.sparse_coo_tensor(indices, values, (4, 4)) >>> mvn_ldlt = SparseMultivariateNormal(loc=loc, diagonal=diagonal, ... precision_tril=precision_tril) >>> samples = mvn_ldlt.sample((50,)) Batched distributions:: >>> loc_batch = torch.randn(3, 4) >>> diagonal_batch = torch.abs(torch.randn(3, 4)) + 0.1 >>> precision_batch = torch.stack([precision_tril] * 3) >>> mvn_batch = SparseMultivariateNormal(loc=loc_batch, diagonal=diagonal_batch, ... precision_tril=precision_batch) >>> samples = mvn_batch.sample() >>> samples.shape torch.Size([3, 4]) """ arg_constraints = {} # TODO: add in constraints # For LDL^T parameterization: # arg_constraints = {'loc': constraints.real_vector, # 'diagonal': constraints.independent(constraints.positive, 1), # 'scale_tril': sparse_strictly_lower_triangular, # 'precision_tril': sparse_strictly_lower_triangular} # For LL^T parameterization: # arg_constraints = {'loc': constraints.real_vector, # 'scale_tril': constraints.lower_cholesky, # 'precision_tril': constraints.lower_cholesky} support = constraints.real_vector has_rsample = True
[docs] def __init__(self, loc, diagonal=None, scale_tril=None, precision_tril=None, validate_args=None): if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") elif loc.dim() > 2: raise ValueError( "loc must be at most two-dimensional as the current implementation only supports 1 batch dimension." ) event_shape = loc.shape[-1:] self._loc = loc if diagonal is not None: if diagonal.dim() < 1: raise ValueError("diagonal must be at least one-dimensional.") elif diagonal.dim() > 2: raise ValueError( "diagonal must be at most two-dimensional as the current implementation only supports 1 batch dimension." ) if diagonal.shape[-1:] != event_shape: raise ValueError("diagonal must be a batch of vectors with shape {}".format(event_shape)) self._diagonal = diagonal if (scale_tril is not None) + (precision_tril is not None) != 1: raise ValueError("Exactly one of scale_tril or precision_tril may be specified.") if scale_tril is not None: if scale_tril.layout == torch.sparse_coo: scale_tril = scale_tril.coalesce() if not scale_tril.is_coalesced() else scale_tril elif scale_tril.layout == torch.sparse_csr: pass else: raise ValueError("scale_tril must be sparse COO or CSR, instead of {}".format(scale_tril.layout)) if scale_tril.dim() < 2: raise ValueError( "scale_tril matrix must be at least two-dimensional, " "with optional leading batch dimension" ) elif scale_tril.dim() > 3: raise ValueError("scale_tril can only have 1 batch dimension, but has {}".format(scale_tril.dim() - 2)) if diagonal is not None: batch_shape = torch.broadcast_shapes(loc.shape[:-1], diagonal.shape[:-1], scale_tril.shape[:-2]) else: batch_shape = torch.broadcast_shapes(loc.shape[:-1], scale_tril.shape[:-2]) self._scale_tril = scale_tril else: # precision_tril is not None if precision_tril.layout == torch.sparse_coo: precision_tril = precision_tril.coalesce() elif precision_tril.layout == torch.sparse_csr: pass else: raise ValueError( "precision_tril must be sparse COO or CSR, instead of {}".format(precision_tril.layout) ) if precision_tril.dim() < 2: raise ValueError( "precision_tril must be at least two-dimensional, " "with optional leading batch dimensions" ) elif precision_tril.dim() > 3: raise ValueError( "precision_tril can only have 1 batch dimension, but has {}".format(precision_tril.dim() - 2) ) if diagonal is not None: batch_shape = torch.broadcast_shapes(loc.shape[:-1], diagonal.shape[:-1], precision_tril.shape[:-2]) else: batch_shape = torch.broadcast_shapes(loc.shape[:-1], precision_tril.shape[:-2]) self._precision_tril = precision_tril super().__init__(batch_shape, event_shape, validate_args=validate_args)
@property def diagonal(self): return self._diagonal @property def scale_tril(self): return self._scale_tril @property def precision_tril(self): return self._precision_tril @property def loc(self): return self._loc @property def mean(self): return self._loc @property def mode(self): return self._loc @property def is_ldlt_parameterization(self): r"""Return ``True`` if using :math:`L D L^\top` parameterization (``diagonal`` provided), else ``False`` (:math:`L L^\top`).""" return self._diagonal is not None
[docs] def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) if "_scale_tril" in self.__dict__: if self._diagonal is not None: # LDL^T parameterization: scale_tril is unit lower triangular eta = self._diagonal.sqrt() * eps x = _batch_sparse_mv(spmm, self._scale_tril, eta) + eta else: # LL^T parameterization: scale_tril is lower triangular with diagonal x = _batch_sparse_mv(spmm, self._scale_tril, eps) else: # 'precision_tril' in self.__dict__ if self._diagonal is not None: # LDL^T parameterization: precision_tril is unit lower triangular x = _batch_sparse_mv( spts, self._precision_tril, eps / (self._diagonal.sqrt()), upper=False, unitriangular=True, transpose=True, ) else: # LL^T parameterization: precision_tril is lower triangular with diagonal x = _batch_sparse_mv( spts, self._precision_tril, eps, upper=False, unitriangular=False, transpose=True, ) return self.loc + x
[docs] class SparseMultivariateNormalNative(Distribution): r""" Sparse multivariate normal (native ``torch.sparse.mm`` backend). This distribution models :math:`x \sim \mathcal N(\mu, \Sigma)` where the covariance is parameterized via a sparse Cholesky factor :math:`\Sigma = L L^\top`. Sampling uses ``torch.sparse.mm`` directly so gradients propagate to the CSR values of ``L``. **Scope & limitations** - Layout: **CSR only** (requirement of ``torch.sparse.mm``) - Batching: **unbatched only** (2D factor) - Parameterization: **LLᵀ (covariance) only** - No precision/LDLᵀ support in this native variant (see :class:`SparseMultivariateNormal`) Parameters ---------- loc : torch.Tensor Mean vector of shape ``(n,)``. scale_tril : torch.Tensor Sparse lower-triangular Cholesky factor ``L`` of shape ``(n, n)`` in ``torch.sparse_csr`` layout with positive diagonal. validate_args : bool, optional If ``True``, validate input shapes/dtypes where feasible. Attributes ---------- loc : torch.Tensor Mean vector. scale_tril : torch.Tensor CSR Cholesky factor ``L`` such that ``Σ = L @ L.T``. Notes ----- Sampling uses :math:`x = \mu + L \varepsilon, \quad \varepsilon \sim \mathcal N(0, I)`, implemented as dense–sparse matmul via ``torch.sparse.mm``. Although PyTorch docs historically understate gradient support for some sparse ops, in practice autograd computes gradients w.r.t. the CSR **values** here. ``covariance_matrix`` and ``variance`` are computed by densifying the factor, which can be memory-expensive for large problems. See Also -------- SparseMultivariateNormal Full-featured sparse MVN with COO/CSR, batched inputs, and precision/LDLᵀ forms. Examples -------- Basic usage (LLᵀ parameterization): >>> import torch >>> n = 4 >>> # Build a small lower-triangular with positive diagonal in CSR >>> crow = torch.tensor([0, 1, 3, 4, 4], dtype=torch.int64) >>> col = torch.tensor([0, 0, 1, 2], dtype=torch.int64) >>> vals = torch.tensor([1.0, 0.2, 1.1, 0.3], dtype=torch.float64) >>> L = torch.sparse_csr_tensor(crow, col, vals, size=(n, n)) >>> loc = torch.zeros(n, dtype=torch.float64) >>> mvn = SparseMultivariateNormalNative(loc, L) >>> x = mvn.rsample() # (n,) >>> x.shape torch.Size([4]) Multiple samples: >>> xs = mvn.rsample((100,)) # (100, n) >>> xs.shape torch.Size([100, 4]) Log probability (densifies internally): >>> lp = mvn.log_prob(x) >>> torch.isfinite(lp).item() # doctest: +SKIP True """ arg_constraints = { "loc": constraints.real_vector, # TODO: create custom sparse lower triangular constraint # 'scale_tril': constraints.lower_cholesky, } support = constraints.real_vector has_rsample = True
[docs] def __init__(self, loc, scale_tril, validate_args=None): if loc.dim() != 1: raise ValueError("loc must be one-dimensional for SparseMultivariateNormalNative.") if scale_tril.layout != torch.sparse_csr: raise ValueError("scale_tril must be sparse CSR for SparseMultivariateNormalNative.") if scale_tril.dim() != 2: raise ValueError("scale_tril must be two-dimensional (unbatched) for SparseMultivariateNormalNative.") if scale_tril.shape[0] != scale_tril.shape[1]: raise ValueError("scale_tril must be square.") if scale_tril.shape[0] != loc.shape[0]: raise ValueError("scale_tril must have the same size as loc.") event_shape = loc.shape self._loc = loc self._scale_tril = scale_tril # No batch dimensions for this implementation batch_shape = torch.Size() super().__init__(batch_shape, event_shape, validate_args=validate_args)
@property def scale_tril(self): return self._scale_tril @property def loc(self): return self._loc @property def mean(self): return self._loc @property def mode(self): return self._loc @property def covariance_matrix(self): r"""Compute covariance matrix :math:`\Sigma = L L^\top` as ``L @ L.T`` using sparse operations.""" # Convert to dense for covariance computation - this is expensive but needed warnings.warn( "Computing covariance_matrix requires converting sparse matrix to dense format. " "This may cause memory issues for large sparse matrices. " "Consider using variance property for diagonal elements only.", UserWarning, stacklevel=2, ) L_dense = self._scale_tril.to_dense() return L_dense @ L_dense.T @property def variance(self): r"""Compute diagonal (variance) of the covariance matrix :math:`\operatorname{diag}(L L^\top)`.""" # For LL^T parameterization, variance is sum of squares of each row warnings.warn( "Computing variance requires converting sparse matrix to dense format. " "This may cause memory issues for large sparse matrices.", UserWarning, stacklevel=2, ) L_dense = self._scale_tril.to_dense() return (L_dense**2).sum(dim=-1)
[docs] def rsample(self, sample_shape=torch.Size()): r"""Sample from the distribution using :func:`torch.sparse.mm` (reparameterized).""" shape = self._extended_shape(sample_shape) eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) # For unbatched case: eps is (num_samples, event_size) or (event_size,) # We need to use torch.sparse.mm(scale_tril, eps.T).T if eps.dim() == 1: # Single sample case x = torch.sparse.mm(self._scale_tril, eps.unsqueeze(-1)).squeeze(-1) else: # Multiple samples case x = torch.sparse.mm(self._scale_tril, eps.t()).t() return self.loc + x
[docs] def log_prob(self, value): r"""Compute log probability density (densifies internally, may be memory intensive).""" if self._validate_args: self._validate_sample(value) # Convert to dense for log_prob computation warnings.warn( "Computing log_prob requires converting sparse matrix to dense format. " "This may cause memory issues for large sparse matrices. " "Consider using rsample() only if you don't need log_prob computation.", UserWarning, stacklevel=2, ) L_dense = self._scale_tril.to_dense() # Solve L @ z = (value - loc) for z diff = value - self.loc if diff.dim() == 1: z = torch.linalg.solve_triangular(L_dense, diff.unsqueeze(-1), upper=False).squeeze(-1) else: z = torch.linalg.solve_triangular(L_dense, diff.t(), upper=False).t() # Compute log probability M = (z**2).sum(-1) # Mahalanobis distance squared half_log_det = L_dense.diagonal().log().sum() return -0.5 * (self.event_shape[0] * math.log(2 * math.pi) + M) - half_log_det