Probability Distributions
This module contains sparse multivariate normal distributions.
Sparse Multivariate Normal
- class torchsparsegradutils.distributions.sparse_multivariate_normal.SparseMultivariateNormal(loc, diagonal=None, scale_tril=None, precision_tril=None, validate_args=None)[source]
Bases:
DistributionMultivariate normal with sparse Cholesky / LDL^T parameterizations.
Supports sparse covariance or sparse precision factors using either the standard Cholesky (\(L L^\top\)) or the modified Cholesky (\(L D L^\top\)) parameterization. Sparse triangular factors can be given in COO or CSR layout and (optionally) batched with a single leading batch dimension.
Parameterizations
- Cholesky (LL^T)
Covariance form: \(\Sigma = L L^\top\) with lower-triangular \(L\) (incl. diagonal).
Precision form: \(\Omega = L L^\top\) with lower-triangular \(L\) (incl. diagonal).
- Modified Cholesky (LDL^T)
Covariance form: \(\Sigma = L D L^\top\) with unit lower-triangular \(L\) and diagonal \(D = \operatorname{diag}(\text{diagonal}) > 0\).
Precision form: \(\Omega = L D L^\top\) with unit lower-triangular \(L\) and diagonal \(D = \operatorname{diag}(\text{diagonal})\) (entries may be any real numbers if you only care about sampling via \(\Omega^{-1/2}\)).
In both cases, the strictly lower part lives in the sparse
*_trilfactor and the diagonal is provided separately viadiagonal.
- param loc:
Mean vector, shape
(n,)or(B, n).- type loc:
torch.Tensor
- param diagonal:
Diagonal entries for the \(D\) in \(L D L^\top\). Shape
(n,)or(B, n). Must be positive for covariance parameterization; free real for precision parameterization. IfNone, the \(L L^\top\) parameterization is used.- type diagonal:
torch.Tensor, optional
- param scale_tril:
Sparse lower-triangular factor for covariance (\(L\)). COO or CSR with shape
(n, n)or(B, n, n). Mutually exclusive withprecision_tril.- type scale_tril:
torch.Tensor, optional
- param precision_tril:
Sparse lower-triangular factor for precision (\(L\)). COO or CSR with shape
(n, n)or(B, n, n). Mutually exclusive withscale_tril.- type precision_tril:
torch.Tensor, optional
- param validate_args:
If
True, run argument checks.- type validate_args:
bool, optional
- loc
Mean.
- Type:
- diagonal
Diagonal entries for \(D\) in \(L D L^\top\) form, if provided.
- Type:
torch.Tensor or None
- scale_tril
Covariance Cholesky factor (sparse lower-triangular), if provided.
- Type:
torch.Tensor or None
- precision_tril
Precision Cholesky factor (sparse lower-triangular), if provided.
- Type:
torch.Tensor or None
- returns:
A distribution instance compatible with
torch.distributions.- rtype:
SparseMultivariateNormal
- raises ValueError:
On invalid shapes, unsupported layouts, or incompatible batching.
Notes
Only a single batch dimension is supported for the sparse factors.
Sampling uses reparameterization:
Covariance \(L L^\top\): \(x = L \varepsilon\).
Covariance \(L D L^\top\): \(x = L (\sqrt{D} \odot \varepsilon) + (\sqrt{D} \odot \varepsilon)\) since \(L\) is unit lower-triangular (strictly lower part in sparse factor).
Precision forms use sparse triangular solves with the transpose.
log_probis not implemented in this class.Sparse operations are delegated to
torchsparsegradutils.sparse_mm()andtorchsparsegradutils.sparse_triangular_solve().
See also
torch.distributions.MultivariateNormalDense baseline distribution.
torchsparsegradutils.sparse_mmSparse matrix–dense matrix multiply used during sampling.
torchsparsegradutils.sparse_triangular_solveSparse triangular solver used with precision factors.
Examples
\(L L^\top\) parameterization with sparse covariance:
>>> import torch >>> from torchsparsegradutils.distributions import SparseMultivariateNormal >>> loc = torch.zeros(4) >>> indices = torch.tensor([[1, 2, 2, 3], [0, 0, 1, 2]]) >>> values = torch.tensor([0.5, 0.3, 0.8, 0.2]) >>> diag_indices = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]]) >>> diag_values = torch.tensor([1.0, 1.0, 1.0, 1.0]) >>> all_indices = torch.cat([diag_indices, indices], dim=1) >>> all_values = torch.cat([diag_values, values]) >>> scale_tril = torch.sparse_coo_tensor(all_indices, all_values, (4, 4)) >>> mvn = SparseMultivariateNormal(loc=loc, scale_tril=scale_tril) >>> samples = mvn.sample((100,)) >>> samples.shape torch.Size([100, 4])
\(L D L^\top\) parameterization with diagonal precision:
>>> diagonal = torch.tensor([2.0, 1.5, 3.0, 1.0]) # Precision diagonal >>> precision_tril = torch.sparse_coo_tensor(indices, values, (4, 4)) >>> mvn_ldlt = SparseMultivariateNormal(loc=loc, diagonal=diagonal, ... precision_tril=precision_tril) >>> samples = mvn_ldlt.sample((50,))
Batched distributions:
>>> loc_batch = torch.randn(3, 4) >>> diagonal_batch = torch.abs(torch.randn(3, 4)) + 0.1 >>> precision_batch = torch.stack([precision_tril] * 3) >>> mvn_batch = SparseMultivariateNormal(loc=loc_batch, diagonal=diagonal_batch, ... precision_tril=precision_batch) >>> samples = mvn_batch.sample() >>> samples.shape torch.Size([3, 4])
- arg_constraints = {}
- support = IndependentConstraint(Real(), 1)
- property mean
Returns the mean of the distribution.
- property mode
Returns the mode of the distribution.
- property is_ldlt_parameterization
Return
Trueif using \(L D L^\top\) parameterization (diagonalprovided), elseFalse(\(L L^\top\)).
- class torchsparsegradutils.distributions.sparse_multivariate_normal.SparseMultivariateNormalNative(loc, scale_tril, validate_args=None)[source]
Bases:
DistributionSparse multivariate normal (native
torch.sparse.mmbackend).This distribution models \(x \sim \mathcal N(\mu, \Sigma)\) where the covariance is parameterized via a sparse Cholesky factor \(\Sigma = L L^\top\). Sampling uses
torch.sparse.mmdirectly so gradients propagate to the CSR values ofL.Scope & limitations
Layout: CSR only (requirement of
torch.sparse.mm)Batching: unbatched only (2D factor)
Parameterization: LLᵀ (covariance) only
No precision/LDLᵀ support in this native variant (see
SparseMultivariateNormal)
- Parameters:
loc (torch.Tensor) – Mean vector of shape
(n,).scale_tril (torch.Tensor) – Sparse lower-triangular Cholesky factor
Lof shape(n, n)intorch.sparse_csrlayout with positive diagonal.validate_args (bool, optional) – If
True, validate input shapes/dtypes where feasible.
- loc
Mean vector.
- Type:
- scale_tril
CSR Cholesky factor
Lsuch thatΣ = L @ L.T.- Type:
Notes
Sampling uses \(x = \mu + L \varepsilon, \quad \varepsilon \sim \mathcal N(0, I)\), implemented as dense–sparse matmul via
torch.sparse.mm. Although PyTorch docs historically understate gradient support for some sparse ops, in practice autograd computes gradients w.r.t. the CSR values here.covariance_matrixandvarianceare computed by densifying the factor, which can be memory-expensive for large problems.See also
SparseMultivariateNormalFull-featured sparse MVN with COO/CSR, batched inputs, and precision/LDLᵀ forms.
Examples
Basic usage (LLᵀ parameterization):
>>> import torch >>> n = 4 >>> # Build a small lower-triangular with positive diagonal in CSR >>> crow = torch.tensor([0, 1, 3, 4, 4], dtype=torch.int64) >>> col = torch.tensor([0, 0, 1, 2], dtype=torch.int64) >>> vals = torch.tensor([1.0, 0.2, 1.1, 0.3], dtype=torch.float64) >>> L = torch.sparse_csr_tensor(crow, col, vals, size=(n, n)) >>> loc = torch.zeros(n, dtype=torch.float64) >>> mvn = SparseMultivariateNormalNative(loc, L) >>> x = mvn.rsample() # (n,) >>> x.shape torch.Size([4])
Multiple samples:
>>> xs = mvn.rsample((100,)) # (100, n) >>> xs.shape torch.Size([100, 4])
Log probability (densifies internally):
>>> lp = mvn.log_prob(x) >>> torch.isfinite(lp).item() True
- arg_constraints = {'loc': IndependentConstraint(Real(), 1)}
- support = IndependentConstraint(Real(), 1)
- has_rsample = True
- property mean
Returns the mean of the distribution.
- property mode
Returns the mode of the distribution.
- property covariance_matrix
Compute covariance matrix \(\Sigma = L L^\top\) as
L @ L.Tusing sparse operations.
- property variance
Compute diagonal (variance) of the covariance matrix \(\operatorname{diag}(L L^\top)\).
- rsample(sample_shape=())[source]
Sample from the distribution using
torch.sparse.mm()(reparameterized).