Note
Go to the end to download the full example code.
Helion Layer Normalization Forward Example
This example demonstrates a Helion kernel implementation of 1D layer normalization using FP16 inputs and compares it against PyTorch’s built-in layer_norm function.
from __future__ import annotations
import torch
import helion
from helion._testing import run_example
import helion.language as hl
@helion.kernel
def layer_norm_fwd(
x: torch.Tensor,
nomralized_shape: list[int],
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-5,
) -> torch.Tensor:
"""
Performs 1D layer normalization on the input tensor using Helion.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
weight (torch.Tensor): Learnable scale parameter of shape [dim].
bias (torch.Tensor): Learnable bias parameter of shape [dim].
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
Returns:
torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
"""
m, n = x.size()
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
assert len(nomralized_shape) == 1, (
"Helion layer norm only supports 1D layer norm currently"
)
assert nomralized_shape[0] == n, (
f"normalized shape mismatch {nomralized_shape[0]} != {n}"
)
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
for tile_m in hl.tile(m):
acc = x[tile_m, :].to(torch.float32)
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
normalized = (acc - mean) * torch.rsqrt(var + eps)
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
out[tile_m, :] = acc
return out
def main() -> None:
"""
Main execution function for the layer normalization example.
- Generates random input, weight, and bias tensors.
- Runs the Helion layer normalization kernel and compares its output to PyTorch's
built-in layer_norm function using the run_example utility.
- Prints comparison results and checks for correctness within specified tolerances.
"""
batch_size = 32
dim = 64
device = "cuda"
x = torch.randn([batch_size, dim], device=device, dtype=torch.float16)
weight = torch.randn([dim], device=device, dtype=torch.float16)
bias = torch.randn([dim], device=device, dtype=torch.float16)
eps = 1e-4
run_example(
layer_norm_fwd,
torch.nn.functional.layer_norm,
(x, [dim], weight, bias, eps),
kernel_name="helion",
baseline_name="torch",
rtol=1e-3,
atol=1e-3,
)
if __name__ == "__main__":
main()