Source code for torchsparsegradutils.indexed_matmul

import torch
from packaging.version import parse as parse_version

try:
    import dgl.ops as dglops

    dgl_installed = True
except ImportError:
    dgl_installed = False


[docs] def segment_mm(a: torch.Tensor, b: torch.Tensor, seglen_a: torch.Tensor) -> torch.Tensor: r""" Segmented matrix multiplication with variable-length segments. Performs matrix multiplication between contiguous segments of ``a`` and the corresponding matrices in ``b``. If ``seglen_a == [10, 5, 0, 3]``, the operator computes:: a[0:10] @ b[0], a[10:15] @ b[1], a[15:15] @ b[2], a[15:18] @ b[3] Parameters ---------- a : torch.Tensor, shape ``(N, D1)`` Left operand containing the concatenation of all segments. b : torch.Tensor, shape ``(R, D1, D2)`` Right operand containing one ``(D1, D2)`` matrix per segment. seglen_a : torch.Tensor, shape ``(R,)``, integer dtype Length of each segment in ``a``. ``seglen_a.sum()`` must equal ``N``. Returns ------- torch.Tensor, shape ``(N, D2)`` Concatenation of all segment results in original order. Raises ------ NotImplementedError If the fallback path is used on a PyTorch version lacking nested tensor matmul support (requires PyTorch >= 2.4). ValueError If input ranks or sizes are incompatible. Notes ----- If DGL is available, this uses :func:`dgl.ops.segment_mm` [1c]_ (typically faster). Otherwise it falls back to a PyTorch nested-tensor implementation. See Also -------- gather_mm : Per-row indexed matrix multiplication. References ---------- .. [1c] DGL ``segment_mm`` documentation: https://www.dgl.ai/dgl_docs/generated/dgl.ops.segment_mm.html Examples -------- >>> import torch >>> # N = 18, D1 = 4, D2 = 2 >>> a = torch.randn(18, 4) >>> b = torch.randn(3, 4, 2) >>> seglen_a = torch.tensor([10, 5, 3]) >>> out = segment_mm(a, b, seglen_a) >>> out.shape torch.Size([18, 2]) Zero-length segment:: >>> seglen_a = torch.tensor([10, 5, 0, 3]) >>> b = torch.randn(4, 4, 2) >>> segment_mm(a, b, seglen_a).shape torch.Size([18, 2]) """ if parse_version(torch.__version__) < parse_version("2.4"): raise NotImplementedError("PyTorch version is too old for nested tensors") if dgl_installed: # DGL is probably more computationally efficient # See https://github.com/pytorch/pytorch/issues/136747 return dglops.segment_mm(a, b, seglen_a) if not a.dim() == 2 or not b.dim() == 3 or not seglen_a.dim() == 1: raise ValueError("Input tensors have unexpected dimensions") N, _ = a.shape R, D1, D2 = b.shape # Sanity check sizes if not a.shape[1] == D1 or not seglen_a.shape[0] == R: raise ValueError("Incompatible size for inputs") segidx_a = torch.cumsum(seglen_a[:-1], dim=0).cpu() # Ideally the conversions below to nested tensor would be handled natively nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0)) nested_b = torch.nested.as_nested_tensor(torch.split(b, 1, dim=0)).reshape((R, D1, D2)) # The actual gather matmul computation nested_ab = torch.matmul(nested_a, nested_b) # Convert back to tensors, again ideally this would be handled natively ab = torch.cat(nested_ab.unbind(), dim=0) return ab
[docs] def gather_mm(a: torch.Tensor, b: torch.Tensor, idx_b: torch.Tensor) -> torch.Tensor: r""" Per-row indexed matrix multiplication. For each row ``i`` in ``a`` this computes ``a[i] @ b[idx_b[i]]`` and stacks the results into the output. Parameters ---------- a : torch.Tensor, shape ``(N, D1)`` Left operand with one row per output. b : torch.Tensor, shape ``(R, D1, D2)`` Bank of transformation matrices. idx_b : torch.Tensor, shape ``(N,)``, integer dtype Indices selecting which matrix in ``b`` to use for each row. Values must satisfy ``0 <= idx_b[i] < R``. Returns ------- torch.Tensor, shape ``(N, D2)`` Row-wise results where ``out[i] = a[i] @ b[idx_b[i]]``. Raises ------ NotImplementedError If the fallback path is used on a PyTorch version lacking nested tensor matmul support (requires PyTorch >= 2.4). ValueError If inputs are not tensors, ranks are incorrect, or sizes are incompatible. Notes ----- If DGL is available, this uses :func:`dgl.ops.gather_mm` [1b]_. Otherwise it uses a dependency-free PyTorch nested-tensor fallback. See Also -------- segment_mm : Segmented matrix multiplication over contiguous chunks. References ---------- .. [1b] DGL ``gather_mm`` documentation: https://www.dgl.ai/dgl_docs/generated/dgl.ops.gather_mm.html Examples -------- >>> import torch >>> # N = 5, D1 = 3, D2 = 2, R = 3 >>> a = torch.randn(5, 3) >>> b = torch.randn(3, 3, 2) >>> idx_b = torch.tensor([0, 1, 0, 2, 1]) >>> out = gather_mm(a, b, idx_b) >>> out.shape torch.Size([5, 2]) All rows using the same matrix:: >>> torch.allclose(gather_mm(a, b, torch.zeros(5, dtype=torch.long)), a @ b[0]) True Mixed indexing example:: >>> # Different transformation for each row >>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) # (2, 2) >>> b = torch.tensor([[[1.0, 0.0], [0.0, 1.0]], # Identity ... [[2.0, 0.0], [0.0, 2.0]]]) # 2x scale >>> idx_b = torch.tensor([0, 1]) # Use identity, then 2x scale >>> result = gather_mm(a, b, idx_b) >>> result tensor([[1., 2.], [6., 8.]]) """ if parse_version(torch.__version__) < parse_version("2.4"): raise NotImplementedError("PyTorch version is too old for nested tensors") if dgl_installed: # DGL is more computationally efficient # See https://github.com/pytorch/pytorch/issues/136747 return dglops.gather_mm(a, b, idx_b) # Dependency free fallback if not isinstance(a, torch.Tensor) or not isinstance(b, torch.Tensor) or not isinstance(idx_b, torch.Tensor): raise ValueError("Inputs should be instances of torch.Tensor") if not a.dim() == 2 or not b.dim() == 3 or not idx_b.dim() == 1: raise ValueError("Input tensors have unexpected dimensions") N = idx_b.shape[0] R, D1, D2 = b.shape # Sanity check sizes if not a.shape[0] == N or not a.shape[1] == D1: raise ValueError("Incompatible size for inputs") torchdevice = a.device src_idx = torch.arange(N, device=torchdevice) # Ideally the conversions below to nested tensor would be handled without for loops and without copy nested_a = torch.nested.as_nested_tensor([a[idx_b == i, :] for i in range(R)]) src_idx_reshuffled = torch.cat([src_idx[idx_b == i] for i in range(R)]) nested_b = torch.nested.as_nested_tensor(torch.split(b, 1, dim=0)).reshape((R, D1, D2)) # The actual gather matmul computation nested_ab = torch.matmul(nested_a, nested_b) # Convert back to tensors, again, ideally this would be handled natively with no copy ab_segmented = torch.cat(nested_ab.unbind(), dim=0) ab = torch.empty((N, D2), device=torchdevice) ab[src_idx_reshuffled] = ab_segmented return ab