Note
Go to the end to download the full example code
INT4 General Matrix Multiplication (GEMM) with Helion#
This example demonstrates an INT4 GEMM kernel implemented in Helion. The kernel performs matrix multiplication where the second matrix B is packed with two 4-bit values per byte. The kernel unpacks the int4 values, converts to bfloat16, and performs matmul with the bfloat16 matrix A.
Imports#
INT4 GEMM Kernel#
@helion.kernel(static_shapes=False)
def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor:
"""
BFloat16 x INT4 General Matrix Multiplication (GEMM).
This kernel performs matrix multiplication where:
- A is a bfloat16 matrix of shape [M, K]
- B is an int8 matrix of shape [K//2, N] containing packed int4 values
(two 4-bit values packed into each int8)
Args:
A (Tensor): Input tensor of shape [M, K] in bfloat16 format.
B (Tensor): Packed int4 tensor of shape [K//2, N] in int8 format.
Returns:
Tensor: Output tensor of shape [M, N] in bfloat16 format.
"""
M, K = A.shape
_, N = B.shape
C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
block_size_k_packed = hl.register_block_size(K // 2)
# Use Helion to tile the computation
for tile_m, tile_n in hl.tile([M, N]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed):
# Load corresponding tiles from A (need to load twice the packed tile size).
# We map tile_k_packed to the corresponding (factor-2 widened) range in A.
# This affine-widened slice is lowered as an affine ``hl.arange()`` index and,
# on the CuTe backend, is consumed by the packed-affine matmul load fusion
# below (the affine A load feeds directly into ``hl.dot``).
a_tile_begin = tile_k_packed.begin * 2
a_tile_len = block_size_k_packed * 2
a_tile = A[
tile_m, a_tile_begin : (a_tile_begin + a_tile_len)
] # [BLOCK_SIZE_M, BLOCK_SIZE_K]
# Load packed int8 data from B
b_tile = B[tile_k_packed, tile_n] # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
# Extract low and high 4-bit values with sign extension.
# The low nibble is sign-extended arithmetically (mask to 4 bits, then
# subtract 16 when the sign bit is set) rather than via ``(x << 4) >> 4``;
# the explicit arithmetic form is portable across backends and avoids
# relying on narrow-int (int8) shift truncation semantics.
b_lo_raw = (b_tile & 0xF).to(torch.int32)
b_lo = torch.where(b_lo_raw >= 8, b_lo_raw - 16, b_lo_raw) # low 4 bits
b_hi = (b_tile >> 4).to(torch.int32) # high 4 bits (arithmetic >>)
# Stack and reshape to interleave low and high bits.
# Stack along a new dimension to get [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N],
# then reshape to interleave into [BLOCK_SIZE_K, BLOCK_SIZE_N] in the order
# b_lo[0], b_hi[0], b_lo[1], b_hi[1], ..., matching the A[2*kp]/A[2*kp+1] pairs.
b_stacked = torch.stack([b_lo.to(A.dtype), b_hi.to(A.dtype)], dim=1)
b_unpacked = b_stacked.reshape(
tile_k_packed.block_size * 2, tile_n.block_size
)
# Matmul over the widened K dimension. Using ``hl.dot`` lets the affine
# A load and the interleaved B unpack fuse into a single packed-affine
# matmul (the supported lowering for factor-widened affine loads).
acc = hl.dot(a_tile, b_unpacked, acc=acc) # [BLOCK_SIZE_M, BLOCK_SIZE_N]
C[tile_m, tile_n] = acc.to(torch.bfloat16)
return C
TritonBench Wrapper#
def int4_gemm_tritonbench(tb_op: object, x: torch.Tensor, w: torch.Tensor) -> Callable:
"""
Wrapper for TritonBench compatibility.
Args:
tb_op: TritonBench operator instance
x (torch.Tensor): Left input tensor in bfloat16 format.
w (torch.Tensor): Right input tensor of shape [K, N] containing int4 values.
Will be packed to int4 format.
Returns:
Callable: A function that performs the int4 gemm.
"""
# Pack w to int4 format (two 4-bit values per int8 byte)
x_2d = x.reshape(-1, x.size(-1))
w_int8 = w.to(torch.int8)
w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2)
w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8)
def run_kernel() -> torch.Tensor:
return matmul_bf16_int4(x_2d, w_packed)
return run_kernel
Verification Function#
def _pack_int4_matrix(unpacked: torch.Tensor) -> torch.Tensor:
"""
Pack int4 matrix into int8 container with two values per byte.
Args:
unpacked (torch.Tensor): Tensor of shape [K, N] with values in [-8, 7].
Returns:
torch.Tensor: Packed tensor of shape [K//2, N] in int8 format.
"""
k, n = unpacked.shape
assert k % 2 == 0, "K dimension must be even for int4 packing"
reshaped = unpacked.reshape(k // 2, 2, n).permute(1, 0, 2)
return ((reshaped[0] & 0xF) | (reshaped[1] << 4)).to(torch.int8)
def _unpack_int4_matrix(packed: torch.Tensor) -> torch.Tensor:
"""
Unpack an int4 matrix stored as two 4-bit values per int8 byte.
Args:
packed (torch.Tensor): Packed tensor of shape [K//2, N] in int8 format.
Returns:
torch.Tensor: Unpacked tensor of shape [K, N] in int8 format.
"""
b_lo = ((packed << 4) >> 4).to(torch.int8)
b_hi = (packed >> 4).to(torch.int8)
stacked = torch.stack([b_lo, b_hi], dim=1)
return stacked.reshape(packed.shape[0] * 2, packed.shape[1])
def reference_matmul_bf16_int4(A: Tensor, B_packed: Tensor) -> Tensor:
"""
Reference implementation that unpacks the int4 weights and performs matmul.
Args:
A (Tensor): Input tensor in bfloat16 format.
B_packed (Tensor): Packed int4 tensor.
Returns:
Tensor: Output tensor in bfloat16 format.
"""
B_unpacked = _unpack_int4_matrix(B_packed).to(torch.bfloat16)
return torch.matmul(A, B_unpacked)
def check(m: int, k: int, n: int) -> None:
"""
Test the INT4 GEMM implementation using the run_example utility.
Args:
m (int): Number of rows in the left input matrix.
k (int): Shared dimension (must be even).
n (int): Number of columns in the right input matrix.
"""
A = torch.randn(m, k, dtype=torch.bfloat16, device=DEVICE)
B_unpacked = torch.randint(-8, 8, (k, n), dtype=torch.int8, device=DEVICE)
B_packed = _pack_int4_matrix(B_unpacked)
run_example(
matmul_bf16_int4,
reference_matmul_bf16_int4,
(A, B_packed),
rtol=2e-1,
atol=1.0,
)
print(f"Test passed for shapes: M={m}, K={k}, N={n}")
Main Function#
def main() -> None:
"""
Main function to run tests with different matrix sizes.
"""
check(4, 8192, 7168)
check(8192, 8192, 8192)
Run Example#
if __name__ == "__main__":
main()
Total running time of the script: (0 minutes 0.000 seconds)