Language Module#
The helion.language module contains the core DSL constructs for writing GPU kernels.
Loop Constructs#
tile()#
- helion.language.tile(begin_or_end, end_or_none=None, /, block_size=None)[source]#
Break up an iteration space defined by a size or sequence of sizes into tiles.
The generated tiles can flatten the iteration space into the product of the sizes, perform multidimensional tiling, swizzle the indices for cache locality, reorder dimensions, etc. The only invariant is that every index in the range of the given sizes is covered exactly once.
The exact tiling strategy is determined by a Config object, typically created through autotuning.
If used at the top level of a function, this becomes the grid of the kernel. Otherwise, it becomes a loop in the output kernel.
The key difference from
grid()is thattilegives youTileobjects that load a slice of elements, whilegridgives you scalar integer indices. It is recommended to usetilein most cases, since it allows more choices in autotuning.- Parameters:
begin_or_end (
int|Tensor|Sequence[int|Tensor]) – If 2+ positional args provided, the start of iteration space. Otherwise, the end of iteration space.end_or_none (
int|Tensor|Sequence[int|Tensor] |None) – If 2+ positional args provided, the end of iteration space.block_size (
object) – Fixed block size (overrides autotuning) or None for autotuned size
- Returns:
Iterator over tile objects
- Return type:
Examples
One dimensional tiling:
@helion.kernel def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: result = torch.zeros_like(x) for tile in hl.tile(x.size(0)): # tile processes multiple elements at once result[tile] = x[tile] + y[tile] return result
Multi-dimensional tiling:
@helion.kernel() def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: m, k = x.size() k, n = y.size() out = torch.empty([m, n], dtype=x.dtype, device=x.device) for tile_m, tile_n in hl.tile([m, n]): acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) for tile_k in hl.tile(k): acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) out[tile_m, tile_n] = acc return out
Fixed block size:
@helion.kernel def process_with_fixed_block(x: torch.Tensor) -> torch.Tensor: result = torch.zeros_like(x) for tile in hl.tile(x.size(0), block_size=64): # Process with fixed block size of 64 result[tile] = x[tile] * 2 return result
Using tile properties:
@helion.kernel def tile_info_example(x: torch.Tensor) -> torch.Tensor: result = torch.zeros([x.size(0)], dtype=x.dtype, device=x.device) for tile in hl.tile(x.size(0)): # Access tile properties start = tile.begin end = tile.end size = tile.block_size indices = tile.index # [start, start+1, ..., end-1] # Use in computation result[tile] = x[tile] + indices return result
See also
grid(): For explicit control over the launch gridtile_index(): For getting tile indicesregister_block_size(): For registering block sizes
Note
Similar to
range()with multiple forms:tile(end) iterates 0 to end-1, autotuned block_size
tile(begin, end) iterates begin to end-1, autotuned block_size
tile(begin, end, block_size) iterates begin to end-1, fixed block_size
tile(end, block_size=block_size) iterates 0 to end-1, fixed block_size
Block sizes can be registered for autotuning explicitly with
register_block_size()and passed as theblock_sizeargument if one needs two loops to use the same block size. Passingblock_size=Noneis equivalent to calling register_block_size.Use
tilein most cases. Usegridwhen you need explicit control over the launch grid.
The tile() function is the primary way to create parallel loops in Helion kernels. It provides several key features:
Tiling Strategies: The exact tiling strategy is determined by a Config object, typically created through autotuning. This allows for:
Multidimensional tiling
Index swizzling for cache locality
Dimension reordering
Flattening of iteration spaces
Usage Patterns:
# Simple 1D tiling
for tile in hl.tile(1000):
# tile.begin, tile.end, tile.block_size are available
# Load entire tile (not just first element)
data = tensor[tile] # or hl.load(tensor, tile) for explicit loading
# 2D tiling
for tile_i, tile_j in hl.tile([height, width]):
# Each tile represents a portion of the 2D space
pass
# With explicit begin/end/block_size
for tile in hl.tile(0, 1000, block_size=64):
pass
Grid vs Loop Behavior:
When used at the top level of a kernel function,
tile()becomes the grid of the kernel (parallel blocks)When used nested inside another loop, it becomes a sequential loop within each block
grid()#
- helion.language.grid(begin_or_end, end_or_none=None, /, step=None)[source]#
Iterate over individual indices of the given iteration space.
The key difference from
tile()is thatgridgives you scalar integer indices (torch.SymInt), whiletilegives youTileobjects that load a slice of elements. Usetilein most cases. Usegridwhen you need explicit control over the launch grid or when processing one element at a time.Semantics are equivalent to:
for i in hl.tile(...): # i is a Tile object, accesses multiple elements data = tensor[i] # loads slice of elements (1D tensor)
vs:
for i in hl.grid(...): # i is a scalar index, accesses single element data = tensor[i] # loads single element (0D scalar)
When used at the top level of a function, this becomes the grid of the kernel. Otherwise, it becomes a loop in the output kernel.
- Parameters:
begin_or_end (
int|Tensor|ConstExpr|Sequence[int|Tensor|ConstExpr]) – If 2+ positional args provided, the start of iteration space. Otherwise, the end of iteration space.end_or_none (
int|Tensor|ConstExpr|Sequence[int|Tensor|ConstExpr] |None) – If 2+ positional args provided, the end of iteration space.step (
int|Tensor|ConstExpr|Sequence[int|Tensor|ConstExpr] |None) – Step size for iteration (default: 1)
- Returns:
Iterator over scalar indices
- Return type:
Iterator[torch.SymInt] or Iterator[Sequence[torch.SymInt]]
See also
tile(): For processing multiple elements at oncetile_index(): For getting tile indicesarange(): For creating index sequences
Note
Similar to
range()with multiple forms:grid(end) iterates from 0 to end-1, step 1
grid(begin, end) iterates from begin to end-1, step 1
grid(begin, end, step) iterates from begin to end-1, given step
grid(end, step=step) iterates from 0 to end-1, given step
Use
tilein most cases. Usegridwhen you need explicit control over the launch grid.
The grid() function iterates over individual indices rather than tiles. It’s equivalent to tile(size, block_size=1) but returns scalar indices instead of tile objects.
jagged_tile()#
- helion.language.jagged_tile(parent)[source]#
Iterate over a jagged inner dimension using an N-D parent tensor of per-lane ends.
jagged_tileis the jagged counterpart totile(). Instead of taking a scalar upper bound, it takes a tensor whose every axis comes from an enclosing parent tile context. Each element ofparentgives the true end of the jagged child loop for the corresponding parent lane.Conceptually, Helion lowers:
for tile_k in hl.jagged_tile(parent): ...
to:
end = parent.amax() for tile_k in hl.tile(end): mask = tile_k.index[None, :] < parent[:, None] ...
while automatically masking out indices where
tile_k.index >= parentfor each parent lane. This lets you write ragged loops directly instead of writing a dense loop and manually constructing masks.- Parameters:
parent (
object) – N-D tensor whose every axis is an enclosing tile axis.parent[i, ...]is the true end of the jagged child loop for that combination of parent lanes. The 1-D case is the common scalar-of-rows pattern.- Returns:
Iterator over tile objects for the jagged child dimension
- Return type:
Iterator[Tile]
Examples
Before
jagged_tile: dense loop plus manual mask:@helion.kernel def jagged_row_sum_masked( x: torch.Tensor, row_lengths: torch.Tensor ) -> torch.Tensor: b = row_lengths.size(0) out = torch.zeros([b], dtype=x.dtype, device=x.device) for tile_b in hl.tile(b): lengths = row_lengths[tile_b] max_len = lengths.amax() acc = hl.zeros([tile_b], dtype=x.dtype) for tile_k in hl.tile(max_len): mask = tile_k.index[None, :] < lengths[:, None] vals = hl.load(x, [tile_b, tile_k], extra_mask=mask) acc = acc + vals.sum(dim=1) out[tile_b] = acc return out
With
jagged_tile: the mask becomes implicit:@helion.kernel def jagged_row_sum( x: torch.Tensor, row_lengths: torch.Tensor ) -> torch.Tensor: b = row_lengths.size(0) out = torch.zeros([b], dtype=x.dtype, device=x.device) for tile_b in hl.tile(b): lengths = row_lengths[tile_b] acc = hl.zeros([tile_b], dtype=x.dtype) for tile_k in hl.jagged_tile(lengths): acc = acc + x[tile_b, tile_k].sum(dim=1) out[tile_b] = acc return out
Packed jagged data with offsets:
@helion.kernel def jagged_sum( x_data: torch.Tensor, x_offsets: torch.Tensor, ) -> torch.Tensor: b = x_offsets.size(0) - 1 out = torch.zeros([b], dtype=x_data.dtype, device=x_data.device) for tile_b in hl.tile(b): starts = x_offsets[tile_b] ends = x_offsets[tile_b.index + 1] lengths = ends - starts acc = hl.zeros([tile_b], dtype=x_data.dtype) for tile_k in hl.jagged_tile(lengths): idx = starts[:, None] + tile_k.index[None, :] acc = acc + x_data[idx].sum(dim=1) out[tile_b] = acc return out
See also
tile(): For dense or uniform iteration spaces
Note
jagged_tilecurrently has a few important restrictions:The input must be a tensor of rank >= 1. Scalars are not allowed, and every axis of the parent tensor must come from an enclosing tile context.
jagged_tilecannot be used as the outermost loop of a kernel.A jagged child tile must be indexed together with its parent axes. For example,
x[tile_k]is invalid iftile_kcomes fromhl.jagged_tile(lengths)undertile_b. Usex[tile_b, tile_k]or another indexing expression that preserves the parent context.Use
tile()when the loop bound is uniform across lanes.Check more jagged kernels using
hl.jagged_tilein theexamples/directory.
The jagged_tile() function is the jagged counterpart to tile(). It iterates an
inner dimension whose extent varies per lane of an enclosing parent tile, using a
1D tensor of per-lane end positions from that parent context.
Instead of writing a dense inner loop and manually building a mask, jagged_tile()
lets Helion apply the masking implicitly for indices beyond each lane’s true length.
static_range()#
- helion.language.static_range(begin_or_end, end_or_none=None, /, step=1)[source]#
Create a range that gets unrolled at compile time by iterating over constant integer values.
This function is similar to Python’s built-in range(), but it generates a sequence of integer constants that triggers loop unrolling behavior in Helion kernels. The loop is completely unrolled at compile time, with each iteration becoming separate instructions in the generated code.
- Parameters:
- Returns:
Iterator over constant integer values
- Return type:
Iterator[int]
Examples
Simple unrolled loop:
@helion.kernel def unrolled_example(x: torch.Tensor) -> torch.Tensor: result = torch.zeros_like(x) for tile in hl.tile(x.size(0)): acc = torch.zeros([tile], dtype=x.dtype, device=x.device) # This loop gets completely unrolled for i in hl.static_range(3): acc += x[tile] * i result[tile] = acc return result
Range with start and step:
@helion.kernel def kernel_stepped_unroll(x: torch.Tensor) -> torch.Tensor: result = torch.zeros_like(x) for tile in hl.tile(x.size(0)): acc = torch.zeros([tile], dtype=x.dtype, device=x.device) # Unroll loop from 2 to 8 with step 2: [2, 4, 6] for i in hl.static_range(2, 8, 2): acc += x[tile] * i result[tile] = acc return result
Note
Only constant integer values are supported
The range must be small enough to avoid compilation timeouts
Each iteration becomes separate instructions in the generated Triton code
Use for small, fixed iteration counts where unrolling is beneficial
static_range() behaves like a compile-time unrolled range for small loops. It hints the compiler to fully unroll the loop body where profitable.
barrier()#
- helion.language.barrier()[source]#
Grid-wide barrier separating top-level hl.tile / hl.grid loops.
- Return type:
barrier() inserts a grid-wide synchronization point between top-level hl.tile or hl.grid loops. It forces persistent kernel execution so that all blocks complete one phase before the next begins.
Memory Operations#
load()#
- helion.language.load(tensor, index, extra_mask=None, eviction_policy=None)[source]#
Load a value from a tensor using a list of indices.
This function is equivalent to tensor[index] but allows setting extra_mask= to mask elements beyond the default masking based on the hl.tile range. It also accepts an optional eviction_policy which is forwarded to the underlying Triton tl.load call to control the cache eviction behavior (e.g., “evict_last”).
- Parameters:
tensor (
Tensor|StackTensor) – The tensor / stack tensor to load fromindex (
list[object]) – The indices to use to index into the tensorextra_mask (
Tensor|None) – The extra mask (beyond automatic tile bounds masking) to apply to the tensoreviction_policy (
str|None) – Optional Triton load eviction policy to hint cache behavior
- Returns:
The loaded value
- Return type:
store()#
- helion.language.store(tensor, index, value, extra_mask=None)[source]#
Store a value to a tensor using a list of indices.
This function is equivalent to tensor[index] = value but allows setting extra_mask= to mask elements beyond the default masking based on the hl.tile range.
- Parameters:
- Return type:
- Returns:
None
atomic_add()#
- helion.language.atomic_add(target, index, value, sem='relaxed')[source]#
Atomically add a value to a target tensor.
Performs an atomic read-modify-write that adds
valuetotarget[index]. This is safe for concurrent access from multiple threads/blocks.- Parameters:
- Returns:
The previous value(s) stored at
target[index]before the update.- Return type:
Example
@helion.kernel def global_sum(x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
- for tile in hl.tile(x.size(0)):
hl.atomic_add(result, [0], x[tile].sum())
return result
Notes
Use for race-free accumulation across parallel execution.
Higher memory semantics may reduce performance.
atomic_and()#
- helion.language.atomic_and(target, index, value, sem='relaxed')[source]#
Atomically apply bitwise AND with
valuetotarget[index].- Parameters:
- Returns:
The previous value(s) stored at
target[index]before the update.- Return type:
atomic_or()#
- helion.language.atomic_or(target, index, value, sem='relaxed')[source]#
Atomically apply bitwise OR with
valuetotarget[index].- Parameters:
- Returns:
The previous value(s) stored at
target[index]before the update.- Return type:
atomic_xor()#
- helion.language.atomic_xor(target, index, value, sem='relaxed')[source]#
Atomically apply bitwise XOR with
valuetotarget[index].- Parameters:
- Returns:
The previous value(s) stored at
target[index]before the update.- Return type:
atomic_xchg()#
atomic_max()#
atomic_min()#
atomic_cas()#
- helion.language.atomic_cas(target, index, expected, value, sem='relaxed')[source]#
Atomically compare-and-swap a value at
target[index].If the current value equals
expected, writesvalue. Otherwise leaves memory unchanged.- Parameters:
target (
Tensor) – Tensor to update.index (
list[object]) – Indices selecting elements to update. Can include tiles.expected (
Tensor|float|bool) – Expected current value(s) used for comparison.value (
Tensor|float|bool) – New value(s) to write if comparison succeeds.sem (
str) – Memory ordering semantics. One of"relaxed","acquire","release","acq_rel". Defaults to"relaxed".
- Returns:
The previous value(s) stored at
target[index]before the compare-and-swap.- Return type:
Note
Triton CAS doesn’t support a masked form; our generated code uses an unmasked CAS and relies on index masking to avoid OOB.
Inline Assembly#
inline_asm_elementwise()#
- helion.language.inline_asm_elementwise(asm, constraints, args, dtype, is_pure, pack)[source]#
Execute inline assembly over a tensor. Essentially, this is map where the function is inline assembly.
The input tensors args are implicitly broadcasted to the same shape. dtype can be a tuple of types, in which case the output is a tuple of tensors.
Each invocation of the inline asm processes pack elements at a time. Exactly which set of inputs a block receives is unspecified. Input elements of size less than 4 bytes are packed into 4-byte registers.
This op does not support empty dtype – the inline asm must return at least one tensor, even if you don’t need it. You can work around this by returning a dummy tensor of arbitrary type; it shouldn’t cost you anything if you don’t use it.
- Parameters:
asm (
str) – assembly to run. Must match target’s assembly format.constraints (
str) – asm constraints in LLVM formatargs (
Sequence[Tensor]) – the input tensors, whose values are passed to the asm blockdtype (
Union[dtype,Sequence[dtype]]) – the element type(s) of the returned tensor(s)is_pure (
bool) – if true, the compiler assumes the asm block has no side-effectspack (
int) – the number of elements to be processed by one instance of inline assembly
- Return type:
- Returns:
one tensor or a tuple of tensors of the given dtypes
Executes target-specific inline assembly on elements of one or more tensors with broadcasting and optional packed processing.
inline_triton()#
- helion.language.inline_triton(triton_source, args, output_like)[source]#
Inline a raw Triton snippet inside a Helion kernel.
- Parameters:
triton_source (
str) – The Triton code snippet. The last statement must be an expression representing the return value. The snippet may be indented, and common indentation is stripped automatically.args (
Sequence[object] |Mapping[str,object]) – Positional or keyword placeholders that will be substituted viastr.formatbefore code generation. Provide a tuple/list for positional placeholders ({0},{1}, …) or a mapping for named placeholders ({x},{y}, …).output_like (
TypeVar(_T)) – Example tensors describing the expected outputs. A single tensor indicates a single output; a tuple/list of tensors indicates multiple outputs.
- Return type:
TypeVar(_T)- Returns:
The value(s) produced by the snippet. Matches the structure of
output_like.
Embeds small Triton code snippets directly inside a Helion kernel. Common indentation is removed automatically, placeholders are replaced using str.format with tuple or dict arguments, and the final line in the snippet becomes the return value. Provide tensors (or tuples of tensors) via output_like so Helion knows the type of the return value.
triton_kernel()#
- helion.language.triton_kernel(triton_source_or_fn, args, output_like)[source]#
Define (once) and call a @triton.jit function from Helion device code.
- Parameters:
triton_source_or_fn (
object) – Source for a single @triton.jit function definition, or a Python function object defining a @triton.jit kernel.args (
Sequence[object] |Mapping[str,object]) – Positional or keyword placeholders that will be substituted via name resolution of Helion variables.output_like (
TypeVar(_T)) – Example tensor(s) describing the expected outputs for shape/dtype checks.
- Return type:
TypeVar(_T)
Define (once) and call a @triton.jit function from Helion device code.
Accepts either:
a source string containing a single Triton function definition,
a function name string referring to a
@triton.jitfunction in the kernel’s module, ora Python function object (or Triton JITFunction; unwrapped via
.fn).
The function is emitted at module scope once and then invoked from the kernel body.
Pass
output_liketensors for shape/dtype checks identical toinline_triton.
Example (by name):
@triton.jit
def add_pairs(a, b):
return a + b
@helion.kernel()
def k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
for tile in hl.tile(x.shape):
out[tile] = hl.triton_kernel("add_pairs", args=(x[tile], y[tile]), output_like=x[tile])
return out
Tensor Creation#
zeros()#
- helion.language.zeros(shape, dtype=torch.float32, device=None)[source]#
Return a device-tensor filled with zeros.
Equivalent to
hl.full(shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype).Note
Only use within
hl.tile()loops for creating local tensors. For output tensor creation, usetorch.zeros()with proper device placement.- Parameters:
- Returns:
A device tensor of the given shape and dtype filled with zeros
- Return type:
Examples
@helion.kernel def process_kernel(input: torch.Tensor) -> torch.Tensor: result = torch.empty_like(input) for tile in hl.tile(input.size(0)): buffer = hl.zeros([tile], dtype=input.dtype) # Local buffer buffer += input[tile] # Add input values to buffer result[tile] = buffer return result
full()#
- helion.language.full(shape, value, dtype=torch.float32, device=None)[source]#
Create a device-tensor filled with a specified value.
Note
Only use within
hl.tile()loops for creating local tensors. For output tensor creation, usetorch.full()with proper device placement.- Parameters:
- Returns:
A device tensor of the given shape and dtype filled with value
- Return type:
Examples
@helion.kernel def process_kernel(input: torch.Tensor) -> torch.Tensor: result = torch.empty_like(input) for tile in hl.tile(input.size(0)): # Create local buffer filled with initial value buffer = hl.full([tile], 0.0, dtype=input.dtype) buffer += input[tile] # Add input values to buffer result[tile] = buffer return result
arange()#
- helion.language.arange(*args, dtype=None, device=None, **kwargs)[source]#
Same as torch.arange(), but defaults to same device as the current kernel.
Creates a 1D tensor containing a sequence of integers in the specified range, automatically using the current kernel’s device and index dtype.
- Parameters:
args (
int) – Positional arguments passed to torch.arange(start, end, step).dtype (
dtype|None) – Data type of the result tensor (defaults to kernel’s index dtype)device (
device|None) – Device must match the current compile environment devicekwargs (
object) – Additional keyword arguments passed to torch.arange
- Returns:
1D tensor containing the sequence
- Return type:
See also
tile_index(): For getting tile indiceszeros(): For creating zero-filled tensorsfull(): For creating constant-filled tensors
rand()#
- helion.language.rand(shape, seed, device=None, offsets=None)[source]#
hl.rand provides a Philox-based pseudorandom number generator (PRNG) that operates independently of PyTorch’s global random seed. Instead, it requires an explicit seed argument. By default, offsets are derived from the full logical sizes of the tiles specified in the shape argument. An explicit
offsetstensor may be supplied to bypass the implicit offset computation; the output will then haveoffsets.shapeandshapeis ignored (an empty list[]is fine).- Parameters:
shape (
list[object]) – A list of sizes for the output tensor. Ignored whenoffsetsis provided.seed (
int|Tensor) – A single element int64 tensor or int literaldevice (
device|None) – Device must match the current compile environment deviceoffsets (
Tensor|None) – Optional explicit int64 offset tensor fed directly into the philox RNG. When provided, the output shape equalsoffsets.shape.
- Returns:
A device tensor of float32 dtype filled with uniform random values in [0, 1)
- Return type:
Examples
@helion.kernel def process_kernel(x: torch.Tensor) -> torch.Tensor: output = torch.zeros_like(x) (m,) = x.shape for tile_m in hl.tile(m): output[tile_m] = hl.rand([tile_m], seed=42) return output
With explicit offsets (e.g. spaced out so sibling streams may use offsets
+1and+2):@helion.kernel def spaced_rand_kernel(x: torch.Tensor) -> torch.Tensor: output = torch.zeros_like(x) (m,) = x.shape for tile_m in hl.tile(m): base = hl.arange(tile_m).to(torch.int64) * 3 output[tile_m] = hl.rand([], seed=42, offsets=base) return output
rand4x()#
- helion.language.rand4x(seed, offsets, device=None)[source]#
hl.rand4x returns four independent uniform float32 tensors in
[0, 1)per offset from a single Philox round (~4× cheaper than four separatehl.rand()calls). Mirrors Triton’stl.rand4x.- Parameters:
- Returns:
Four float32 tensors, each with shape
offsets.shapeand values in [0, 1).- Return type:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Examples
Three sibling dropout masks per element with a single Philox call:
@helion.kernel def triple_dropout_kernel(x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) (n,) = x.shape for tile in hl.tile(n): base = hl.tile_index(tile).to(torch.int64) r0, r1, r2, _ = hl.rand4x(seed=42, offsets=base) keep = (r0 > 0.1) & (r1 > 0.2) & (r2 > 0.3) out[tile] = x[tile] * keep.to(x.dtype) return out
randint()#
- helion.language.randint(shape, low, high, seed, device=None)[source]#
hl.randint provides a Philox-based pseudorandom integer generator (PRNG) that operates independently of PyTorch’s global random seed. Instead, it requires an explicit seed argument. Offsets are derived from the full logical sizes of the tiles specified in the shape argument.
- Parameters:
shape (
list[object]) – A list of sizes for the output tensorlow (
int) – Lowest integer to be drawn from the distribution (inclusive)high (
int) – One above the highest integer to be drawn from the distribution (exclusive)seed (
int|Tensor) – A single element int64 tensor or int literaldevice (
device|None) – Device must match the current compile environment device
- Returns:
A device tensor of int32 dtype filled with random integers in [low, high)
- Return type:
Examples
@helion.kernel def process_kernel(x: torch.Tensor) -> torch.Tensor: output = torch.zeros(x.shape, dtype=torch.int32, device=x.device) (m,) = x.shape for tile_m in hl.tile(m): output[tile_m] = hl.randint([tile_m], low=0, high=10, seed=42) return output
Tunable Parameters#
register_block_size()#
- helion.language.register_block_size(min_or_max, max_or_none=None, /)[source]#
Explicitly register a block size that should be autotuned and can be used for allocations and inside hl.tile(…, block_size=…).
This is useful if you have two loops where you want them to share a block size, or if you need to allocate a kernel tensor before the hl.tile() loop.
- The signature can one of:
hl.register_block_size(max) hl.register_block_size(min, max)
Where min and max are integers that control the range of block_sizes searched by the autotuner. Max may be a symbolic shape, but min must be a constant integer.
register_tunable()#
Tile Operations#
Tile Class#
- class helion.language.Tile(block_id)[source]#
This class should not be instantiated directly, it is the result of hl.tile(…) and represents a single tile of the iteration space.
Tile’s can be used as indices to tensors, e.g. tensor[tile]. Tile’s can also be use as sizes for allocations, e.g. torch.empty([tile]). There are also properties such as *
tile.index*tile.begin*tile.end*tile.id*tile.block_size*tile.countthat can be used to retrieve various information about the tile.Masking is implicit for tiles, so if the final tile is smaller than the block size loading that tile will only load the valid elements and reduction operations know to ignore the invalid elements.
- Parameters:
block_id (
int) –
The Tile class represents a portion of an iteration space with the following key attributes:
begin: Starting indices of the tileend: Ending indices of the tileblock_size: Size of the tile in each dimension
View Operations#
subscript()#
- helion.language.subscript(tensor, index)[source]#
Equivalent to tensor[index] where tensor is a kernel-tensor (not a host-tensor).
Can be used to add dimensions to the tensor, e.g. tensor[None, :] or tensor[:, None].
- Parameters:
- Returns:
The indexed tensor with potentially modified dimensions
- Return type:
Examples
@helion.kernel def broadcast_multiply(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # x has shape (N,), y has shape (M,) result = torch.empty( [x.size(0), y.size(0)], dtype=x.dtype, device=x.device ) for tile_i, tile_j in hl.tile([x.size(0), y.size(0)]): # Get tile data x_tile = x[tile_i] y_tile = y[tile_j] # Make x broadcastable: (tile_size, 1) # same as hl.subscript(x_tile, [slice(None), None]) x_expanded = x_tile[:, None] # Make y broadcastable: (1, tile_size) # same as hl.subscript(y_tile, [None, slice(None)]) y_expanded = y_tile[None, :] result[tile_i, tile_j] = x_expanded * y_expanded return result
Note
Only supports None and : (slice(None)) indexing
Used for reshaping kernel tensors by adding dimensions
Prefer direct indexing syntax when possible:
tensor[None, :]Does not support integer indexing or slicing with start/stop
split()#
- helion.language.split(tensor)[source]#
Split the last dimension of a tensor with size two into two separate tensors.
- Parameters:
tensor (
Tensor) – The input tensor whose last dimension has length two.- Return type:
- Returns:
A tuple
(lo, hi)where each tensor has the same shape astensorwithout its last dimension.
See also
join()#
StackTensor#
StackTensor class#
- class helion.language.StackTensor(tensor_like: torch.Tensor, dev_ptrs: torch.Tensor)[source]#
This class should not be instantiated directly. It is the result of hl.stacktensor_like(…). It presents a batch of tensors of the same properties (shape, dtype and stride) but reside at different memory locations virtually stacked together.
StackTensor provides a way to perform parallel memory accesses to multiple tensors with a single subscription.
Core Concept:
Instead of performing separate memory operations on each tensor individually, StackTensor allows you to broadcast a single memory operation (hl.load, hl.store, hl.atomic_add, etc.) to multiple tensor buffers in parallel. This is particularly useful for batch processing scenarios where the same operation needs to be applied to multiple tensors.
Memory Operation Behavior:
Loads: When you index into a StackTensor (e.g., stack_tensor[i]), it performs the same indexing operation on all underlying tensor buffers and returns a new tensor where the results are stacked according to the shape of dev_ptrs.
Stores: When you assign to a StackTensor (e.g., stack_tensor[i] = value), the value tensor is “unstacked” - each slice of the value tensor is written to the respective underlying tensor buffer. This is the reverse operation of loading. (e.g. value[j] is writtent to tensor_j[i]).
Shape Semantics:
The StackTensor’s shape is dev_ptrs.shape + tensor_like.shape, where:
dev_ptrs.shape becomes the stacking dimensions
tensor_like.shape represents the shape of each individual tensor
-
tensor_like:
Tensor# A template host tensor that defines the shape, dtype, and other properties for all tensors in the stack group. Must be a Host tensor (created outside of the device loop).
stacktensor_like#
- helion.language.stacktensor_like(tensor_like, dev_ptrs)[source]#
Creates a StackTensor from a tensor of data pointers (dev_ptrs) pointing to tensors alike residing at different memory locations.
This function creates a StackTensor that allows you to broadcast memory operations to multiple tensor buffers in parallel.
Must be called inside a helion kernel with dev_ptrs as a device tensor and tensor_like as a host tensor.
- Parameters:
tensor_like (
Tensor) – A template host tensor that defines the shape, dtype, and other properties that each buffer in the stack group should have. Must be a host tensor.dev_ptrs (
Tensor) – A tensor containing device pointers (memory addresses) to data buffers. Must be of dtype torch.uint64 and must be a device tensor.
Examples
Basic Load Operation:
@helion.kernel def stack_load(dev_ptrs: torch.Tensor, example: torch.Tensor): for tile in hl.tile(example.size(0)): ptr_tile = dev_ptrs[:] # Shape: [num_tensors] stack_tensor = hl.stack_like(example, ptr_tile) # Load from all tensors simultaneously data = stack_tensor[tile] # Shape: [num_tensors, tile_size] return data
Store Operation:
@helion.kernel def stack_store( dev_ptrs: torch.Tensor, example: torch.Tensor, values: torch.Tensor ): ptr_tile = dev_ptrs[:] # Shape: [num_tensors] stack_tensor = hl.stack_like(example, ptr_tile) # Store values of shape [num_tensors, N] to all tensors in parallel stack_tensor[:] = values # slice values[i, :] goes to tensor i
Usage Setup:
# Create list of tensors to process tensor_list = [torch.randn(16, device="cuda") for _ in range(4)] tensor_ptrs = torch.as_tensor( [p.data_ptr() for p in tensor_list], dtype=torch.uint64, device="cuda" ) result = stack_load(tensor_ptrs, tensor_list[0])
- Return type:
- Returns:
A StackTensor object that broadcasts memory operations to all data buffers pointed to by dev_ptrs.
Reduction Operations#
reduce()#
- helion.language.reduce(combine_fn, input_tensor, dim=None, other=0, keep_dims=False)[source]#
Applies a reduction operation along a specified dimension or all dimensions.
This function is only needed for user-defined combine functions. Standard PyTorch reductions (such as sum, mean, amax, etc.) work directly in Helion without requiring this function.
- Parameters:
combine_fn (
Union[Callable[[Tensor,Tensor],Tensor],Callable[...,tuple[Tensor,...]]]) – A binary function that combines two elements element-wise. Must be associative and commutative for correct results. Can be tensor->tensor or tuple->tuple function.input_tensor (
Tensor|tuple[Tensor,...]) – Input tensor or tuple of tensors to reducedim (
int|None) – The dimension along which to reduce (None for all dimensions)other (
float|tuple[float,...]) – Value for masked/padded elements (default: 0) For tuple inputs, can be tuple of values with same lengthkeep_dims (
bool) – If True, reduced dimensions are retained with size 1
- Returns:
Tensor(s) with reduced dimensions
- Return type:
torch.Tensor or tuple[torch.Tensor, …]
See also
associative_scan(): For prefix operations
Note
combine_fn must be associative and commutative
For standard reductions, use PyTorch functions directly (faster)
Masked elements use the ‘other’ value during reduction
Scan Operations#
associative_scan()#
- helion.language.associative_scan(combine_fn, input_tensor, dim, reverse=False)[source]#
Applies an associative scan operation along a specified dimension.
Computes the prefix scan (cumulative operation) along a dimension using a custom combine function. Unlike
reduce(), this preserves the input shape.- Parameters:
combine_fn (
Union[Callable[[Tensor,Tensor],Tensor],Callable[...,tuple[Tensor,...]]]) – A binary function that combines two elements element-wise. Must be associative for correct results. Can be tensor->tensor or tuple->tuple function.input_tensor (
Tensor|tuple[Tensor,...]) – Input tensor or tuple of tensors to scandim (
int) – The dimension along which to scanreverse (
bool) – If True, performs the scan in reverse order
- Returns:
- Tensor(s) with same shape as input
containing the scan result
- Return type:
torch.Tensor or tuple[torch.Tensor, …]
See also
cumsum()#
- helion.language.cumsum(input_tensor, dim, reverse=False)[source]#
Compute the cumulative sum along a specified dimension.
Equivalent to
hl.associative_scan(torch.add, input_tensor, dim, reverse).- Parameters:
- Returns:
Tensor with same shape as input containing cumulative sum
- Return type:
See also
associative_scan(): For custom scan operationscumprod(): For cumulative productreduce(): For dimension-reducing operations
Note
Output has same shape as input
Reverse=True computes cumsum from right to left
Equivalent to torch.cumsum
cumprod()#
- helion.language.cumprod(input_tensor, dim, reverse=False)[source]#
Compute the cumulative product along a specified dimension.
Equivalent to
hl.associative_scan(torch.mul, input_tensor, dim, reverse).- Parameters:
- Returns:
Tensor with same shape as input containing cumulative product
- Return type:
See also
associative_scan(): For custom scan operationscumsum(): For cumulative sumreduce(): For dimension-reducing operations
Note
Output has same shape as input
Reverse=True computes cumprod from right to left
Equivalent to torch.cumprod
tile_index()#
- helion.language.tile_index(tile)[source]#
Retrieve the index (a 1D tensor containing offsets) of the given tile. This can also be written as: tile.index.
Example usage:
@helion.kernel def arange(length: int, device: torch.device) -> torch.Tensor: out = torch.empty(length, dtype=torch.int32, device=device) for tile in hl.tile(length): out[tile] = tile.index return out
- Parameters:
tile (
TileInterface) –- Return type:
tile_begin()#
tile_end()#
- helion.language.tile_end(tile)[source]#
Retrieve the end offset of the given tile. For the first 0 to N-1 tiles, this is equivalent to tile.begin + tile.block_size. For the last tile, this is the end offset passed to hl.tile(). This can also be written as: tile.end.
- Parameters:
tile (
TileInterface) –- Return type:
tile_block_size()#
tile_id()#
Utilities#
device_print()#
Constexpr Operations#
constexpr()#
- helion.language.constexpr#
alias of
ConstExpr
specialize()#
- helion.language.specialize(value)[source]#
Turn dynamic shapes into compile-time constants. Examples:
channels = hl.specialize(tensor.size(1)) height, width = hl.specialize(tensor.shape[-2:])
- Parameters:
value (
TypeVar(_T)) – The symbolic value or sequence of symbolic values to specialize on.- Return type:
TypeVar(_T)- Returns:
A Python int or a sequence containing only Python ints.
See also
ConstExpr: Create compile-time constants for kernel parameters
Matrix Operations#
dot()#
- helion.language.dot(mat1, mat2, acc=None, out_dtype=None)[source]#
Performs a matrix multiplication of tensors with support for multiple dtypes.
This operation performs matrix multiplication with inputs of various dtypes including float16, bfloat16, float32, int8, and FP8 formats (e4m3fn, e5m2). The computation is performed with appropriate precision based on the input dtypes.
- Parameters:
mat1 (
Tensor) – First matrix (2D or 3D tensor of torch.float16, torch.bfloat16, torch.float32, torch.int8, torch.float8_e4m3fn, or torch.float8_e5m2)mat2 (
Tensor) – Second matrix (2D or 3D tensor of torch.float16, torch.bfloat16, torch.float32, torch.int8, torch.float8_e4m3fn, or torch.float8_e5m2)acc (
Tensor|None) – The accumulator tensor (2D or 3D tensor of torch.float16, torch.float32, or torch.int32). If not None, the result is added to this tensor. If None, a new tensor is created with appropriate dtype based on inputs.out_dtype (
dtype|None) – Optional dtype that controls the output type of the multiplication prior to any accumulation. This maps directly to the Tritontl.dotout_dtypeargument and overrides the default promotion rules when provided.
- Return type:
- Returns:
Result of matrix multiplication. If acc is provided, returns acc + (mat1 @ mat2). Otherwise returns (mat1 @ mat2) with promoted dtype.
Example
>>> # FP8 example >>> a = torch.randn(32, 64, device="cuda").to(torch.float8_e4m3fn) >>> b = torch.randn(64, 128, device="cuda").to(torch.float8_e4m3fn) >>> c = torch.zeros(32, 128, device="cuda", dtype=torch.float32) >>> result = hl.dot(a, b, acc=c) # result is c + (a @ b)
>>> # Float16 example >>> a = torch.randn(32, 64, device="cuda", dtype=torch.float16) >>> b = torch.randn(64, 128, device="cuda", dtype=torch.float16) >>> result = hl.dot(a, b) # result dtype will be torch.float16
>>> # Int8 example >>> a = torch.randint(-128, 127, (32, 64), device="cuda", dtype=torch.int8) >>> b = torch.randint(-128, 127, (64, 128), device="cuda", dtype=torch.int8) >>> acc = torch.zeros(32, 128, device="cuda", dtype=torch.int32) >>> result = hl.dot(a, b, acc=acc) # int8 x int8 -> int32
dot_scaled()#
- helion.language.dot_scaled(mat1, mat1_scale, mat1_format, mat2, mat2_scale, mat2_format, acc=None, out_dtype=None)[source]#
Performs a block-scaled matrix multiplication using Triton’s tl.dot_scaled.
This operation performs matrix multiplication with block-scaled inputs in formats such as FP4 (e2m1), FP8 (e4m3, e5m2), BF16, and FP16. Each input tensor has an associated scale factor tensor and format string.
- Parameters:
mat1 (
Tensor) – First matrix (2D tensor of packed data)mat1_scale (
Tensor) – Scale factors for mat1 (2D tensor)mat1_format (
str) – Format string for mat1 (one of “e2m1”, “e4m3”, “e5m2”, “bf16”, “fp16”)mat2 (
Tensor) – Second matrix (2D tensor of packed data)mat2_scale (
Tensor) – Scale factors for mat2 (2D tensor)mat2_format (
str) – Format string for mat2 (one of “e2m1”, “e4m3”, “e5m2”, “bf16”, “fp16”)acc (
Tensor|None) – Optional accumulator tensor (2D, float32 or float16)out_dtype (
dtype|None) – Optional output dtype for the multiplication
- Return type:
- Returns:
Result of block-scaled matrix multiplication.
dot_scaled() performs block-scaled matrix multiplication using low-precision formats (e.g., e2m1, e4m3, e5m2). Each input matrix has an associated per-block scale tensor and format string. This maps to Triton’s tl.dot_scaled for hardware-accelerated scaled dot products on supported architectures.