Note
Go to the end to download the full example code
Helion Layer Normalization Forward and Backward Example#
This example demonstrates a Helion kernel implementation of 1D layer normalization with both forward and backward passes using FP16 inputs and compares it against PyTorch’s built-in layer_norm function.
@helion.kernel
def layer_norm_fwd(
x: torch.Tensor,
normalized_shape: list[int],
weight: torch.Tensor,
bias: torch.Tensor | None = None,
eps: float = 1e-5,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Performs 1D layer normalization on the input tensor using Helion.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
weight (torch.Tensor): Learnable scale parameter of shape [dim].
bias (torch.Tensor | None): Optional learnable bias parameter of shape [dim].
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- The layer-normalized output tensor of shape [batch_size, dim], in FP16.
- Mean tensor of shape [batch_size], in FP32.
- Reciprocal standard deviation tensor of shape [batch_size], in FP32.
"""
m, n = x.size()
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}"
if bias is not None:
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {n}"
assert len(normalized_shape) == 1, (
"Helion layer norm only supports 1D layer norm currently"
)
assert normalized_shape[0] == n, (
f"normalized shape mismatch {normalized_shape[0]} != {n}"
)
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
mean = torch.empty([m], dtype=torch.float32, device=x.device)
rstd = torch.empty([m], dtype=torch.float32, device=x.device)
for tile_m in hl.tile(m):
acc = x[tile_m, :].to(torch.float32)
# Compute mean
mean_val = torch.sum(acc, dim=-1) / n
# Compute variance
centered = acc - mean_val[:, None]
var_val = torch.sum(centered * centered, dim=-1) / n
# Compute reciprocal standard deviation
rstd_val = torch.rsqrt(var_val + eps)
# Normalize
normalized = centered * rstd_val[:, None]
# Apply affine transformation
if bias is not None:
acc = normalized * (weight[:].to(torch.float32)) + (
bias[:].to(torch.float32)
)
else:
acc = normalized * (weight[:].to(torch.float32))
out[tile_m, :] = acc.to(x.dtype)
mean[tile_m] = mean_val
rstd[tile_m] = rstd_val
return out, mean, rstd
@helion.kernel
def layer_norm_bwd(
grad_out: torch.Tensor,
x: torch.Tensor,
mean: torch.Tensor,
rstd: torch.Tensor,
weight: torch.Tensor,
compute_bias_grad: hl.constexpr = True, # type: ignore[valid-type]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""
Compute gradients for weight (dW) and optionally bias (dB) parameters.
This kernel performs reduction across the batch dimension (M) to accumulate
gradients for each feature dimension's weight and bias parameters.
Args:
grad_out: Gradient w.r.t layer norm output [M, N]
x: Original input tensor [M, N]
mean: Per-sample mean computed in forward pass [M]
rstd: Per-sample reciprocal standard deviation from forward pass [M]
weight: Weight parameter (used only for dtype/device info) [N]
compute_bias_grad: Whether to compute bias gradient (default: True)
Returns:
(grad_x, grad_weight, grad_bias): Gradients for input, weight, and bias (if computed)
grad_bias is None if compute_bias_grad is False
"""
m_block = hl.register_block_size(x.size(0))
n = hl.specialize(x.size(1))
grad_x = torch.empty_like(x)
num_blocks = (x.size(0) + m_block - 1) // m_block
grad_weight_blocks = x.new_empty([num_blocks, n], dtype=torch.float32)
grad_bias_blocks = x.new_empty([num_blocks, n], dtype=torch.float32)
for mb_cta in hl.tile(x.size(0), block_size=m_block):
grad_w_acc = weight.new_zeros(n, dtype=torch.float32)
if compute_bias_grad:
grad_b_acc = weight.new_zeros(n, dtype=torch.float32)
weight_cta = weight[None, :].to(torch.float32)
for mb in hl.tile(mb_cta.begin, mb_cta.end):
x_mb = x[mb, :].to(torch.float32)
dy_mb = grad_out[mb, :].to(torch.float32)
mean_mb = mean[mb].to(torch.float32)
rstd_mb = rstd[mb].to(torch.float32)
x_hat = (x_mb - mean_mb[:, None]) * rstd_mb[:, None]
grad_w_acc += torch.sum(dy_mb * x_hat, dim=0)
if compute_bias_grad:
# pyrefly: ignore [unbound-name]
grad_b_acc += torch.sum(dy_mb, dim=0)
wdy = weight_cta * dy_mb
c1 = torch.sum(x_hat * wdy, dim=-1) / n
c2 = torch.sum(wdy, dim=-1) / n
dx = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd_mb[:, None]
grad_x[mb, :] = dx.to(x.dtype)
grad_weight_blocks[mb_cta.id, :] = grad_w_acc
if compute_bias_grad:
grad_bias_blocks[mb_cta.id, :] = grad_b_acc # type: ignore[index]
grad_weight = grad_weight_blocks.sum(0).to(weight.dtype)
if compute_bias_grad:
grad_bias = grad_bias_blocks.sum(0).to(weight.dtype)
return grad_x, grad_weight, grad_bias
return grad_x, grad_weight, None
class LayerNormFunction(torch.autograd.Function):
@staticmethod
def forward( # pyrefly: ignore [bad-override]
ctx: Any, # noqa: ANN401
x: torch.Tensor,
normalized_shape: list[int],
weight: torch.Tensor,
bias: torch.Tensor | None,
eps: float,
) -> torch.Tensor:
"""Forward pass for layer normalization."""
y, mean, rstd = layer_norm_fwd(x, normalized_shape, weight, bias, eps)
ctx.save_for_backward(x, weight, bias, mean, rstd) # type: ignore[arg-type]
ctx.normalized_shape = normalized_shape # type: ignore[attr-defined]
return y
@staticmethod
def backward( # type: ignore[override]
ctx: Any, # noqa: ANN401
grad_output: torch.Tensor,
) -> tuple[
torch.Tensor | None, None, torch.Tensor | None, torch.Tensor | None, None
]:
"""Backward pass for layer normalization split into two separate kernels for efficiency."""
grad_out = grad_output # Use common name internally
x, weight, bias, mean, rstd = ctx.saved_tensors # type: ignore[attr-defined]
# Check if bias gradient is needed
compute_bias_grad = bias is not None
grad_x, grad_weight, grad_bias = layer_norm_bwd(
grad_out, x, mean, rstd, weight, compute_bias_grad
)
return grad_x, None, grad_weight, grad_bias, None
def layer_norm(
x: torch.Tensor,
normalized_shape: list[int],
weight: torch.Tensor,
bias: torch.Tensor | None = None,
eps: float = 1e-5,
) -> torch.Tensor:
"""Layer normalization with forward + backward support."""
return LayerNormFunction.apply(x, normalized_shape, weight, bias, eps) # type: ignore[no-any-return]
Benchmark Wrapper#
def layer_norm_tritonbench(
tb_op: object,
x: torch.Tensor,
normalized_shape: list[int],
weight: torch.Tensor,
bias: torch.Tensor | None = None,
eps: float = 1e-5,
) -> Callable[[], torch.Tensor]:
"""
Wrapper for tritonbench that matches expected interface.
Args:
tb_op: TritonBench operator instance
x: Input tensor
normalized_shape: Shape to normalize over
weight: Weight parameter
bias: Bias parameter (optional)
eps: Small constant for numerical stability
Returns:
Callable that returns normalized tensor
"""
return lambda: layer_norm(x, normalized_shape, weight, bias, eps)
def main() -> None:
"""
Main execution function for the layer normalization example.
- Generates random input, weight, and bias tensors.
- Runs the Helion layer normalization kernel and compares its output to PyTorch's
built-in layer_norm function using the run_example utility.
- Prints comparison results and checks for correctness within specified tolerances.
"""
batch_size = 4096
dim = 10240
device = DEVICE
# Test forward pass only
print("\n=== Forward Pass Test ===")
x = -2.3 + 0.5 * torch.randn([batch_size, dim], device=device, dtype=torch.float16)
weight = torch.randn([dim], device=device, dtype=torch.float16)
bias = torch.randn([dim], device=device, dtype=torch.float16)
eps = 1e-4
for b in [bias, None]:
run_example(
layer_norm,
torch.nn.functional.layer_norm,
(x, [dim], weight, b, eps),
rtol=1e-3,
atol=1e-3,
)
# Test forward + backward pass
print("\n\n=== Forward + Backward Pass Test ===")
x_grad = torch.randn(
[batch_size, dim], device=device, dtype=torch.float16, requires_grad=True
)
weight_grad = torch.randn(
[dim], device=device, dtype=torch.float16, requires_grad=True
)
bias_grad = torch.randn(
[dim], device=device, dtype=torch.float16, requires_grad=True
)
for b in [bias_grad, None]:
run_example(
layer_norm,
torch.nn.functional.layer_norm,
(x_grad, [dim], weight_grad, b, eps),
rtol=1e-3,
atol=1e-3,
bwd=True,
)
if __name__ == "__main__":
main()
Total running time of the script: (0 minutes 0.000 seconds)