Note
Go to the end to download the full example code
Helion GEGLU MLP Example#
This example demonstrates a Helion kernel implementation of GEGLU MLP (GELU-Gated Linear Unit MLP). GEGLU MLP is a common pattern in transformer architectures like Gemma, where:
Input x is projected through gate_proj and up_proj
GEGLU operation: GELU(gate_proj(x)) * up_proj(x)
Result is projected through down_proj
GELU uses tanh approximation: 0.5 * a * (1 + tanh(sqrt(2/π) * (a + 0.044715 * a³)))
Based on liger_kernel’s GEGLU implementation used in Gemma and other gated feedforward networks.
Imports#
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
from torch import Tensor
import torch.nn as nn
import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
GEGLU Kernel#
@helion.kernel()
def geglu(a: Tensor, b: Tensor) -> Tensor:
"""
Performs GEGLU operation: GELU(a) * b using tanh approximation for GELU.
GELU(a) = 0.5 * a * (1 + tanh(sqrt(2/π) * (a + 0.044715 * a³)))
GEGLU(a, b) = GELU(a) * b
Args:
a (Tensor): Input tensor for GELU activation of any shape.
b (Tensor): Input tensor for multiplication, must have same shape as a.
Returns:
Tensor: Result of GEGLU operation with same shape as inputs.
"""
# Ensure tensors have the same shape
assert a.shape == b.shape, (
f"Input tensors must have same shape, got {a.shape} != {b.shape}"
)
# Create output tensor
out = torch.empty_like(a, dtype=torch.promote_types(a.dtype, b.dtype))
# Get the total number of elements and process in tiles
total_elements = a.numel()
# Flatten tensors for easier processing
a_flat = a.view(-1)
b_flat = b.view(-1)
out_flat = out.view(-1)
# Process elements in tiles
for tile_idx in hl.tile(total_elements):
# Load input values and convert to float32 for computation
a_vals = a_flat[tile_idx].to(torch.float32)
b_vals = b_flat[tile_idx]
# GELU computation using tanh approximation
# Constants for tanh approximation
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / π)
# Compute a cubed
a_cubed = a_vals * a_vals * a_vals
# Compute tanh argument: sqrt(2/π) * (a + 0.044715 * a^3)
tanh_arg = sqrt_2_over_pi * (a_vals + 0.044715 * a_cubed)
# Compute tanh and GELU
tanh_result = torch.tanh(tanh_arg)
gelu_a = 0.5 * a_vals * (1.0 + tanh_result)
# GEGLU: GELU(a) * b
result = gelu_a.to(b_vals.dtype) * b_vals
# Store result
out_flat[tile_idx] = result
return out
@helion.kernel()
def geglu_bwd(grad_out: Tensor, a: Tensor, b: Tensor) -> tuple[Tensor, Tensor]:
grad_a = torch.empty_like(a)
grad_b = torch.empty_like(b)
grad_out_flat = grad_out.view(-1)
a_flat = a.view(-1)
b_flat = b.view(-1)
grad_a_flat = grad_a.view(-1)
grad_b_flat = grad_b.view(-1)
for tile_idx in hl.tile(a.numel()):
a_vals = a_flat[tile_idx].to(torch.float32)
b_vals = b_flat[tile_idx].to(torch.float32)
grad_out_vals = grad_out_flat[tile_idx].to(torch.float32)
sqrt_2_over_pi = 0.7978845608028654
a_cubed = a_vals * a_vals * a_vals
tanh_arg = sqrt_2_over_pi * (a_vals + 0.044715 * a_cubed)
tanh_result = torch.tanh(tanh_arg)
gelu_a = 0.5 * a_vals * (1.0 + tanh_result)
grad_b_vals = grad_out_vals * gelu_a
grad_b_flat[tile_idx] = grad_b_vals.to(b.dtype)
dz_da = sqrt_2_over_pi * (1.0 + 0.134145 * a_vals * a_vals)
sech_sq = 1.0 - tanh_result * tanh_result
dgelu_da = 0.5 * (1.0 + tanh_result) + 0.5 * a_vals * sech_sq * dz_da
grad_a_vals = grad_out_vals * b_vals * dgelu_da
grad_a_flat[tile_idx] = grad_a_vals.to(a.dtype)
return grad_a, grad_b
class GEGLUFunction(torch.autograd.Function):
@staticmethod
def forward( # pyrefly: ignore [bad-override]
ctx: Any, # noqa: ANN401
a: Tensor,
b: Tensor,
) -> Tensor:
"""Forward pass for GEGLU."""
out = geglu(a, b)
ctx.save_for_backward(a, b)
return out
@staticmethod
def backward( # type: ignore[override]
ctx: Any, # noqa: ANN401
grad_out: Tensor,
) -> tuple[Tensor, Tensor]:
"""Backward pass for GEGLU."""
a, b = ctx.saved_tensors
grad_a, grad_b = geglu_bwd(grad_out, a, b)
return grad_a, grad_b
def geglu_autograd(a: Tensor, b: Tensor) -> Tensor:
"""GEGLU with forward + backward support."""
return GEGLUFunction.apply(a, b) # type: ignore[no-any-return]
GEGLU MLP Module (matches liger_kernel structure)#
@dataclass
class Config:
"""
Configuration class for MLP.
"""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str = "gelu_pytorch_tanh",
) -> None:
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
class HelionGEGLUMLP(nn.Module):
"""
Helion implementation of GEGLU MLP matching liger_kernel.LigerGEGLUMLP structure.
This implements the complete MLP used in transformer architectures:
down_proj(GEGLU(gate_proj(x), up_proj(x)))
"""
def __init__(self, config: Config) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass: down_proj(GEGLU(gate_proj(x), up_proj(x)))
"""
gate_output = self.gate_proj(x)
up_output = self.up_proj(x)
geglu_output = geglu(gate_output, up_output)
return self.down_proj(geglu_output)
Verification Function#
def check_geglu_kernel(shape: tuple[int, ...]) -> None:
"""
Verify the GEGLU kernel implementation against PyTorch's baseline.
Args:
shape: Shape of the input tensors to test.
"""
def baseline_geglu(a: Tensor, b: Tensor) -> Tensor:
"""
PyTorch baseline implementation using tanh approximation GELU.
This matches the liger_kernel implementation.
"""
return nn.functional.gelu(a, approximate="tanh").to(b.dtype) * b
print("\n=== Forward Pass Test ===")
a = torch.randn(shape, device=DEVICE, dtype=torch.float16)
b = torch.randn(shape, device=DEVICE, dtype=torch.float16)
run_example(geglu, baseline_geglu, (a, b))
# Test forward + backward pass
print("\n\n=== Forward + Backward Pass Test ===")
a_grad = torch.randn(shape, device=DEVICE, dtype=torch.float16, requires_grad=True)
b_grad = torch.randn(shape, device=DEVICE, dtype=torch.float16, requires_grad=True)
run_example(
geglu_autograd,
baseline_geglu,
(a_grad, b_grad),
kernel_name="helion_autograd",
baseline_name="torch",
rtol=1e-2,
atol=1e-1,
bwd=True,
)
class BaselineMLP(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass: down_proj(GEGLU(gate_proj(x), up_proj(x)))
"""
gate_output = self.gate_proj(x)
up_output = self.up_proj(x)
geglu_output = (
nn.functional.gelu(gate_output, approximate="tanh").to(up_output.dtype)
* up_output
)
return self.down_proj(geglu_output)
def check_geglu_mlp(
batch_size: int, seq_len: int, hidden_size: int, intermediate_size: int
) -> None:
"""
Verify the GEGLU MLP implementation against PyTorch's baseline MLP.
Args:
batch_size: Batch size
seq_len: Sequence length
hidden_size: Hidden dimension size
intermediate_size: Intermediate dimension size
"""
config = Config(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act="gelu_pytorch_tanh",
)
# Create test input
x = torch.randn(
batch_size, seq_len, hidden_size, device=DEVICE, dtype=torch.float16
)
# Create models
helion_mlp = HelionGEGLUMLP(config).to(DEVICE).to(torch.float16)
baseline_mlp = BaselineMLP(config).to(DEVICE).to(torch.float16)
# Copy weights to ensure same parameters
baseline_mlp.gate_proj.weight.data = helion_mlp.gate_proj.weight.data.clone()
baseline_mlp.up_proj.weight.data = helion_mlp.up_proj.weight.data.clone()
baseline_mlp.down_proj.weight.data = helion_mlp.down_proj.weight.data.clone()
# Run comparison
run_example(lambda x: helion_mlp(x), lambda x: baseline_mlp(x), (x,))
Tritonbench Integration#
def geglu_tritonbench(tb_op: object, x: Tensor) -> Callable:
"""
Wrapper for tritonbench that matches its interface.
Copies weights from tritonbench operator models to ensure fair comparison.
Args:
tb_op: TritonBench operator instance with baseline_model and liger_model
x (Tensor): Input tensor for the GEGLU MLP.
Returns:
Callable: A callable that runs the GEGLU kernel with copied weights.
"""
# Extract configuration from tritonbench operator
config = Config(
# pyrefly: ignore [missing-attribute]
hidden_size=tb_op.hidden_size,
# pyrefly: ignore [missing-attribute]
intermediate_size=tb_op.intermediate_size,
# pyrefly: ignore [missing-attribute]
hidden_act=tb_op.hidden_act,
)
# Create Helion model
helion_mlp = HelionGEGLUMLP(config).to(x.device).to(x.dtype)
# Copy weights from tritonbench baseline model (LlamaMLP) to ensure fairness
# LlamaMLP has: gate_proj, up_proj, down_proj (same structure as our HelionGEGLUMLP)
# pyrefly: ignore [missing-attribute]
baseline_model = tb_op.baseline_model
# Copy gate projection weights
helion_mlp.gate_proj.weight.data.copy_(baseline_model.gate_proj.weight.data)
# Copy up projection weights
helion_mlp.up_proj.weight.data.copy_(baseline_model.up_proj.weight.data)
# Copy down projection weights
helion_mlp.down_proj.weight.data.copy_(baseline_model.down_proj.weight.data)
return lambda: helion_mlp(x)
Main Function#
def main() -> None:
"""
Main entry point that runs the GEGLU kernel and MLP verification.
Tests various shapes including typical transformer sizes.
"""
print("Testing GEGLU kernel...")
# Test GEGLU kernel with different shapes
kernel_test_shapes = [(8, 2048, 4096), (8, 4096, 8192)]
for shape in kernel_test_shapes:
print(f"\nTesting GEGLU kernel shape: {shape}")
check_geglu_kernel(shape)
print(f"✓ GEGLU kernel shape {shape} passed")
print("\n\nTesting GEGLU MLP...")
# Test GEGLU MLP with transformer-typical sizes
mlp_test_configs = [
(8, 2048, 4096, 11008),
(8, 4096, 8192, 11008),
]
for batch_size, seq_len, hidden_size, intermediate_size in mlp_test_configs:
print(
f"\nTesting GEGLU MLP: B={batch_size}, T={seq_len}, H={hidden_size}, I={intermediate_size}"
)
check_geglu_mlp(batch_size, seq_len, hidden_size, intermediate_size)
print("✓ GEGLU MLP config passed")
if __name__ == "__main__":
main()
Total running time of the script: (0 minutes 0.000 seconds)