Note
Go to the end to download the full example code.
Helion Softmax Kernel Examples
This example demonstrates multiple Helion kernel implementations of the softmax function, including a simple wrapper around PyTorch’s softmax, a decomposed version using explicit exponentiation and normalization, and a numerically optimized two-pass version. The example also includes a check function to compare these kernels against PyTorch’s built-in softmax for correctness.
from __future__ import annotations
import torch
import helion
from helion._testing import run_example
import helion.language as hl
@helion.kernel()
def softmax(x: torch.Tensor) -> torch.Tensor:
"""
Simple Helion kernel wrapping PyTorch's softmax function.
Args:
x (torch.Tensor): Input tensor of shape [n, m].
Returns:
torch.Tensor: Softmax output tensor of the same shape.
"""
n, _m = x.size()
out = torch.empty_like(x)
for tile_n in hl.tile(n):
out[tile_n, :] = torch.nn.functional.softmax(x[tile_n, :], dim=1)
return out
@helion.kernel()
def softmax_decomposed(x: torch.Tensor) -> torch.Tensor:
"""
Helion kernel implementing softmax by decomposing into max, exp, and normalization steps.
This avoids using PyTorch's built-in softmax decomposition.
Args:
x (torch.Tensor): Input tensor of shape [n, m].
Returns:
torch.Tensor: Softmax output tensor of the same shape.
"""
n, _m = x.size()
out = torch.empty_like(x)
for tile_n in hl.tile(n):
values = x[tile_n, :]
amax = torch.amax(values, dim=1, keepdim=True)
exp = torch.exp(values - amax)
sum_exp = torch.sum(exp, dim=1, keepdim=True)
out[tile_n, :] = exp / sum_exp
return out
@helion.kernel()
def softmax_two_pass(x: torch.Tensor) -> torch.Tensor:
"""
Numerically optimized Helion kernel performing softmax in two passes.
This version uses fewer passes but is less numerically stable.
Args:
x (torch.Tensor): Input tensor of shape [m, n].
Returns:
torch.Tensor: Softmax output tensor of the same shape.
"""
m, n = x.size()
out = torch.empty_like(x)
block_size_m = hl.register_block_size(m)
block_size_n = hl.register_block_size(n)
for tile_m in hl.tile(m, block_size=block_size_m):
mi = hl.full([tile_m], float("-inf"), dtype=torch.float32)
di = hl.zeros([tile_m], dtype=torch.float32)
for tile_n in hl.tile(n, block_size=block_size_n):
values = x[tile_m, tile_n]
local_amax = torch.amax(values, dim=1)
mi_next = torch.maximum(mi, local_amax)
di = di * torch.exp(mi - mi_next) + torch.exp(
values - mi_next[:, None]
).sum(dim=1)
mi = mi_next
for tile_n in hl.tile(n, block_size=block_size_n):
values = x[tile_m, tile_n]
out[tile_m, tile_n] = torch.exp(values - mi[:, None]) / di[:, None]
return out
def check(m: int, n: int) -> None:
"""
Runs correctness checks comparing Helion softmax kernels against PyTorch's softmax.
Args:
m (int): Number of rows in input tensor.
n (int): Number of columns in input tensor.
"""
x = torch.randn([m, n], device="cuda", dtype=torch.float16)
kernels = {
"helion simple": softmax,
# "helion decomposed": softmax_decomposed, # Disabled due to possible issues
"helion two pass": softmax_two_pass,
}
run_example(kernels, lambda x: torch.nn.functional.softmax(x, dim=1), (x,))
def main() -> None:
"""
Main function to run the softmax kernel correctness check with example input size.
"""
check(1024, 1024)
if __name__ == "__main__":
main()