torchsparsegradutils Documentationο
torchsparsegradutils is a comprehensive collection of utility functions to work with PyTorch sparse tensors, ensuring memory efficiency and supporting various sparsity-preserving tensor operations with automatic differentiation. This package addresses fundamental gaps in PyTorchβs sparse tensor ecosystem, providing essential operations that preserve sparsity in gradients during backpropagation.
π Key Featuresο
Core Sparse Operations with Sparse Gradient Supportο
Memory-Efficient Sparse Matrix Multiplication
sparse_mm(): Memory-efficient sparse matrix multiplication with batch supportPreserves sparsity in gradients during backpropagation
Workaround for PyTorch issue #41128
Supports both COO and CSR formats with optional batching
Sparse Linear System Solvers
sparse_triangular_solve(): Sparse triangular solver with batch supportsparse_generic_solve(): Generic sparse linear solver with pluggable backendssparse_generic_lstsq(): Generic sparse linear least-squares solver
Built-in Iterative Solvers (No External Dependencies)ο
Pure PyTorch Implementations
BICGSTAB: Biconjugate Gradient Stabilized method
CG: Conjugate Gradient method
LSMR: Least Squares Minimal Residual method
MINRES: Minimal Residual method
Sparse Multivariate Normal Distributionsο
SparseMultivariateNormal: Structured Gaussian Distribution with reparameterised sampling
SparseMultivariateNormalNative: Native implementation using torch.sparse.mm
Installationο
Install from PyPI:
pip install torchsparsegradutils
For development installation:
git clone https://github.com/cai4cai/torchsparsegradutils.git
cd torchsparsegradutils
pip install -e .
Quick Startο
import torch
from torchsparsegradutils import sparse_mm, sparse_triangular_solve
# Create sparse matrix
indices = torch.tensor([[0, 1, 1], [2, 0, 2]])
values = torch.tensor([3., 4., 5.])
A = torch.sparse_coo_tensor(indices, values, (2, 3))
# Dense matrix
B = torch.randn(3, 4)
# Sparse matrix multiplication
result = sparse_mm(A, B)
# The result preserves sparsity in gradients
loss = result.sum()
loss.backward()