Matrix Multiplication with Layer Normalization Example

This example demonstrates how to implement a fused matrix multiplication and layer normalization operation using Helion.

Imports

from __future__ import annotations

import torch
import torch.nn.functional as F

import helion
from helion._testing import run_example
import helion.language as hl

MatMul-LayerNorm Kernel

static_shapes=True gives a performance boost for matmuls

@helion.kernel(static_shapes=True)
def matmul_layernorm(
    x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
    """
    Performs matrix multiplication followed by layer normalization.

    Args:
        x: First input tensor of shape [M, K]
        y: Second input tensor of shape [K, N]
        weight: Layer normalization weight parameter of shape [N]
        bias: Layer normalization bias parameter of shape [N]

    Returns:
        Output tensor of shape [M, N] containing the result of matrix multiplication followed by layer normalization
    """
    m, k = x.size()
    k2 = y.size(0)
    n = hl.register_reduction_dim(y.size(1))
    assert k == k2, f"size mismatch {k} != {k2}"
    assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}"
    assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {n}"
    out = torch.empty(
        [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
    )
    for tile_m in hl.tile(m):
        acc = hl.zeros([tile_m, n], dtype=torch.float32)
        for tile_k in hl.tile(k):
            mm = torch.matmul(x[tile_m, tile_k], y[tile_k, :])
            acc = acc + mm
        eps = 1e-5
        var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
        normalized = (acc - mean) * torch.rsqrt(var + eps)
        acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
        out[tile_m, :] = acc
    return out

Reference Implementation

def matmul_layernorm_pytorch(
    x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
    """
    PyTorch reference implementation of matrix multiplication followed by layer normalization.

    Args:
        x: First input tensor of shape [M, K]
        y: Second input tensor of shape [K, N]
        weight: Layer normalization weight parameter of shape [N]
        bias: Layer normalization bias parameter of shape [N]

    Returns:
        Output tensor of shape [M, N] containing the result of matrix multiplication followed by layer normalization
    """
    matmul_out = torch.matmul(x, y)

    ln_out = F.layer_norm(
        matmul_out.to(torch.float32),
        normalized_shape=(matmul_out.shape[-1],),
        weight=weight.to(torch.float32),
        bias=bias.to(torch.float32),
    )

    return ln_out.to(torch.promote_types(x.dtype, y.dtype))

Verification Function

def check(m: int, k: int, n: int) -> None:
    """
    Verify the matmul_layernorm kernel implementation against the PyTorch reference implementation.

    Args:
        m: First dimension of the first matrix
        k: Second dimension of the first matrix / First dimension of the second matrix
        n: Second dimension of the second matrix
    """
    x = torch.randn([m, k], device="cuda", dtype=torch.float16)
    y = torch.randn([k, n], device="cuda", dtype=torch.float16)
    weight = torch.randn([n], device="cuda", dtype=torch.float16)
    bias = torch.randn([n], device="cuda", dtype=torch.float16)
    run_example(matmul_layernorm, matmul_layernorm_pytorch, (x, y, weight, bias))

Main Function

def main() -> None:
    """
    Main entry point that runs the matmul_layernorm kernel verification with different matrix sizes.

    Tests with two configurations:
    - 32x64 * 64x200
    - 128x256 * 256x400
    """
    # TODO(yf225): n=64 or 128 throws error, need to investigate
    # check(32, 64, 64)
    # check(32, 64, 128)
    check(32, 64, 200)
    check(128, 256, 400)


if __name__ == "__main__":
    main()

Gallery generated by Sphinx-Gallery