Quick Start Guide

This guide will help you get started with torchsparsegradutils quickly.

Basic Sparse Matrix Operations

Sparse Matrix Multiplication

The core function sparse_mm() performs memory-efficient sparse matrix multiplication:

import torch
from torchsparsegradutils import sparse_mm

# Create a sparse matrix in COO format
indices = torch.tensor([[0, 1, 1], [2, 0, 2]], dtype=torch.int64)
values = torch.tensor([3., 4., 5.])
A = torch.sparse_coo_tensor(indices, values, (2, 3))
A.requires_grad_(True)

# Create a dense matrix
B = torch.randn(3, 4, requires_grad=True)

# Perform sparse matrix multiplication with gradient support
result = sparse_mm(A, B)
print(f"Result shape: {result.shape}")  # torch.Size([2, 4])

# The operation preserves sparsity in gradients
loss = result.sum()
loss.backward()
print(f"A gradient is sparse: {A.grad.is_sparse}")  # True
print(f"A gradient nnz: {A.grad._nnz()}")  # Number of non-zeros

Batched Operations

torchsparsegradutils supports batched operations, for sparse_mm and triangular_solve:

import torch
from torchsparsegradutils import sparse_mm

# Create batched sparse matrices
batch_size = 2

# Method 1: Stack individual sparse matrices
A1 = torch.sparse_coo_tensor([[0, 1], [0, 1]], [1., 2.], (2, 2))
A2 = torch.sparse_coo_tensor([[0, 1], [1, 0]], [3., 4.], (2, 2))
A_batch = torch.stack([A1, A2])  # Shape: (2, 2, 2)

# Dense batch
B = torch.randn(batch_size, 2, 3)

# Batched multiplication
result = sparse_mm(A_batch, B)
print(result.shape)  # torch.Size([2, 2, 3])

Solving Linear Systems

Triangular Systems

For sparse triangular systems, use sparse_triangular_solve():

import torch
from torchsparsegradutils import sparse_triangular_solve
from torchsparsegradutils.utils.random_sparse import rand_sparse_tri

# Create a sparse lower triangular matrix
L = rand_sparse_tri((3, 3), nnz=5, upper=False, layout=torch.sparse_csr)

# Right-hand side
b = torch.randn(3, 2)

# Solve Lx = b
x = sparse_triangular_solve(L, b, upper=False)

# Verify solution (should be close to zero)
from torchsparsegradutils import sparse_mm
residual = sparse_mm(L, x) - b
print(f"Residual norm: {torch.norm(residual):.6f}")

General Linear Systems

For general sparse linear systems, use sparse_generic_solve():

import torch
from torchsparsegradutils import sparse_generic_solve, sparse_mm
from torchsparsegradutils.utils.random_sparse import make_spd_sparse
from torchsparsegradutils.utils import minres

# Create a sparse symmetric positive definite matrix
A_sparse, A_dense = make_spd_sparse(10, torch.sparse_csr, torch.float32, torch.int64, 'cpu')
b = torch.randn(10)

# Solve using MINRES solver
x_minres = sparse_generic_solve(A_sparse, b, solve=minres, tol=1e-6)
print(f"MINRES solution shape: {x_minres.shape}")

# Verify solution
residual = A_sparse @ x_minres - b
print(f"MINRES residual norm: {torch.norm(residual):.6f}")

Sparse Multivariate Normal Distributions

Basic Usage with LDL^T Parameterization

Create and sample from sparse multivariate normal distributions:

import torch
from torchsparsegradutils.distributions import SparseMultivariateNormal
from torchsparsegradutils.utils.random_sparse import rand_sparse_tri

# Create parameters
dim = 10
loc = torch.zeros(dim)

# LDL^T parameterization (recommended for stability)
diagonal = torch.ones(dim) * 0.5
scale_tril = rand_sparse_tri((dim, dim), nnz=15, upper=False,
                            layout=torch.sparse_coo, strict=False)

# Create distribution with LDL^T covariance parameterization
dist = SparseMultivariateNormal(
    loc=loc,
    diagonal=diagonal,
    scale_tril=scale_tril
)

# Sample from the distribution
samples = dist.rsample((100,))
print(f"Samples shape: {samples.shape}")  # torch.Size([100, 10])

Gradient Computation

The distributions support reparameterized sampling for gradient computation:

# Enable gradients for parameters
loc = torch.zeros(dim, requires_grad=True)
diagonal = torch.ones(dim, requires_grad=True)
scale_tril = rand_sparse_tri((dim, dim), nnz=15, upper=False,
                            layout=torch.sparse_coo, strict=True)
scale_tril.requires_grad_(True)

# Create distribution
dist = SparseMultivariateNormal(
    loc=loc,
    diagonal=diagonal,
    scale_tril=scale_tril
)

# Sample using rsample for gradient flow
samples = dist.rsample((10,))

# Compute some loss
loss = samples.mean()
loss.backward()

print(f"Location gradient norm: {torch.norm(loc.grad):.6f}")
print(f"Diagonal gradient norm: {torch.norm(diagonal.grad):.6f}")
print(f"Scale gradient nnz: {scale_tril.grad._nnz()}")

Working with Different Backends

CuPy Backend (GPU Acceleration)

For GPU acceleration with CuPy:

import torch
from torchsparsegradutils.cupy import sparse_solve_c4t
from torchsparsegradutils.utils.random_sparse import make_spd_sparse

# Create matrices on GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
A_sparse, A_dense = make_spd_sparse(50, torch.sparse_csr, torch.float32, torch.int64, device)
b = torch.randn(50, device=device)

# Solve using CuPy backend (only works on GPU)
if device.type == 'cuda':
    x = sparse_solve_c4t(A_sparse, b, solve='cg', tol=1e-6)
    print(f"CuPy solution on GPU: {x.shape}")
else:
    print("CUDA not available, skipping CuPy example")

JAX Backend

For JAX integration:

import torch
from torchsparsegradutils.utils.random_sparse import make_spd_sparse

try:
    from torchsparsegradutils.jax import sparse_solve_j4t
    from jax.scipy.sparse.linalg import cg

    # Create test matrices
    A_sparse, A_dense = make_spd_sparse(20, torch.sparse_csr, torch.float32, torch.int64, 'cpu')
    b = torch.randn(20)

    # Solve using JAX backend
    x = sparse_solve_j4t(A_sparse, b, solve=cg, tol=1e-6)
    print(f"JAX solution shape: {x.shape}")

except ImportError:
    print("JAX not available, skipping JAX example")

Next Steps