Source code for torchsparsegradutils.utils.utils

from typing import List, Tuple

import torch


[docs] def stack_csr( tensors: List[torch.Tensor], dim: int = 0, ) -> torch.Tensor: """ Stack CSR sparse tensors along a new dimension. This function is analogous to :func:`torch.stack`, but specifically designed for CSR (Compressed Sparse Row) tensors. Unlike COO tensors, CSR tensors are **not** currently supported by :func:`torch.stack`, hence this helper provides the missing functionality. Parameters ---------- tensors : list of torch.Tensor List of 2D CSR sparse tensors to be stacked. All tensors must have the same shape and layout. dim : int, default=0 Dimension along which to stack the tensors. Returns ------- torch.Tensor A CSR sparse tensor with an additional dimension of size ``len(tensors)`` inserted at position ``dim``. Raises ------ TypeError If ``tensors`` is not a list or tuple. ValueError If ``tensors`` is empty, contain tensors of different shapes, are not in CSR format, or are not 2D. Notes ----- - :func:`torch.stack` supports COO sparse tensors but not CSR. This function fills that gap by implementing stacking logic for CSR tensors. Examples -------- Stack multiple 2D CSR tensors: >>> import torch >>> from torchsparsegradutils.utils import stack_csr >>> crow = torch.tensor([0, 1, 2]) >>> col = torch.tensor([0, 1]) >>> A = torch.sparse_csr_tensor(crow, col, torch.tensor([1.0, 2.0]), (2, 2)) >>> B = torch.sparse_csr_tensor(crow, col, torch.tensor([3.0, 4.0]), (2, 2)) >>> stacked = stack_csr([A, B]) >>> stacked.shape torch.Size([2, 2, 2]) The new dimension is a CSR batch dimension: >>> stacked.crow_indices().shape torch.Size([2, 3]) """ if not isinstance(tensors, (list, tuple)): raise TypeError("Expected a list of tensors, but got {}.".format(type(tensors))) if len(tensors) == 0: raise ValueError("Cannot stack empty list of tensors.") if not all([tensor.shape == tensors[0].shape for tensor in tensors]): raise ValueError("All tensors must have the same shape.") if not all([tensor.layout == torch.sparse_csr for tensor in tensors]): raise ValueError("All tensors must be in CSR layout.") if not all([tensor.ndim == 2 for tensor in tensors]): raise ValueError("All tensors must be 2D.") crow_indices = torch.stack([tensor.crow_indices() for tensor in tensors], dim=dim) col_indices = torch.stack([tensor.col_indices() for tensor in tensors], dim=dim) values = torch.stack([tensor.values() for tensor in tensors], dim=dim) shape = list(tensors[0].shape) shape.insert(dim, len(tensors)) shape = tuple(shape) return torch.sparse_csr_tensor(crow_indices, col_indices, values, shape)
def _sort_coo_indices( indices: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Sort COO indices in ascending lexicographic order with permutation tracking. This function sorts COO (Coordinate List) format indices and returns both the sorted indices and a permutation tensor mapping the original order to the sorted order. It performs a similar role to ``.coalesce()`` but works directly on index tensors and supports both int32 and int64 indices. Parameters ---------- indices : torch.Tensor COO indices with shape ``(2, nnz)`` for unbatched tensors or ``(3, nnz)`` for batched tensors. Returns ------- indices_sorted : torch.Tensor Sorted indices in the same shape and dtype as the input. permutation : torch.Tensor Tensor of shape ``(nnz,)`` giving the permutation mapping from the original indices to the sorted indices. Notes ----- - Sorting is lexicographic: first by row indices, then by column indices. - For batched tensors (3, nnz), sorting is performed across all batches jointly. If independent per-batch sorting is required, call this function separately on each batch slice. Examples -------- Sort unbatched COO indices: >>> import torch >>> from torchsparsegradutils.utils.utils import _sort_coo_indices >>> indices = torch.tensor([[1, 0, 1], [2, 1, 0]]) >>> sorted_indices, perm = _sort_coo_indices(indices) >>> sorted_indices tensor([[0, 1, 1], [1, 0, 2]]) Sort batched COO indices: >>> batched_indices = torch.tensor([ ... [0, 0, 1], # batch indices ... [1, 0, 1], # row indices ... [2, 1, 0] # col indices ... ]) >>> sorted_idx, perm = _sort_coo_indices(batched_indices) See Also -------- torch.Tensor.coalesce : Built-in method for sorting/merging duplicate COO indices. """ indices_sorted, permutation = torch.unique(indices, dim=-1, sorted=True, return_inverse=True) return indices_sorted.contiguous(), torch.argsort(permutation) def _compress_row_indices( row_indices: torch.Tensor, num_rows: int, ) -> torch.Tensor: """ Convert COO row indices to CSR crow-indices. Computes CSR (Compressed Sparse Row) ``crow_indices`` from a 1D tensor of COO row indices by counting non-zeros per row and taking a cumulative sum. Parameters ---------- row_indices : torch.Tensor 1D tensor of non-negative integer row indices with shape ``(nnz,)``. Values must be in ``[0, num_rows - 1]``. num_rows : int Total number of rows in the matrix. Returns ------- torch.Tensor CSR crow-indices of shape ``(num_rows + 1,)`` on the same device as ``row_indices``. By definition: ``crow[0] = 0``, ``crow[i+1] - crow[i]`` equals the number of non-zeros in row ``i``, and ``crow[-1] = nnz``. Raises ------ ValueError If ``row_indices`` is not 1D, contains out-of-range/negative values, or if ``num_rows`` is not positive. TypeError If ``row_indices`` is not an integer tensor. Notes ----- - Rows with zero non-zeros are handled naturally (the cumulative count repeats for those rows). - The output dtype matches ``row_indices.dtype`` (commonly ``int64`` or ``int32``). Examples -------- Basic compression: >>> import torch >>> from torchsparsegradutils.utils.utils import _compress_row_indices >>> row_indices = torch.tensor([0, 0, 2, 2]) >>> _compress_row_indices(row_indices, num_rows=3) tensor([0, 2, 2, 4]) Empty rows: >>> row_indices = torch.tensor([0, 2]) >>> _compress_row_indices(row_indices, num_rows=3) tensor([0, 1, 1, 2]) See Also -------- convert_coo_to_csr_indices_values : Convert full COO (row,col,values) to CSR arrays. torch.sparse_csr_tensor : Construct a CSR sparse tensor from (crow, col, values). """ if not isinstance(row_indices, torch.Tensor): raise TypeError("row_indices must be a torch.Tensor.") if row_indices.ndim != 1: raise ValueError(f"row_indices must be 1D, got shape {tuple(row_indices.shape)}.") if row_indices.dtype not in (torch.int32, torch.int64): raise TypeError("row_indices must have integer dtype (torch.int32 or torch.int64).") if not isinstance(num_rows, int) or num_rows <= 0: raise ValueError("num_rows must be a positive integer.") if row_indices.numel() > 0: if torch.any(row_indices < 0): raise ValueError("row_indices contains negative entries.") if torch.any(row_indices >= num_rows): raise ValueError("row_indices contains entries >= num_rows.") # Compute the number of non-zero elements in each row counts = torch.bincount(row_indices, minlength=num_rows).to(row_indices.dtype) # Compute the cumulative sum of counts to get CSR indices crow_indices = torch.cat([torch.zeros(1, dtype=row_indices.dtype, device=counts.device), counts.cumsum_(dim=0)]) return crow_indices
[docs] def convert_coo_to_csr_indices_values(coo_indices, num_rows, values=None): """ Convert COO indices to CSR format with optional value permutation. Converts COO (Coordinate List Format) row and column indices to CSR (Compressed Sparse Row) format. Supports both batched and unbatched indices. The function sorts the COO indices lexicographically and compresses row indices to CSR crow format. Parameters ---------- coo_indices : torch.Tensor COO indices tensor with shape (2, nnz) for unbatched or (3, nnz) for batched format. Rows are [row_idx, col_idx] or [batch_idx, row_idx, col_idx]. num_rows : int Number of rows in the matrix. values : torch.Tensor, optional Values tensor corresponding to COO indices. If provided, values are reordered according to the index sorting permutation. Returns ------- crow_indices : torch.Tensor CSR crow indices. For unbatched: shape (num_rows + 1,). For batched: shape (num_batches, num_rows + 1). col_indices : torch.Tensor CSR column indices. Same shape as original column indices but reordered according to sorting. values_or_permutation : torch.Tensor If values provided: reordered values tensor. If values is None: permutation indices from sorting. Raises ------ ValueError If indices tensor has wrong number of dimensions, row indices exceed num_rows, or values shape doesn't match indices. Examples -------- Unbatched COO to CSR conversion: >>> import torch >>> from torchsparsegradutils.utils import convert_coo_to_csr_indices_values >>> # COO indices for 3x3 matrix >>> coo_indices = torch.tensor([[0, 1, 2], [1, 0, 2]]) >>> values = torch.tensor([1.0, 2.0, 3.0]) >>> crow, col, vals = convert_coo_to_csr_indices_values( ... coo_indices, num_rows=3, values=values) >>> crow tensor([0, 1, 2, 3]) Batched conversion: >>> # Batched COO indices: 2 batches, each with 2 elements >>> batch_coo = torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1], [1, 0, 1, 0]]) >>> batch_values = torch.tensor([1.0, 2.0, 3.0, 4.0]) >>> crow, col, vals = convert_coo_to_csr_indices_values( ... batch_coo, num_rows=2, values=batch_values) Without values (get permutation): >>> crow, col, perm = convert_coo_to_csr_indices_values( ... coo_indices, num_rows=3) """ if coo_indices.shape[0] < 2: raise ValueError( f"Indices tensor must have at least 2 rows (row and column indices). Got {coo_indices.shape[0]} rows." ) elif coo_indices.shape[0] > 3: raise ValueError( f"Current implementation only supports single batch diomension, therefore indices tensor must have at most 3 rows (batch, row and column indices). Got {coo_indices.shape[0]} rows." ) if coo_indices[-2].max() >= num_rows: raise ValueError( f"Row indices must be less than num_rows ({num_rows}). Got max row index {coo_indices[-2].max()}" ) if values is not None and values.shape[0] != coo_indices.shape[1]: raise ValueError( f"Number of values ({values.shape[0]}) does not match number of indices ({coo_indices.shape[1]})" ) coo_indices, permutation = _sort_coo_indices(coo_indices) if coo_indices.shape[0] == 2: row_indices, col_indices = coo_indices crow_indices = _compress_row_indices(row_indices, num_rows) values = values[permutation] if values is not None else permutation else: batch_indices, row_indices, col_indices = coo_indices crow_indices = torch.cat( [ _compress_row_indices(row_indices[batch_indices == batch], num_rows) for batch in torch.unique(batch_indices) ] ) num_batches = torch.unique(batch_indices).shape[0] crow_indices = crow_indices.reshape(num_batches, -1) col_indices = col_indices.reshape(num_batches, -1) values = values[permutation] if values is not None else permutation values = values.reshape(num_batches, -1) return crow_indices, col_indices, values
[docs] def convert_coo_to_csr(sparse_coo_tensor): """ Convert COO sparse tensor to CSR format. Converts a COO (Coordinate List Format) sparse tensor to CSR (Compressed Sparse Row) format. Handles both unbatched and batched tensors with optional leading batch dimension. Parameters ---------- sparse_coo_tensor : torch.Tensor COO sparse tensor to convert. Must have layout torch.sparse_coo. Can be 2D (m, n) or 3D (b, m, n) with single batch dimension. Returns ------- torch.Tensor CSR sparse tensor with same shape and values as input. Raises ------ ValueError If input tensor layout is not torch.sparse_coo. Notes ----- The function automatically coalesces the COO tensor if not already coalesced, ensuring proper handling of duplicate indices. Examples -------- Convert 2D COO to CSR: >>> import torch >>> from torchsparsegradutils.utils import convert_coo_to_csr >>> # Create COO tensor >>> indices = torch.tensor([[0, 1, 1], [1, 0, 2]]) >>> values = torch.tensor([1.0, 2.0, 3.0]) >>> coo_tensor = torch.sparse_coo_tensor(indices, values, (2, 3)) >>> csr_tensor = convert_coo_to_csr(coo_tensor) >>> csr_tensor.layout torch.sparse_csr Convert batched COO to CSR: >>> # Batched COO tensor: 2 batches, each with 2 elements >>> batch_indices = torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1], [1, 0, 1, 0]]) >>> batch_values = torch.tensor([1.0, 2.0, 3.0, 4.0]) >>> batched_coo = torch.sparse_coo_tensor(batch_indices, batch_values, (2, 2, 2)) >>> batched_csr = convert_coo_to_csr(batched_coo) >>> batched_csr.shape torch.Size([2, 2, 2]) """ if sparse_coo_tensor.layout == torch.sparse_coo: if sparse_coo_tensor.is_coalesced() is False: sparse_coo_tensor = sparse_coo_tensor.coalesce() crow_indices, col_indices, values = convert_coo_to_csr_indices_values( sparse_coo_tensor.indices(), sparse_coo_tensor.size()[-2], sparse_coo_tensor.values() ) return torch.sparse_csr_tensor(crow_indices, col_indices, values, sparse_coo_tensor.size()) else: raise ValueError(f"Unsupported layout: {sparse_coo_tensor.layout}")
def _demcompress_crow_indices(crow_indices, num_rows): """ Decompress CSR crow indices to COO row indices. Converts CSR (Compressed Sparse Row) crow indices back to COO row indices by expanding the compressed representation to individual row indices for each non-zero element. Parameters ---------- crow_indices : torch.Tensor CSR crow indices tensor of shape (num_rows + 1,). Contains cumulative counts of non-zero elements up to each row. num_rows : int Total number of rows in the matrix. Returns ------- torch.Tensor COO row indices tensor of shape (nnz,) where each element indicates the row index of the corresponding non-zero value. Notes ----- This function performs the inverse operation of _compress_row_indices. Each row i is repeated (crow_indices[i+1] - crow_indices[i]) times in the output. Examples -------- Basic decompression: >>> import torch >>> from torchsparsegradutils.utils.utils import _demcompress_crow_indices >>> # CSR crow indices for matrix with pattern: >>> # [X X .] -> 2 elements in row 0 >>> # [. . .] -> 0 elements in row 1 >>> # [X . X] -> 2 elements in row 2 >>> crow_indices = torch.tensor([0, 2, 2, 4]) >>> row_indices = _demcompress_crow_indices(crow_indices, num_rows=3) >>> row_indices tensor([0, 0, 2, 2]) Single element per row: >>> # Each row has one element >>> crow_indices = torch.tensor([0, 1, 2, 3]) >>> row_indices = _demcompress_crow_indices(crow_indices, num_rows=3) >>> row_indices tensor([0, 1, 2]) """ row_indices = torch.repeat_interleave( torch.arange(num_rows, dtype=crow_indices.dtype, device=crow_indices.device), crow_indices[1:] - crow_indices[:-1], ) return row_indices # use @torch.jit.script ?
[docs] def sparse_block_diag(*sparse_tensors: torch.Tensor) -> torch.Tensor: """ Construct a block-diagonal sparse matrix from COO/CSR inputs. Builds a block-diagonal sparse tensor from a sequence of **2D** sparse tensors that are all in the same layout (either COO or CSR). The result has blocks placed on the diagonal and zeros elsewhere, analogous to :func:`torch.block_diag` but for sparse inputs. Parameters ---------- *sparse_tensors : torch.Tensor Variable number of 2D sparse tensors (all COO or all CSR). Each tensor must have exactly 2 sparse dimensions and 0 dense dimensions. Returns ------- torch.Tensor A sparse tensor in the same layout as the inputs with shape ``(sum_i n_i, sum_i m_i)``, where each block ``i`` has shape ``(n_i, m_i)``. Raises ------ TypeError If any input is not a :class:`torch.Tensor`. ValueError If no tensors are provided; if layouts are mixed; or if any input does not have exactly 2 sparse dims and 0 dense dims. Notes ----- The resulting block structure is .. code-block:: text [A₁ 0 0 ... 0 ] [0 A₂ 0 ... 0 ] [0 0 A₃ ... 0 ] [⋮ ⋮ ⋮ ⋱ ⋮ ] [0 0 0 ... Aₙ] Offsets for each block are computed using **cumulative** row/column sizes of all preceding blocks (not simply ``i * size``), so inputs may have different shapes. Examples -------- COO inputs: >>> import torch >>> from torchsparsegradutils.utils import sparse_block_diag >>> A = torch.sparse_coo_tensor(torch.tensor([[0, 1], [0, 1]]), torch.tensor([1., 2.]), size=(2, 2)) >>> B = torch.sparse_coo_tensor(torch.tensor([[0], [0]]), torch.tensor([3.]), size=(1, 1)) >>> C = sparse_block_diag(A, B) >>> C.shape torch.Size([3, 3]) >>> C.layout torch.sparse_coo CSR inputs: >>> A_csr = A.to_sparse_csr() >>> B_csr = B.to_sparse_csr() >>> D = sparse_block_diag(A_csr, B_csr) >>> D.layout torch.sparse_csr See Also -------- torch.block_diag : Dense block-diagonal construction for dense inputs. stack_csr : Stack CSR matrices along a new batch dimension. """ # ---- validation ---- for i, t in enumerate(sparse_tensors): if not isinstance(t, torch.Tensor): raise TypeError(f"TypeError: expected Tensor as element {i} in argument 0, but got {type(t).__name__}") if len(sparse_tensors) == 0: raise ValueError("At least one sparse tensor must be provided.") if all(t.layout == torch.sparse_coo for t in sparse_tensors): layout = torch.sparse_coo elif all(t.layout == torch.sparse_csr for t in sparse_tensors): layout = torch.sparse_csr else: raise ValueError("Sparse tensors must either be all sparse_coo or all sparse_csr.") if not all(t.sparse_dim() == 2 for t in sparse_tensors): raise ValueError("All sparse tensors must have exactly two sparse dimensions.") if not all(t.dense_dim() == 0 for t in sparse_tensors): raise ValueError("All sparse tensors must have zero dense dimensions.") if len(sparse_tensors) == 1: return sparse_tensors[0] # ---- COO path ---- if layout == torch.sparse_coo: row_parts = [] col_parts = [] val_parts = [] total_rows = 0 total_cols = 0 row_offset = 0 col_offset = 0 for t in sparse_tensors: t = t.coalesce() if not t.is_coalesced() else t rows, cols = t.indices() vals = t.values() # apply cumulative offsets row_parts.append(rows + row_offset) col_parts.append(cols + col_offset) val_parts.append(vals) # update offsets and totals n_i, m_i = t.size(-2), t.size(-1) row_offset += n_i col_offset += m_i total_rows += n_i total_cols += m_i rows_all = torch.cat(row_parts, dim=0) cols_all = torch.cat(col_parts, dim=0) vals_all = torch.cat(val_parts, dim=0) return torch.sparse_coo_tensor( torch.stack([rows_all, cols_all], dim=0), vals_all, size=(total_rows, total_cols) ) # ---- CSR path ---- # We need to stitch crow/col/values with cumulative offsets. crow_parts = [] col_parts = [] val_parts = [] total_rows = 0 total_cols = 0 col_offset = 0 crow_running_last = None # last crow value of accumulated blocks for idx, t in enumerate(sparse_tensors): crow = t.crow_indices() col = t.col_indices() vals = t.values() # For the first block, we keep the full crow. For subsequent blocks, # drop the initial zero and shift by the previous cumulative nnz. if idx == 0: crow_acc = crow else: # shift crow by last value of previous crow crow_acc = crow[1:] + crow_running_last # shift columns by cumulative column offset col_acc = col + col_offset crow_parts.append(crow_acc) col_parts.append(col_acc) val_parts.append(vals) n_i, m_i = t.size(-2), t.size(-1) total_rows += n_i total_cols += m_i col_offset += m_i crow_running_last = (crow_parts[-1][-1] if idx == 0 else crow_parts[-1][-1]).clone() crow_all = torch.cat(crow_parts, dim=0) col_all = torch.cat(col_parts, dim=0) vals_all = torch.cat(val_parts, dim=0) return torch.sparse_csr_tensor(crow_all, col_all, vals_all, size=(total_rows, total_cols))
[docs] def sparse_block_diag_split( sparse_block_diag_tensor: torch.Tensor, *shapes: Tuple[int, int] ) -> tuple[torch.Tensor, ...]: """ Split a block-diagonal sparse matrix back into its component blocks. Given a block-diagonal sparse tensor produced by :func:`sparse_block_diag`, return the original 2D sparse tensors (in the same layout) according to the provided shapes. Supports COO and CSR layouts. Parameters ---------- sparse_block_diag_tensor : torch.Tensor Input block-diagonal sparse tensor (COO or CSR). Must be 2D and have exactly two sparse dimensions and zero dense dimensions. *shapes : tuple of int Sequence of shapes ``(rows_i, cols_i)`` for each block in the order they appear along the diagonal. The sums of rows and cols must match the input tensor's height and width, respectively. Returns ------- tuple of torch.Tensor The recovered sparse blocks, each a 2D sparse tensor in the same layout as `sparse_block_diag_tensor`. Raises ------ ValueError If the input layout is not COO or CSR; if any provided shape is not 2D; or if the sum of provided shapes does not match the input size. TypeError If `sparse_block_diag_tensor` is not a tensor. Notes ----- - For COO inputs, this function assumes the tensor is **coalesced**. If it is not, it will be coalesced internally to avoid duplicate coordinates. - This is the inverse operation of :func:`sparse_block_diag` when given the correct shapes (order and sizes) of the original blocks. See Also -------- sparse_block_diag : Construct a block-diagonal sparse matrix from 2D sparse blocks. """ if not isinstance(sparse_block_diag_tensor, torch.Tensor): raise TypeError("Input must be a torch.Tensor.") if sparse_block_diag_tensor.layout == torch.sparse_coo: layout = torch.sparse_coo elif sparse_block_diag_tensor.layout == torch.sparse_csr: layout = torch.sparse_csr else: raise ValueError("Input tensor layout not supported. Only sparse_coo and sparse_csr are supported.") if not all(len(s) == 2 for s in shapes): raise ValueError("All shapes must be two-dimensional (rows, cols).") # Validate total shape matches the block-diagonal tensor total_rows = sum(s[0] for s in shapes) total_cols = sum(s[1] for s in shapes) in_rows, in_cols = sparse_block_diag_tensor.size(-2), sparse_block_diag_tensor.size(-1) if (total_rows, total_cols) != (in_rows, in_cols): raise ValueError( f"Sum of provided block shapes ({total_rows}, {total_cols}) does not match " f"input tensor size ({in_rows}, {in_cols})." ) if layout == torch.sparse_coo: # Ensure coalesced for clean masking t = ( sparse_block_diag_tensor.coalesce() if not sparse_block_diag_tensor.is_coalesced() else sparse_block_diag_tensor ) row_idx, col_idx = t.indices() vals = t.values() blocks: list[torch.Tensor] = [] row_offset = 0 col_offset = 0 for rows, cols in shapes: rmask = (row_idx >= row_offset) & (row_idx < row_offset + rows) cmask = (col_idx >= col_offset) & (col_idx < col_offset + cols) mask = rmask & cmask sub_rows = row_idx[mask] - row_offset sub_cols = col_idx[mask] - col_offset sub_vals = vals[mask] blocks.append( torch.sparse_coo_tensor( torch.stack((sub_rows, sub_cols), dim=0), sub_vals, size=(rows, cols), device=t.device, dtype=sub_vals.dtype, ) ) row_offset += rows col_offset += cols return tuple(blocks) # CSR path t = sparse_block_diag_tensor crow = t.crow_indices() ccol = t.col_indices() vals = t.values() blocks: list[torch.Tensor] = [] row_offset = 0 col_offset = 0 for rows, cols in shapes: # Pointer range for this row block in crow start_ptr = int(crow[row_offset].item()) end_ptr = int(crow[row_offset + rows].item()) # Slice the values/columns for this block and shift columns back sub_ccol = ccol[start_ptr:end_ptr] - col_offset sub_vals = vals[start_ptr:end_ptr] # Row pointers for this block: subtract start_ptr to rebase to 0 sub_crow = crow[row_offset : row_offset + rows + 1] - crow[row_offset] blocks.append( torch.sparse_csr_tensor( sub_crow, sub_ccol, sub_vals, size=(rows, cols), device=t.device, dtype=sub_vals.dtype, ) ) row_offset += rows col_offset += cols return tuple(blocks)
[docs] def sparse_eye( size: Tuple[int, ...], *, layout: torch.layout = torch.sparse_coo, values_dtype: torch.dtype = torch.float64, indices_dtype: torch.dtype = torch.int64, device: torch.device = torch.device("cpu"), requires_grad: bool = False, ) -> torch.Tensor: """ Create a sparse identity matrix. Constructs an identity matrix in sparse format (COO or CSR). Supports both unbatched and batched square matrices. Parameters ---------- size : tuple of int Shape of the identity matrix. Must be either ``(n, n)`` for unbatched or ``(batch_size, n, n)`` for batched. Rows and columns must be equal. layout : torch.layout, default=torch.sparse_coo Sparse tensor layout. Must be either ``torch.sparse_coo`` or ``torch.sparse_csr``. values_dtype : torch.dtype, default=torch.float64 Data type of the values. Only ``torch.float32`` and ``torch.float64`` are supported. indices_dtype : torch.dtype, default=torch.int64 Data type of the indices. Must be ``torch.int32`` or ``torch.int64``. device : torch.device, default=torch.device("cpu") Device on which to create the tensor. requires_grad : bool, default=False Whether autograd should record operations on the returned tensor. Returns ------- torch.Tensor Sparse identity matrix of shape ``(n, n)`` or batched identity matrix of shape ``(batch_size, n, n)``, in the requested sparse layout. Raises ------ ValueError If size is not 2D or 3D, matrix is not square, or if dtypes/layout are unsupported. Notes ----- - For batched inputs, each batch element is an independent identity matrix. Examples -------- Unbatched identity (COO): >>> from torchsparsegradutils.utils import sparse_eye >>> I = sparse_eye((3, 3), layout=torch.sparse_coo) >>> I.to_dense() # doctest: +ELLIPSIS tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]...) Batched identity (CSR): >>> I_batched = sparse_eye((2, 4, 4), layout=torch.sparse_csr) >>> I_batched.shape torch.Size([2, 4, 4]) """ if len(size) < 2: raise ValueError("size must have at least 2 dimensions") if len(size) > 3: raise ValueError("size must have at most 3 dimensions (supports 1 batch dimension)") if size[-2] != size[-1]: raise ValueError("size must define a square matrix (n, n) or batched square matrix (b, n, n)") if values_dtype not in (torch.float32, torch.float64): raise ValueError(f"Values dtype {values_dtype} not supported. Use torch.float32 or torch.float64.") if indices_dtype not in (torch.int32, torch.int64): raise ValueError(f"indices_dtype {indices_dtype} not supported. Use torch.int32 or torch.int64.") values = torch.ones(size[-1], dtype=values_dtype, device=device) if layout == torch.sparse_coo: if indices_dtype not in [torch.int32, torch.int64]: raise ValueError("For sparse_coo layout, indices_dtype can either be torch.int32 or torch.int64.") indices = torch.arange(0, size[-1], dtype=indices_dtype, device=device) indices = torch.stack([indices, indices], dim=0) if len(size) == 3: batch_dim_indices = ( torch.arange(size[0], dtype=indices_dtype, device=device).repeat_interleave(size[-1]).unsqueeze(0) ) sparse_dim_indices = torch.cat([indices] * size[0], dim=-1) indices = torch.cat([batch_dim_indices, sparse_dim_indices]) values = values.repeat(size[0]) # NOTE: is_coalesced=True since there are no duplicate indices in identity matrix, flag avails in PyTorch 2.1+ return torch.sparse_coo_tensor( indices, values, size, dtype=values_dtype, device=device, requires_grad=requires_grad, is_coalesced=True ) elif layout == torch.sparse_csr: if indices_dtype not in [torch.int32, torch.int64]: raise ValueError("For sparse_csr layout, indices_dtype can either be torch.int32 or torch.int64.") crow_indices = torch.arange(0, size[-1] + 1, dtype=indices_dtype, device=device) col_indices = torch.arange(0, size[-1], dtype=indices_dtype, device=device) if len(size) == 3: crow_indices = crow_indices.repeat(size[0], 1) col_indices = col_indices.repeat(size[0], 1) values = values.repeat(size[0], 1) return torch.sparse_csr_tensor( crow_indices, col_indices, values, size, dtype=values_dtype, device=device, requires_grad=requires_grad ) else: raise ValueError("Layout {} not supported. Only sparse_coo and sparse_csr are supported.".format(layout))