torchsparsegradutils.segment_mm

torchsparsegradutils.segment_mm(a: Tensor, b: Tensor, seglen_a: Tensor) Tensor[source]

Segmented matrix multiplication with variable-length segments.

Performs matrix multiplication between contiguous segments of a and the corresponding matrices in b. If seglen_a == [10, 5, 0, 3], the operator computes:

a[0:10] @ b[0], a[10:15] @ b[1],
a[15:15] @ b[2], a[15:18] @ b[3]
Parameters:
  • a (torch.Tensor, shape (N, D1)) – Left operand containing the concatenation of all segments.

  • b (torch.Tensor, shape (R, D1, D2)) – Right operand containing one (D1, D2) matrix per segment.

  • seglen_a (torch.Tensor, shape (R,), integer dtype) – Length of each segment in a. seglen_a.sum() must equal N.

Returns:

Concatenation of all segment results in original order.

Return type:

torch.Tensor, shape (N, D2)

Raises:
  • NotImplementedError – If the fallback path is used on a PyTorch version lacking nested tensor matmul support (requires PyTorch >= 2.4).

  • ValueError – If input ranks or sizes are incompatible.

Notes

If DGL is available, this uses dgl.ops.segment_mm() [1c] (typically faster). Otherwise it falls back to a PyTorch nested-tensor implementation.

See also

gather_mm

Per-row indexed matrix multiplication.

References

Examples

>>> import torch
>>> # N = 18, D1 = 4, D2 = 2
>>> a = torch.randn(18, 4)
>>> b = torch.randn(3, 4, 2)
>>> seglen_a = torch.tensor([10, 5, 3])
>>> out = segment_mm(a, b, seglen_a)
>>> out.shape
torch.Size([18, 2])

Zero-length segment:

>>> seglen_a = torch.tensor([10, 5, 0, 3])
>>> b = torch.randn(4, 4, 2)
>>> segment_mm(a, b, seglen_a).shape
torch.Size([18, 2])