torchsparsegradutils Documentation

Python tests License Code Style: Black

torchsparsegradutils is a comprehensive collection of utility functions to work with PyTorch sparse tensors, ensuring memory efficiency and supporting various sparsity-preserving tensor operations with automatic differentiation. This package addresses fundamental gaps in PyTorch’s sparse tensor ecosystem, providing essential operations that preserve sparsity in gradients during backpropagation.

πŸš€ Key Features

Core Sparse Operations with Sparse Gradient Support

Memory-Efficient Sparse Matrix Multiplication

  • sparse_mm(): Memory-efficient sparse matrix multiplication with batch support

  • Preserves sparsity in gradients during backpropagation

  • Workaround for PyTorch issue #41128

  • Supports both COO and CSR formats with optional batching

Sparse Linear System Solvers

Built-in Iterative Solvers (No External Dependencies)

Pure PyTorch Implementations

  • BICGSTAB: Biconjugate Gradient Stabilized method

  • CG: Conjugate Gradient method

  • LSMR: Least Squares Minimal Residual method

  • MINRES: Minimal Residual method

Sparse Multivariate Normal Distributions

  • SparseMultivariateNormal: Structured Gaussian Distribution with reparameterised sampling

  • SparseMultivariateNormalNative: Native implementation using torch.sparse.mm

Installation

Install from PyPI:

pip install torchsparsegradutils

For development installation:

git clone https://github.com/cai4cai/torchsparsegradutils.git
cd torchsparsegradutils
pip install -e .

Quick Start

import torch
from torchsparsegradutils import sparse_mm, sparse_triangular_solve

# Create sparse matrix
indices = torch.tensor([[0, 1, 1], [2, 0, 2]])
values = torch.tensor([3., 4., 5.])
A = torch.sparse_coo_tensor(indices, values, (2, 3))

# Dense matrix
B = torch.randn(3, 4)

# Sparse matrix multiplication
result = sparse_mm(A, B)

# The result preserves sparsity in gradients
loss = result.sum()
loss.backward()

Indices and tables