torchsparsegradutils.segment_mm
- torchsparsegradutils.segment_mm(a: Tensor, b: Tensor, seglen_a: Tensor) Tensor[source]
Segmented matrix multiplication with variable-length segments.
Performs matrix multiplication between contiguous segments of
aand the corresponding matrices inb. Ifseglen_a == [10, 5, 0, 3], the operator computes:a[0:10] @ b[0], a[10:15] @ b[1], a[15:15] @ b[2], a[15:18] @ b[3]
- Parameters:
a (torch.Tensor, shape
(N, D1)) – Left operand containing the concatenation of all segments.b (torch.Tensor, shape
(R, D1, D2)) – Right operand containing one(D1, D2)matrix per segment.seglen_a (torch.Tensor, shape
(R,), integer dtype) – Length of each segment ina.seglen_a.sum()must equalN.
- Returns:
Concatenation of all segment results in original order.
- Return type:
torch.Tensor, shape
(N, D2)- Raises:
NotImplementedError – If the fallback path is used on a PyTorch version lacking nested tensor matmul support (requires PyTorch >= 2.4).
ValueError – If input ranks or sizes are incompatible.
Notes
If DGL is available, this uses
dgl.ops.segment_mm()[1c] (typically faster). Otherwise it falls back to a PyTorch nested-tensor implementation.See also
gather_mmPer-row indexed matrix multiplication.
References
[1c]DGL
segment_mmdocumentation: https://www.dgl.ai/dgl_docs/generated/dgl.ops.segment_mm.htmlExamples
>>> import torch >>> # N = 18, D1 = 4, D2 = 2 >>> a = torch.randn(18, 4) >>> b = torch.randn(3, 4, 2) >>> seglen_a = torch.tensor([10, 5, 3]) >>> out = segment_mm(a, b, seglen_a) >>> out.shape torch.Size([18, 2])
Zero-length segment:
>>> seglen_a = torch.tensor([10, 5, 0, 3]) >>> b = torch.randn(4, 4, 2) >>> segment_mm(a, b, seglen_a).shape torch.Size([18, 2])