Probability Distributions

This module contains sparse multivariate normal distributions.

Sparse Multivariate Normal

class torchsparsegradutils.distributions.sparse_multivariate_normal.SparseMultivariateNormal(loc, diagonal=None, scale_tril=None, precision_tril=None, validate_args=None)[source]

Bases: Distribution

Multivariate normal with sparse Cholesky / LDL^T parameterizations.

Supports sparse covariance or sparse precision factors using either the standard Cholesky (\(L L^\top\)) or the modified Cholesky (\(L D L^\top\)) parameterization. Sparse triangular factors can be given in COO or CSR layout and (optionally) batched with a single leading batch dimension.

Parameterizations

Cholesky (LL^T)
  • Covariance form: \(\Sigma = L L^\top\) with lower-triangular \(L\) (incl. diagonal).

  • Precision form: \(\Omega = L L^\top\) with lower-triangular \(L\) (incl. diagonal).

Modified Cholesky (LDL^T)
  • Covariance form: \(\Sigma = L D L^\top\) with unit lower-triangular \(L\) and diagonal \(D = \operatorname{diag}(\text{diagonal}) > 0\).

  • Precision form: \(\Omega = L D L^\top\) with unit lower-triangular \(L\) and diagonal \(D = \operatorname{diag}(\text{diagonal})\) (entries may be any real numbers if you only care about sampling via \(\Omega^{-1/2}\)).

  • In both cases, the strictly lower part lives in the sparse *_tril factor and the diagonal is provided separately via diagonal.

param loc:

Mean vector, shape (n,) or (B, n).

type loc:

torch.Tensor

param diagonal:

Diagonal entries for the \(D\) in \(L D L^\top\). Shape (n,) or (B, n). Must be positive for covariance parameterization; free real for precision parameterization. If None, the \(L L^\top\) parameterization is used.

type diagonal:

torch.Tensor, optional

param scale_tril:

Sparse lower-triangular factor for covariance (\(L\)). COO or CSR with shape (n, n) or (B, n, n). Mutually exclusive with precision_tril.

type scale_tril:

torch.Tensor, optional

param precision_tril:

Sparse lower-triangular factor for precision (\(L\)). COO or CSR with shape (n, n) or (B, n, n). Mutually exclusive with scale_tril.

type precision_tril:

torch.Tensor, optional

param validate_args:

If True, run argument checks.

type validate_args:

bool, optional

loc

Mean.

Type:

torch.Tensor

diagonal

Diagonal entries for \(D\) in \(L D L^\top\) form, if provided.

Type:

torch.Tensor or None

scale_tril

Covariance Cholesky factor (sparse lower-triangular), if provided.

Type:

torch.Tensor or None

precision_tril

Precision Cholesky factor (sparse lower-triangular), if provided.

Type:

torch.Tensor or None

has_rsample

Reparameterized sampling is supported.

Type:

bool

returns:

A distribution instance compatible with torch.distributions.

rtype:

SparseMultivariateNormal

raises ValueError:

On invalid shapes, unsupported layouts, or incompatible batching.

Notes

  • Only a single batch dimension is supported for the sparse factors.

  • Sampling uses reparameterization:

    • Covariance \(L L^\top\): \(x = L \varepsilon\).

    • Covariance \(L D L^\top\): \(x = L (\sqrt{D} \odot \varepsilon) + (\sqrt{D} \odot \varepsilon)\) since \(L\) is unit lower-triangular (strictly lower part in sparse factor).

    • Precision forms use sparse triangular solves with the transpose.

  • log_prob is not implemented in this class.

  • Sparse operations are delegated to torchsparsegradutils.sparse_mm() and torchsparsegradutils.sparse_triangular_solve().

See also

torch.distributions.MultivariateNormal

Dense baseline distribution.

torchsparsegradutils.sparse_mm

Sparse matrix–dense matrix multiply used during sampling.

torchsparsegradutils.sparse_triangular_solve

Sparse triangular solver used with precision factors.

Examples

\(L L^\top\) parameterization with sparse covariance:

>>> import torch
>>> from torchsparsegradutils.distributions import SparseMultivariateNormal
>>> loc = torch.zeros(4)
>>> indices = torch.tensor([[1, 2, 2, 3], [0, 0, 1, 2]])
>>> values = torch.tensor([0.5, 0.3, 0.8, 0.2])
>>> diag_indices = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]])
>>> diag_values = torch.tensor([1.0, 1.0, 1.0, 1.0])
>>> all_indices = torch.cat([diag_indices, indices], dim=1)
>>> all_values = torch.cat([diag_values, values])
>>> scale_tril = torch.sparse_coo_tensor(all_indices, all_values, (4, 4))
>>> mvn = SparseMultivariateNormal(loc=loc, scale_tril=scale_tril)
>>> samples = mvn.sample((100,))
>>> samples.shape
torch.Size([100, 4])

\(L D L^\top\) parameterization with diagonal precision:

>>> diagonal = torch.tensor([2.0, 1.5, 3.0, 1.0])  # Precision diagonal
>>> precision_tril = torch.sparse_coo_tensor(indices, values, (4, 4))
>>> mvn_ldlt = SparseMultivariateNormal(loc=loc, diagonal=diagonal,
...                                     precision_tril=precision_tril)
>>> samples = mvn_ldlt.sample((50,))

Batched distributions:

>>> loc_batch = torch.randn(3, 4)
>>> diagonal_batch = torch.abs(torch.randn(3, 4)) + 0.1
>>> precision_batch = torch.stack([precision_tril] * 3)
>>> mvn_batch = SparseMultivariateNormal(loc=loc_batch, diagonal=diagonal_batch,
...                                      precision_tril=precision_batch)
>>> samples = mvn_batch.sample()
>>> samples.shape
torch.Size([3, 4])
arg_constraints = {}
support = IndependentConstraint(Real(), 1)
__init__(loc, diagonal=None, scale_tril=None, precision_tril=None, validate_args=None)[source]
property mean

Returns the mean of the distribution.

property mode

Returns the mode of the distribution.

property is_ldlt_parameterization

Return True if using \(L D L^\top\) parameterization (diagonal provided), else False (\(L L^\top\)).

rsample(sample_shape=())[source]

Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched.

class torchsparsegradutils.distributions.sparse_multivariate_normal.SparseMultivariateNormalNative(loc, scale_tril, validate_args=None)[source]

Bases: Distribution

Sparse multivariate normal (native torch.sparse.mm backend).

This distribution models \(x \sim \mathcal N(\mu, \Sigma)\) where the covariance is parameterized via a sparse Cholesky factor \(\Sigma = L L^\top\). Sampling uses torch.sparse.mm directly so gradients propagate to the CSR values of L.

Scope & limitations

  • Layout: CSR only (requirement of torch.sparse.mm)

  • Batching: unbatched only (2D factor)

  • Parameterization: LLᵀ (covariance) only

  • No precision/LDLᵀ support in this native variant (see SparseMultivariateNormal)

Parameters:
  • loc (torch.Tensor) – Mean vector of shape (n,).

  • scale_tril (torch.Tensor) – Sparse lower-triangular Cholesky factor L of shape (n, n) in torch.sparse_csr layout with positive diagonal.

  • validate_args (bool, optional) – If True, validate input shapes/dtypes where feasible.

loc

Mean vector.

Type:

torch.Tensor

scale_tril

CSR Cholesky factor L such that Σ = L @ L.T.

Type:

torch.Tensor

Notes

Sampling uses \(x = \mu + L \varepsilon, \quad \varepsilon \sim \mathcal N(0, I)\), implemented as dense–sparse matmul via torch.sparse.mm. Although PyTorch docs historically understate gradient support for some sparse ops, in practice autograd computes gradients w.r.t. the CSR values here.

covariance_matrix and variance are computed by densifying the factor, which can be memory-expensive for large problems.

See also

SparseMultivariateNormal

Full-featured sparse MVN with COO/CSR, batched inputs, and precision/LDLᵀ forms.

Examples

Basic usage (LLᵀ parameterization):

>>> import torch
>>> n = 4
>>> # Build a small lower-triangular with positive diagonal in CSR
>>> crow = torch.tensor([0, 1, 3, 4, 4], dtype=torch.int64)
>>> col  = torch.tensor([0, 0, 1, 2], dtype=torch.int64)
>>> vals = torch.tensor([1.0, 0.2, 1.1, 0.3], dtype=torch.float64)
>>> L = torch.sparse_csr_tensor(crow, col, vals, size=(n, n))
>>> loc = torch.zeros(n, dtype=torch.float64)
>>> mvn = SparseMultivariateNormalNative(loc, L)
>>> x = mvn.rsample()   # (n,)
>>> x.shape
torch.Size([4])

Multiple samples:

>>> xs = mvn.rsample((100,))  # (100, n)
>>> xs.shape
torch.Size([100, 4])

Log probability (densifies internally):

>>> lp = mvn.log_prob(x)
>>> torch.isfinite(lp).item()  
True
arg_constraints = {'loc': IndependentConstraint(Real(), 1)}
support = IndependentConstraint(Real(), 1)
has_rsample = True
__init__(loc, scale_tril, validate_args=None)[source]
property mean

Returns the mean of the distribution.

property mode

Returns the mode of the distribution.

property covariance_matrix

Compute covariance matrix \(\Sigma = L L^\top\) as L @ L.T using sparse operations.

property variance

Compute diagonal (variance) of the covariance matrix \(\operatorname{diag}(L L^\top)\).

rsample(sample_shape=())[source]

Sample from the distribution using torch.sparse.mm() (reparameterized).

log_prob(value)[source]

Compute log probability density (densifies internally, may be memory intensive).

Constraints