Note
Go to the end to download the full example code
Grouped GEMM Example#
This example demonstrates grouped matrix multiplication (GEMM) where multiple
input matrices A_i (with potentially different numbers of rows M_i)
are multiplied against a single shared weight matrix B. The results are
concatenated in the original group order.
Key ideas used in this implementation:
Pack all groups’ rows into one contiguous tensor
A_packedwith shape[sum(M_i), K]. This improves memory locality and simplifies indexing.Represent group boundaries with
group_offsets(sizeG+1), so that rows for groupglive inA_packed[group_offsets[g]:group_offsets[g+1]].Use data-dependent tiling over the concatenated row dimension to efficiently support jagged group sizes (different
M_iper group) without padding.
Two kernels are provided:
grouped_gemm_jagged- a simple kernel that iterates groups and tiles dynamically.grouped_gemm_jagged_persistent- a persistent kernel with dynamic tile assignment for better load balancing across streaming multiprocessors (SMs).
Imports and Dependencies#
from __future__ import annotations
from typing import Callable
import torch
import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl
Grouped GEMM Kernel - Basic Implementation#
@helion.kernel(static_shapes=False)
def grouped_gemm_jagged(
A_packed: torch.Tensor, # [total_M, K], where total_M == sum(M_i)
B: torch.Tensor, # [K, N] shared across all groups
group_offsets: torch.Tensor, # [G+1], int32/int64, row offsets into A_packed
) -> torch.Tensor: # [total_M, N] concatenated outputs for all groups
"""
Perform grouped GEMM on jagged inputs using row offsets.
Args:
A_packed: Row-wise concatenation of per-group inputs ``A_i``,
shape ``[sum(M_i), K]``.
B: Shared weight matrix, shape ``[K, N]``.
group_offsets: Row offsets delimiting each group within ``A_packed``,
shape ``[G+1]``. For group ``g``: rows are
``start = group_offsets[g]`` to ``end = group_offsets[g+1]``.
Returns:
Output tensor of shape ``[sum(M_i), N]`` equal to
``torch.cat([A_i @ B for i in groups], dim=0)``.
"""
total_M, K = A_packed.shape
K2, N = B.shape
assert K == K2, "K dimension mismatch between A_packed and B"
out = torch.empty(
total_M,
N,
dtype=torch.promote_types(A_packed.dtype, B.dtype),
device=A_packed.device,
)
G = group_offsets.size(0) - 1
# Process each group independently, tiling over its specific M_g dimension
for g in hl.grid(G):
start = group_offsets[g]
end = group_offsets[g + 1]
M_g = end - start
if M_g != 0:
# Create 2D tiling pattern over output dimensions (M_g x N) for current group
for tile_m, tile_n in hl.tile([M_g, N]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
# K-reduction loop: multiply tiles along K dimension
for tile_k in hl.tile(K):
a_blk = A_packed[start + tile_m.index, tile_k]
b_blk = B[tile_k, tile_n]
# Perform fused multiply-add with FP32 accumulation for numerical stability
acc = torch.addmm(acc, a_blk, b_blk)
# Convert accumulator to output dtype and store result
out[start + tile_m.index, tile_n] = acc.to(out.dtype)
return out
Grouped GEMM Kernel - Persistent Implementation#
@helion.kernel(static_shapes=False)
def grouped_gemm_jagged_persistent(
A_packed: torch.Tensor, # [total_M, K]
B: torch.Tensor, # [K, N]
group_offsets: torch.Tensor, # [G+1], row offsets into A_packed
) -> torch.Tensor:
"""
Persistent grouped GEMM with dynamic tile metadata computation.
This variant computes tile assignments dynamically in the kernel,
similar to TritonBench's WS variant.
Args:
A_packed: Packed A, concatenated by rows across groups, ``[sum(M_i), K]``.
B: Shared weight matrix, ``[K, N]``.
group_offsets: Row offsets delimiting each group within ``A_packed``.
Returns:
Output tensor of shape ``[sum(M_i), N]``.
"""
# Set worker count to match GPU streaming multiprocessor count
device = A_packed.device
if device.type == "xpu":
# TODO(EikanWang): gpu_subslice_count is an out-of-date term. we will update it to XeCore number.
num_workers = torch.xpu.get_device_properties(device.index).gpu_subslice_count
else:
num_workers = torch.cuda.get_device_properties(
device.index
).multi_processor_count
# Define tunable block sizes for M, N dimensions (auto-tuned at runtime)
BLOCK_M = hl.register_block_size(32, 128)
BLOCK_N = hl.register_block_size(32, 128)
total_M, K = A_packed.shape
K2, N = B.shape
assert K == K2
out = torch.zeros(
total_M,
N,
dtype=torch.promote_types(A_packed.dtype, B.dtype),
device=A_packed.device,
)
G = group_offsets.size(0) - 1
for worker_id in hl.grid(num_workers):
# Persistent thread pattern: each worker processes tiles across all groups
# using strided/interleaved assignment for load balancing.
# (i.e. each worker takes every num_workers-th tile. e.g., worker 0 takes tiles 0, N, 2N, ...)
for g in hl.grid(G):
group_start = group_offsets[g]
group_end = group_offsets[g + 1]
m_size = group_end - group_start
if m_size > 0:
# Compute tile grid dimensions for current group
num_m_tiles = (m_size + BLOCK_M - 1) // BLOCK_M
# Calculate number of N tiles (shared across all groups)
num_n_tiles = (N + BLOCK_N - 1) // BLOCK_N
num_group_tiles = num_m_tiles * num_n_tiles
# Distribute tiles among workers using strided access pattern
for local_tile in hl.grid(num_group_tiles):
tile_in_group = local_tile * num_workers + worker_id
if tile_in_group < num_group_tiles:
# Convert linear tile index to 2D (M, N) tile coordinates
m_tile_idx = tile_in_group % num_m_tiles
n_tile_idx = tile_in_group // num_m_tiles
# Compute global memory indices for current tile
base_row = group_start + m_tile_idx * BLOCK_M
base_col = n_tile_idx * BLOCK_N
# Generate row and column index ranges for tile access
row_idx = base_row + hl.arange(BLOCK_M)
col_idx = base_col + hl.arange(BLOCK_N)
# Apply boundary masks to handle partial tiles at edges
rows_valid = row_idx < group_end
cols_valid = col_idx < N
# Initialize FP32 accumulator for numerical precision
acc = hl.zeros([BLOCK_M, BLOCK_N], dtype=torch.float32)
# Iterate over K dimension in blocks for matrix multiplication
for k_tile in hl.tile(K):
k_idx = k_tile.index
# Load tiles from A_packed and B with boundary checking
a_blk = hl.load(
A_packed,
[row_idx, k_idx],
extra_mask=rows_valid[:, None],
)
b_blk = hl.load(
B,
[k_idx, col_idx],
extra_mask=cols_valid[None, :],
)
# Perform tile-level matrix multiplication and accumulate
acc = torch.addmm(acc, a_blk, b_blk)
# Write accumulated result to output with boundary masking
valid_2d = rows_valid[:, None] & cols_valid[None, :]
hl.store(
out,
[row_idx, col_idx],
acc.to(out.dtype),
extra_mask=valid_2d,
)
return out
Data Preparation Utilities#
def _pack_group_inputs(
group_A: list[torch.Tensor], group_B: list[torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Build ``A_packed``, shared ``B``, and ``group_offsets`` from grouped inputs.
Expectations:
- All ``A_i`` share the same ``K`` and dtype/device.
- All groups share the same ``B`` (as produced by TritonBench inputs).
Returns ``(A_packed, B_shared, group_offsets)`` where
``group_offsets`` has length ``G+1`` with ``group_offsets[0] == 0`` and
``group_offsets[g+1] - group_offsets[g] == M_g``.
"""
assert len(group_A) > 0
device = group_A[0].device
dtype = group_A[0].dtype
# Extract shared weight matrix B (same for all groups in TritonBench)
B_shared = group_B[0]
# Compute group offsets and concatenate all A matrices row-wise
M_sizes = [int(a.size(0)) for a in group_A]
starts = [0]
for m in M_sizes:
starts.append(starts[-1] + m)
group_offsets = torch.tensor(starts, device=device, dtype=torch.int32)
A_packed = torch.cat(group_A, dim=0).to(device=device, dtype=dtype).contiguous()
return A_packed, B_shared, group_offsets
TritonBench Integration Wrappers#
def grouped_gemm_jagged_tritonbench(
tb_op: object, group_A: list[torch.Tensor], group_B: list[torch.Tensor]
) -> Callable[[], torch.Tensor]:
"""Adapter for basic grouped GEMM kernel to work with TritonBench benchmark suite."""
def inner() -> torch.Tensor:
A_packed, B_shared, group_offsets = _pack_group_inputs(group_A, group_B)
return grouped_gemm_jagged(A_packed, B_shared, group_offsets)
return inner
def grouped_gemm_jagged_persistent_tritonbench(
tb_op: object, group_A: list[torch.Tensor], group_B: list[torch.Tensor]
) -> Callable[[], torch.Tensor]:
"""Adapter for persistent grouped GEMM kernel with dynamic work distribution for TritonBench."""
def inner() -> torch.Tensor:
A_packed, B_shared, group_offsets = _pack_group_inputs(group_A, group_B)
return grouped_gemm_jagged_persistent(
A_packed,
B_shared,
group_offsets,
)
return inner
Reference Implementation for Validation#
def _reference_grouped_gemm(
group_A: list[torch.Tensor], group_B: list[torch.Tensor]
) -> torch.Tensor:
B_shared = group_B[0]
outs = [a @ B_shared for a in group_A]
return torch.cat(outs, dim=0)
def grouped_gemm_jagged_example(
group_A: list[torch.Tensor], group_B: list[torch.Tensor]
) -> torch.Tensor:
"""
Wrapper to run grouped_gemm_jagged with unpacked TritonBench inputs.
"""
A_packed, B_shared, group_offsets = _pack_group_inputs(group_A, group_B)
return grouped_gemm_jagged(A_packed, B_shared, group_offsets)
def grouped_gemm_jagged_persistent_example(
group_A: list[torch.Tensor], group_B: list[torch.Tensor]
) -> torch.Tensor:
"""
Wrapper to run grouped_gemm_jagged_persistent with unpacked TritonBench inputs.
"""
A_packed, B_shared, group_offsets = _pack_group_inputs(group_A, group_B)
return grouped_gemm_jagged_persistent(A_packed, B_shared, group_offsets)
Test Harness and Validation#
def main() -> None:
torch.manual_seed(0) # Ensure reproducible test results
device = DEVICE
dtype = torch.bfloat16
G = 4 # Number of groups to test
K, N = 256, 128 # Shared dimensions: K (reduction), N (output columns)
# Create test data with varying group sizes (M_i = 64, 128, 192, 256)
group_A = [
torch.randn(64 * (i + 1), K, device=device, dtype=dtype).contiguous()
for i in range(G)
]
# Shared weight matrix B replicated for each group (as per TritonBench convention)
group_B = [torch.randn(K, N, device=device, dtype=dtype).contiguous()] * G
print("Testing grouped GEMM kernels...")
run_example(
grouped_gemm_jagged_example,
_reference_grouped_gemm,
(group_A, group_B),
rtol=1e-2,
atol=1e-2,
)
print("✓ Non-persistent kernel passed")
run_example(
grouped_gemm_jagged_persistent_example,
_reference_grouped_gemm,
(group_A, group_B),
rtol=1e-2,
atol=1e-2,
)
print("✓ Persistent kernel passed")
print("\nAll tests passed!")
if __name__ == "__main__":
main()
Total running time of the script: (0 minutes 0.000 seconds)