from __future__ import annotations
import ast
import builtins
import inspect
import itertools
from itertools import starmap
from typing import TYPE_CHECKING
from typing import Iterator
from typing import Sequence
from typing import TypeGuard
from typing import cast
from typing import overload
import torch
from torch._inductor.runtime.triton_heuristics import (
get_max_y_grid, # type: ignore[import-untyped]
)
from triton import cdiv
import triton.language
from .. import exc
from .._compiler.ast_extension import ExtendedAST
from .._compiler.ast_extension import LoopType
from .._compiler.ast_extension import expr_from_string
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.type_propagation import GridIndexType
from .._compiler.type_propagation import IterType
from .._compiler.type_propagation import LiteralType
from .._compiler.type_propagation import Origin
from .._compiler.type_propagation import SequenceType
from .._compiler.type_propagation import TileIndexType
from .._compiler.type_propagation import TypeInfo
from .._compiler.variable_origin import GetItemOrigin
from ..autotuner.config_spec import ConfigSpec
from ..autotuner.config_spec import FlattenLoopSpec
from ..autotuner.config_spec import L2GroupingSpec
from ..autotuner.config_spec import LoopOrderSpec
from ..autotuner.config_spec import RangeFlattenSpec
from ..autotuner.config_spec import RangeMultiBufferSpec
from ..autotuner.config_spec import RangeNumStagesSpec
from ..autotuner.config_spec import RangeUnrollFactorSpec
from ..autotuner.config_spec import RangeWarpSpecializeSpec
from ..autotuner.config_spec import StaticRangeSpec
from . import _decorators
from .ref_tile import RefTile
from .tile_proxy import Tile
if TYPE_CHECKING:
from collections.abc import Sequence
from .._compiler.inductor_lowering import CodegenState
__all__ = ["grid", "static_range", "tile"]
@overload
@_decorators.api(
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
)
def tile(
begin_or_end: int | torch.Tensor,
end_or_none: int | torch.Tensor | None = None,
/,
block_size: object = None,
) -> Iterator[Tile]: ...
@overload
@_decorators.api(
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
)
def tile(
begin_or_end: Sequence[int | torch.Tensor],
end_or_none: Sequence[int | torch.Tensor] | None = None,
/,
block_size: object = None,
) -> Iterator[Sequence[Tile]]: ...
[docs]
@_decorators.api(
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
)
def tile(
begin_or_end: int | torch.Tensor | Sequence[int | torch.Tensor],
end_or_none: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None,
/,
block_size: object = None,
) -> Iterator[Tile] | Iterator[Sequence[Tile]]:
"""
Break up an iteration space defined by a size or sequence of sizes into tiles.
The generated tiles can flatten the iteration space into the product of the sizes,
perform multidimensional tiling, swizzle the indices for cache locality, reorder
dimensions, etc. The only invariant is that every index in the range of the given
sizes is covered exactly once.
The exact tiling strategy is determined by a Config object, typically created
through autotuning.
If used at the top level of a function, this becomes the grid of the kernel.
Otherwise, it becomes a loop in the output kernel.
The key difference from :func:`~helion.language.grid` is that ``tile`` gives you
``Tile`` objects that load a slice of elements, while ``grid`` gives you scalar
integer indices. It is recommended to use ``tile`` in most cases, since it allows
more choices in autotuning.
Args:
begin_or_end: If 2+ positional args provided, the start of iteration space.
Otherwise, the end of iteration space.
end_or_none: If 2+ positional args provided, the end of iteration space.
block_size: Fixed block size (overrides autotuning) or None for autotuned size
Returns:
Iterator[Tile] or Iterator[Sequence[Tile]]: Iterator over tile objects
Examples:
One dimensional tiling:
.. code-block:: python
@helion.kernel
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
result = torch.zeros_like(x)
for tile in hl.tile(x.size(0)):
# tile processes multiple elements at once
result[tile] = x[tile] + y[tile]
return result
Multi-dimensional tiling:
.. code-block:: python
@helion.kernel()
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
m, k = x.size()
k, n = y.size()
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
out[tile_m, tile_n] = acc
return out
Fixed block size:
.. code-block:: python
@helion.kernel
def process_with_fixed_block(x: torch.Tensor) -> torch.Tensor:
result = torch.zeros_like(x)
for tile in hl.tile(x.size(0), block_size=64):
# Process with fixed block size of 64
result[tile] = x[tile] * 2
return result
Using tile properties:
.. code-block:: python
@helion.kernel
def tile_info_example(x: torch.Tensor) -> torch.Tensor:
result = torch.zeros([x.size(0)], dtype=x.dtype, device=x.device)
for tile in hl.tile(x.size(0)):
# Access tile properties
start = tile.begin
end = tile.end
size = tile.block_size
indices = tile.index # [start, start+1, ..., end-1]
# Use in computation
result[tile] = x[tile] + indices
return result
See Also:
- :func:`~helion.language.grid`: For explicit control over the launch grid
- :func:`~helion.language.tile_index`: For getting tile indices
- :func:`~helion.language.register_block_size`: For registering block sizes
Note:
Similar to ``range()`` with multiple forms:
* tile(end) iterates 0 to end-1, autotuned block_size
* tile(begin, end) iterates begin to end-1, autotuned block_size
* tile(begin, end, block_size) iterates begin to end-1, fixed block_size
* tile(end, block_size=block_size) iterates 0 to end-1, fixed block_size
Block sizes can be registered for autotuning explicitly with :func:`~helion.language.register_block_size`
and passed as the ``block_size`` argument if one needs two loops to use the same block size. Passing
``block_size=None`` is equivalent to calling register_block_size.
Use ``tile`` in most cases. Use ``grid`` when you need explicit control over the launch grid.
"""
raise exc.NotInsideKernel
def _not_none(value: TypeInfo | None) -> TypeGuard[TypeInfo]:
return not (value is None or value.is_literal() and value.as_literal() is None)
def _to_proxy(value: TypeInfo) -> object:
try:
return value.proxy()
except NotImplementedError:
raise exc.IncorrectTileUsage(
f"expected IntLike or list[IntLike], got {value!s}"
) from None
def _check_matching(a: object, b: object) -> None:
"""Check that the types of `a` and `b` match for use in hl.tile."""
if isinstance(a, (list, tuple)):
if not isinstance(b, (list, tuple)):
raise exc.IncorrectTileUsage(
f"expected type hl.tile args to match, got {type(a)} and {type(b)}"
)
if len(a) != len(b):
raise exc.IncorrectTileUsage(
f"expected dims for hl.tile args to match, got {len(a)} and {len(b)}"
)
elif isinstance(a, (int, torch.SymInt, torch.Tensor)):
if not isinstance(b, (int, torch.SymInt, torch.Tensor)):
raise exc.IncorrectTileUsage(
f"expected type hl.tile args to match, got {type(a)} and {type(b)}"
)
else:
raise exc.IncorrectTileUsage(
f"expected type hl.tile args to be IntLike or list[IntLike], got {type(a)}"
)
def _allow_static_range(begin: object, end: object, step: object) -> bool:
"""
Only enable tl.stagic_range when:
1) The ranges are statically known at compile time.
2) The range is small enough to be unrolled without blowing up the compile time.
"""
if begin is None:
begin = 0
elif not isinstance(begin, int):
return False
if not isinstance(end, int):
return False
if step is None:
count = end - begin
elif isinstance(step, int):
count = cdiv(begin - end, step)
else:
return False
# Unrolling a long static range leads to compile timeouts
return count <= 8
def _normalize_begin_end(
begin_or_end: TypeInfo,
end_or_none: TypeInfo | None,
origin: Origin,
) -> tuple[TypeInfo, TypeInfo]:
"""Fill in defaults for begin if it is not provided."""
if _not_none(end_or_none):
begin = begin_or_end
end = end_or_none
else:
try:
begin = TypeInfo.from_example(begin_or_end.tree_map(lambda n: 0), origin)
except NotImplementedError:
raise exc.TypeInferenceError(
f"expected IntLike or list[IntLike], got {begin_or_end!s}"
) from None
end = begin_or_end
return begin, end
@_decorators.type_propagation(tile)
def _(
begin_or_end: TypeInfo,
end_or_none: TypeInfo | None = None,
/,
block_size: TypeInfo | None = None,
*,
origin: Origin,
) -> TypeInfo:
parent = ExtendedAST.current()[-2]
if not isinstance(parent, ast.For):
raise exc.LoopFunctionNotInFor("tile")
begin, end = _normalize_begin_end(begin_or_end, end_or_none, origin=origin)
proxy_begin = _to_proxy(begin)
proxy_end = _to_proxy(end)
_check_matching(proxy_begin, proxy_end)
if _not_none(block_size):
proxy_block_size = Tile._tiles_to_sizes(_to_proxy(block_size))
_check_matching(proxy_end, proxy_block_size)
else:
proxy_block_size = begin.tree_map(lambda n: None)
if unpack := not isinstance(proxy_end, (list, tuple)):
begin_list: list[int | torch.SymInt | torch.Tensor] = [
cast("int | torch.SymInt | torch.Tensor", proxy_begin)
]
end_list: list[int | torch.SymInt | torch.Tensor] = [
cast("int | torch.SymInt | torch.Tensor", proxy_end)
]
block_size_list: list[int | torch.SymInt | torch.Tensor | None] = [
cast("int | torch.SymInt | torch.Tensor | None", proxy_block_size)
]
else:
begin_list = cast("list[int | torch.SymInt | torch.Tensor]", proxy_begin)
end_list = cast("list[int | torch.SymInt | torch.Tensor]", proxy_end)
block_size_list = cast(
"list[int | torch.SymInt | torch.Tensor | None]", proxy_block_size
)
results = []
for begin_part, end_part, bs in zip(
begin_list,
end_list,
block_size_list,
strict=True,
):
size = end_part - begin_part # type: ignore[operator]
if isinstance(size, torch.Tensor):
size = None # data dependent size
if bs is None:
results.append(TileIndexType.allocate(size, origin))
elif isinstance(bs, int):
results.append(TileIndexType.allocate(size, origin, bs))
elif isinstance(bs, torch.SymInt):
from .._compiler.compile_environment import CompileEnvironment
index = CompileEnvironment.current().get_block_id(bs)
if index is None:
results.append(TileIndexType.allocate(size, origin, bs))
else:
results.append(TileIndexType(origin=origin, block_id=index))
CompileEnvironment.current().block_sizes[index].mark_alternate_size(
size
)
_add_config_choices(
[x.block_id for x in results],
is_tile=True,
has_begin=not all((isinstance(x, int) and x == 0) for x in begin_list),
allow_static_ranges=[
*starmap(
_allow_static_range,
zip(begin_list, end_list, block_size_list, strict=True),
)
],
)
if unpack:
(result,) = results
else:
result = SequenceType(origin, tuple(results))
return IterType(origin, result)
def _add_config_choices(
block_ids: list[int],
*,
is_tile: bool = False,
has_begin: bool = False,
allow_static_ranges: list[bool] | None = None,
) -> None:
config_spec = CompileEnvironment.current().config_spec
if len(block_ids) > 1:
# Add loop reordering choice
config_spec.loop_orders.append(LoopOrderSpec(block_ids))
if is_tile and not has_begin:
config_spec.flatten_loops.append(FlattenLoopSpec(block_ids))
is_grid = all(x._loop_type != LoopType.GRID for x in ExtendedAST.current())
if is_grid:
# Track which block_ids come from grids
existing_ids = {*config_spec.grid_block_ids}
config_spec.grid_block_ids.extend(
[x for x in block_ids if x not in existing_ids]
)
if len(block_ids) >= 2:
# L2 grouping now supports 3D+ grids by applying to innermost 2 dimensions
config_spec.l2_groupings.append(L2GroupingSpec(block_ids))
if not _allow_use_yz_grid(config_spec, block_ids):
config_spec.disallow_pid_type("xyz")
# just one set of choices for when we have persistent kernel loop
_add_config_range_choice(block_ids)
else:
if allow_static_ranges is None:
allow_static_ranges = [False] * len(block_ids)
for block_id, allow_static_range in zip(
block_ids, allow_static_ranges, strict=True
):
_add_config_range_choice([block_id], allow_static_range=allow_static_range)
def _add_config_range_choice(
block_ids: list[int], allow_static_range: bool = False
) -> None:
params = inspect.signature(triton.language.range).parameters
config_spec = CompileEnvironment.current().config_spec
if allow_static_range:
config_spec.static_ranges.append(StaticRangeSpec(block_ids))
if "loop_unroll_factor" in params:
config_spec.range_unroll_factors.append(RangeUnrollFactorSpec(block_ids))
if _supports_warp_specialize() and "warp_specialize" in params:
config_spec.range_warp_specialize.append(RangeWarpSpecializeSpec(block_ids))
if "num_stages" in params:
config_spec.range_num_stages.append(RangeNumStagesSpec(block_ids))
if "disallow_acc_multi_buffer" in params:
config_spec.range_multi_buffers.append(RangeMultiBufferSpec(block_ids))
if "flatten" in params:
config_spec.range_flattens.append(RangeFlattenSpec(block_ids))
def _supports_warp_specialize() -> bool:
"""Check if the current device supports warp specialization."""
env = CompileEnvironment.current()
if env.device.type != "cuda" or not env.settings.allow_warp_specialize:
return False
return torch.cuda.get_device_capability() >= (12, 0)
def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool:
"""Check if the yz grid is allowed based on the block sizes."""
if not (1 < len(block_ids) <= 3):
return False
hint = 1
try:
for block_id in block_ids:
hint *= config_spec.block_sizes.block_id_lookup(block_id).size_hint
except KeyError:
return False
return hint < get_max_y_grid()
@_decorators.codegen(tile)
def _(state: CodegenState) -> ast.AST:
return _codegen_loop_helper(state)
def _to_int(value: int | torch.Tensor | None) -> int | None:
"""Convert tensor values to int."""
if value is None:
return None
if isinstance(value, torch.Tensor):
return int(value.item())
return int(value)
def _normalize_to_list(
value: int | torch.Tensor | list[int | torch.Tensor],
) -> list[int | torch.Tensor]:
"""Convert single values to lists for uniform handling."""
if isinstance(value, (list, tuple)):
return list(value)
return [value]
def _normalize_begin_end_ref(
begin_or_end: int | torch.Tensor | list[int | torch.Tensor],
end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None,
) -> tuple[
int | torch.Tensor | list[int | torch.Tensor],
int | torch.Tensor | list[int | torch.Tensor],
]:
if end_or_none is not None:
# Two positional args: begin_or_end is begin, end_or_none is end
return begin_or_end, end_or_none
# One positional arg: begin_or_end is end, begin defaults to 0
end = begin_or_end
if isinstance(end, (list, tuple)):
begin = cast("int | torch.Tensor | list[int | torch.Tensor]", [0] * len(end))
else:
begin = 0
return begin, end
@_decorators.ref(tile)
def _(
begin_or_end: int | torch.Tensor | list[int | torch.Tensor],
end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None,
block_size: int | torch.Tensor | list[int | torch.Tensor] | None = None,
) -> Iterator[RefTile | tuple[RefTile, ...]]:
# Step 1: Normalize begin and end values
begin, end = _normalize_begin_end_ref(begin_or_end, end_or_none)
# Step 2: Convert to lists and then to ints
begin_list = _normalize_to_list(begin)
end_list = _normalize_to_list(end)
begin_ints = [_to_int(b) for b in begin_list]
end_ints = [_to_int(e) for e in end_list]
# Step 3: Determine block sizes - always return full dimension size, ignoring block_size parameter
block_size_list = []
for b, e in zip(begin_ints, end_ints, strict=True):
assert b is not None and e is not None
block_size_list.append(e - b)
# Step 4: Determine return type
# Return single tiles if input was not a list
return_single = not isinstance(begin, list) and not isinstance(end, list)
# Step 5: Generate tiles
# Build tiles for each dimension
tiles = []
for b, e in zip(begin_ints, end_ints, strict=True):
assert b is not None and e is not None
if b != e:
# Only create tile if range is non-empty
tiles.append(RefTile(b, e, e - b))
# Yield result based on return type
if tiles: # Only yield if we have at least one non-empty dimension
if return_single:
# Single dimension case - yield the tile directly
assert len(tiles) == 1
yield tiles[0]
else:
# Multi-dimensional case - yield as tuple
yield tuple(tiles)
def _codegen_loop_helper(
state: CodegenState,
) -> ast.AST:
"""Helper method for codegen of tile and grid decorators."""
for_loop = ExtendedAST.current()[-2]
loop_type = for_loop._loop_type
type_info = ExtendedAST.current()[-1]._type_info
assert isinstance(for_loop, ast.For)
assert isinstance(type_info, IterType)
if isinstance(type_info.inner, SequenceType):
indices_raw = type_info.inner.unpack()
else:
indices_raw = [type_info.inner]
assert all(isinstance(t, (TileIndexType, GridIndexType)) for t in indices_raw)
indices = cast("list[TileIndexType | GridIndexType]", indices_raw)
if loop_type == LoopType.GRID:
env = CompileEnvironment.current()
env.loop_dependency_checker.register_loop(for_loop)
block_ids = [t.block_id for t in indices]
state.tile_strategy.codegen_grid(state, block_ids)
return expr_from_string("None")
raise AssertionError(f"Expected loop type: {loop_type}")
@overload
@_decorators.device_func_replacement(builtins.range)
@_decorators.api(
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
)
def grid(
begin_or_end: int | torch.Tensor,
end_or_none: int | torch.Tensor | None = None,
/,
step: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None,
) -> Iterator[torch.SymInt]: ...
@overload
@_decorators.device_func_replacement(builtins.range)
@_decorators.api(
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
)
def grid(
begin_or_end: Sequence[int | torch.Tensor],
end_or_none: Sequence[int | torch.Tensor] | None = None,
/,
step: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None,
) -> Iterator[Sequence[torch.SymInt]]: ...
[docs]
@_decorators.device_func_replacement(builtins.range)
@_decorators.api(
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
)
def grid(
begin_or_end: int | torch.Tensor | Sequence[int | torch.Tensor],
end_or_none: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None,
/,
step: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None,
) -> Iterator[torch.SymInt] | Iterator[Sequence[torch.SymInt]]: # type: ignore[type-arg]
"""Iterate over individual indices of the given iteration space.
The key difference from :func:`~helion.language.tile` is that ``grid`` gives you
scalar integer indices (``torch.SymInt``), while ``tile`` gives you ``Tile`` objects
that load a slice of elements. Use ``tile`` in most cases. Use ``grid`` when you need
explicit control over the launch grid or when processing one element at a time.
Semantics are equivalent to:
.. code-block:: python
for i in hl.tile(...):
# i is a Tile object, accesses multiple elements
data = tensor[i] # loads slice of elements (1D tensor)
vs:
.. code-block:: python
for i in hl.grid(...):
# i is a scalar index, accesses single element
data = tensor[i] # loads single element (0D scalar)
When used at the top level of a function, this becomes the grid of the kernel.
Otherwise, it becomes a loop in the output kernel.
Args:
begin_or_end: If 2+ positional args provided, the start of iteration space.
Otherwise, the end of iteration space.
end_or_none: If 2+ positional args provided, the end of iteration space.
step: Step size for iteration (default: 1)
Returns:
Iterator[torch.SymInt] or Iterator[Sequence[torch.SymInt]]: Iterator over scalar indices
See Also:
- :func:`~helion.language.tile`: For processing multiple elements at once
- :func:`~helion.language.tile_index`: For getting tile indices
- :func:`~helion.language.arange`: For creating index sequences
Note:
Similar to ``range()`` with multiple forms:
* grid(end) iterates from 0 to end-1, step 1
* grid(begin, end) iterates from begin to end-1, step 1
* grid(begin, end, step) iterates from begin to end-1, given step
* grid(end, step=step) iterates from 0 to end-1, given step
Use ``tile`` in most cases. Use ``grid`` when you need explicit control over the launch grid.
"""
raise exc.NotInsideKernel
@_decorators.type_propagation(grid)
def _(
begin_or_end: TypeInfo,
end_or_none: TypeInfo | None = None,
/,
step: TypeInfo | None = None,
*,
origin: Origin,
) -> TypeInfo:
parent = ExtendedAST.current()[-2]
if not isinstance(parent, ast.For):
raise exc.LoopFunctionNotInFor("grid")
begin, end = _normalize_begin_end(begin_or_end, end_or_none, origin=origin)
proxy_begin = _to_proxy(begin)
proxy_end = _to_proxy(end)
_check_matching(proxy_begin, proxy_end)
if _not_none(step):
proxy_step = Tile._tiles_to_sizes(_to_proxy(step))
_check_matching(proxy_end, proxy_step)
else:
proxy_step = begin.tree_map(lambda n: None)
if unpack := not isinstance(proxy_end, (list, tuple)):
begin_list: list[int | torch.SymInt | torch.Tensor] = [
cast("int | torch.SymInt | torch.Tensor", proxy_begin)
]
end_list: list[int | torch.SymInt | torch.Tensor] = [
cast("int | torch.SymInt | torch.Tensor", proxy_end)
]
step_list: list[int | torch.SymInt | torch.Tensor | None] = [
cast("int | torch.SymInt | torch.Tensor | None", proxy_step)
]
else:
begin_list = cast("list[int | torch.SymInt | torch.Tensor]", proxy_begin)
end_list = cast("list[int | torch.SymInt | torch.Tensor]", proxy_end)
step_list = cast("list[int | torch.SymInt | torch.Tensor | None]", proxy_step)
results = []
for begin_part, end_part, step_part in zip(
begin_list,
end_list,
step_list,
strict=True,
):
size = end_part - begin_part # type: ignore[operator]
if isinstance(size, torch.Tensor):
size = None # data dependent size
if step_part is None:
step_part = 1
results.append(GridIndexType.allocate(size, origin, step_part)) # pyright: ignore[reportArgumentType]
_add_config_choices(
[x.block_id for x in results],
is_tile=False,
has_begin=not all((isinstance(x, int) and x == 0) for x in begin_list),
allow_static_ranges=[
*starmap(
_allow_static_range, zip(begin_list, end_list, step_list, strict=True)
)
],
)
if unpack:
(result,) = results
else:
result = SequenceType(origin, tuple(results))
return IterType(origin, result)
@_decorators.codegen(grid)
def _(state: CodegenState) -> ast.AST:
return _codegen_loop_helper(state)
def _extract_step_value(
step: int | torch.Tensor | Sequence[int | torch.Tensor] | None,
index: int = 0,
) -> int | torch.Tensor | None:
"""Extract step value from various input formats."""
if step is None:
return None
if isinstance(step, (list, tuple)):
# Extract from sequence at index
if index < len(step):
val = step[index]
# Type narrow to valid types for _to_int
if isinstance(val, (int, torch.Tensor, type(None))):
return val
return None
# Single value - type narrow to valid types
if isinstance(step, (int, torch.Tensor)):
return step
return None
def _normalize_step_values(
step: int | torch.Tensor | Sequence[int | torch.Tensor] | None,
num_dims: int,
) -> list[int | None]:
"""Normalize step values to a list of ints for each dimension."""
if step is None:
return [None] * num_dims
assert isinstance(step, (list, tuple))
step_ints = []
for i in range(num_dims):
step_val = _extract_step_value(step, i)
step_ints.append(_to_int(step_val))
return step_ints
def _create_ranges(
begin_ints: list[int | None],
end_ints: list[int | None],
step_ints: list[int | None] | None = None,
) -> list[range]:
"""Create range objects from begin, end, and optional step values."""
ranges = []
if step_ints is None:
# No steps provided - use default ranges
for b, e in zip(begin_ints, end_ints, strict=True):
assert b is not None and e is not None
ranges.append(range(b, e))
else:
# Steps provided - use them where available
for b, e, s in zip(begin_ints, end_ints, step_ints, strict=True):
assert b is not None and e is not None
if s is not None:
ranges.append(range(b, e, s))
else:
ranges.append(range(b, e))
return ranges
@_decorators.ref(grid)
def _(
begin_or_end: int | torch.Tensor | list[int | torch.Tensor],
end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None,
step: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None,
) -> range | Iterator[tuple[int, ...]]:
# Step 1: Normalize begin and end values
begin, end = _normalize_begin_end_ref(begin_or_end, end_or_none)
# Step 2: Handle single dimension case
if not isinstance(begin, (list, tuple)):
begin_int = _to_int(begin)
assert not isinstance(end, (list, tuple))
end_int = _to_int(end)
assert begin_int is not None and end_int is not None
# Extract step for single dimension
step_val = _extract_step_value(step, 0)
step_int = _to_int(step_val)
if step_int is not None:
return range(begin_int, end_int, step_int)
return range(begin_int, end_int)
# Step 3: Handle multi-dimensional case
assert isinstance(end, (list, tuple))
begin_ints = [_to_int(b) for b in begin]
end_ints = [_to_int(e) for e in end]
# Step 4: Normalize step values
step_ints = (
_normalize_step_values(step, len(begin_ints)) if step is not None else None
)
# Step 5: Create ranges and return product
ranges = _create_ranges(begin_ints, end_ints, step_ints)
return itertools.product(*ranges)
@_decorators.device_func_replacement(builtins.zip)
@_decorators.api(is_device_only=True, cache_type=True)
def _zip_replacement(
*args: tuple[object, ...] | list[object],
strict: bool = False,
) -> tuple[tuple[object, ...], ...]:
"""
Device replacement for zip() that returns tuples for unrolling.
This replacement enables zip() to work in device kernels by converting
the zip result to a tuple of tuples, which can then be unrolled by the
existing tuple iteration logic.
Args:
*args: Sequences to zip together
Returns:
Tuple of tuples containing zipped elements
Examples:
.. code-block:: python
@helion.kernel
def kernel_with_zip(a_tensors, b_tensors):
for a, b in zip(a_tensors, b_tensors):
# This gets unrolled at compile time
result += a * b
"""
raise exc.NotInsideKernel
@_decorators.type_propagation(_zip_replacement)
def _(
*args: TypeInfo,
origin: Origin,
**kwargs: object,
) -> TypeInfo:
"""Type propagation for zip replacement that preserves tensor types."""
# Accept but ignore the strict keyword argument
if not args:
return SequenceType(origin, ())
# Convert all arguments to SequenceType
sequences = []
for arg in args:
if not isinstance(arg, SequenceType):
raise exc.TypeInferenceError(
f"zip() argument must be a sequence, got {arg}"
)
sequences.append(arg.unpack())
# Check all sequences have the same length
length = 0
if sequences:
length = len(sequences[0])
for i, seq in enumerate(sequences[1:], 1):
if len(seq) != length:
raise exc.TypeInferenceError(
f"zip() argument {i} has length {len(seq)}, expected {length}"
)
# Build result as tuple of tuples, preserving existing TypeInfo objects
result_elements = []
for i in range(length):
# Create a tuple containing the i-th element from each sequence
tuple_elements = tuple(seq[i] for seq in sequences)
tuple_type = SequenceType(GetItemOrigin(origin, i), tuple_elements)
result_elements.append(tuple_type)
return SequenceType(origin, tuple(result_elements))
@_decorators.register_to_device_ir(_zip_replacement)
def _(
tracer: object,
*flat_args: object,
) -> object:
"""Device IR handler for zip - returns the zipped result for unrolling."""
# flat_args contains the prepared arguments: (tensor_sequences, strict_value)
if not flat_args:
return ()
# Extract sequences and strict parameter
if len(flat_args) == 2:
sequences = flat_args[0] # This should be the tuple of sequences
strict = flat_args[1] # This should be the strict parameter
assert isinstance(strict, bool)
else:
assert len(flat_args) == 1
sequences = flat_args[0]
strict = False
return [*builtins.zip(*sequences, strict=strict)] # type: ignore[arg-type]
@_decorators.device_func_replacement(builtins.enumerate)
@_decorators.api(is_device_only=True, cache_type=True)
def _enumerate_replacement(
iterable: tuple[object, ...] | list[object],
start: int = 0,
) -> tuple[tuple[int, object], ...]:
"""
Device replacement for enumerate() that returns tuples for unrolling.
This replacement enables enumerate() to work in device kernels by converting
the enumerate result to a tuple of (index, value) tuples, which can then be
unrolled by the existing tuple iteration logic.
Args:
iterable: Sequence to enumerate
start: Starting value for the counter (default: 0)
Returns:
Tuple of (index, value) tuples
"""
raise exc.NotInsideKernel
@_decorators.type_propagation(_enumerate_replacement)
def _(
iterable: TypeInfo,
start: TypeInfo | None = None,
*,
origin: Origin,
) -> TypeInfo:
"""Type propagation for enumerate replacement that preserves tensor types."""
if not isinstance(iterable, SequenceType):
raise exc.TypeInferenceError(
f"enumerate() argument must be a sequence, got {iterable}"
)
# Get the start value
start_value = 0
if start is not None and start.is_literal():
start_val = start.as_literal()
if isinstance(start_val, int):
start_value = start_val
# Build result as tuple of (index, value) tuples
sequence_elements = iterable.unpack()
result_elements = []
for i, element in enumerate(sequence_elements):
# Create (index, value) tuple
index_literal = LiteralType(origin, start_value + i)
tuple_elements = (index_literal, element)
tuple_type = SequenceType(GetItemOrigin(origin, i), tuple_elements)
result_elements.append(tuple_type)
return SequenceType(origin, tuple(result_elements))
@_decorators.register_to_device_ir(_enumerate_replacement)
def _(
tracer: object,
*flat_args: object,
) -> object:
"""Device IR handler for enumerate - returns the enumerated result for unrolling."""
if len(flat_args) == 2:
iterable = flat_args[0]
start = flat_args[1]
assert isinstance(start, int)
else:
assert len(flat_args) == 1
iterable = flat_args[0]
start = 0
return [*builtins.enumerate(iterable, start=start)] # type: ignore[arg-type]
@_decorators.api(is_device_only=True, cache_type=True)
def static_range(
begin_or_end: int,
end_or_none: int | None = None,
/,
step: int = 1,
) -> Iterator[int]:
"""
Create a range that gets unrolled at compile time by iterating over constant integer values.
This function is similar to Python's built-in range(), but it generates a sequence
of integer constants that triggers loop unrolling behavior in Helion kernels. The loop
is completely unrolled at compile time, with each iteration becoming separate
instructions in the generated code.
Args:
begin_or_end: If 2+ positional args provided, the start of range (integer).
Otherwise, the end of range (integer).
end_or_none: If 2+ positional args provided, the end of range (integer).
step: Step size for iteration (integer, default: 1)
Returns:
Iterator[int]: Iterator over constant integer values
Examples:
Simple unrolled loop:
.. code-block:: python
@helion.kernel
def unrolled_example(x: torch.Tensor) -> torch.Tensor:
result = torch.zeros_like(x)
for tile in hl.tile(x.size(0)):
acc = torch.zeros([tile], dtype=x.dtype, device=x.device)
# This loop gets completely unrolled
for i in hl.static_range(3):
acc += x[tile] * i
result[tile] = acc
return result
Range with start and step:
.. code-block:: python
@helion.kernel
def kernel_stepped_unroll(x: torch.Tensor) -> torch.Tensor:
result = torch.zeros_like(x)
for tile in hl.tile(x.size(0)):
acc = torch.zeros([tile], dtype=x.dtype, device=x.device)
# Unroll loop from 2 to 8 with step 2: [2, 4, 6]
for i in hl.static_range(2, 8, 2):
acc += x[tile] * i
result[tile] = acc
return result
Note:
- Only constant integer values are supported
- The range must be small enough to avoid compilation timeouts
- Each iteration becomes separate instructions in the generated Triton code
- Use for small, fixed iteration counts where unrolling is beneficial
"""
raise exc.NotInsideKernel
@_decorators.register_fake(static_range)
def _(
begin_or_end: int,
end_or_none: int | None = None,
/,
step: int = 1,
) -> tuple[int, ...]:
"""Fake function for static_range - validates integer constants and returns tuple(range(...))."""
# Validate that inputs are compile-time constants
if end_or_none is not None:
begin_val = begin_or_end
end_val = end_or_none
else:
begin_val = 0
end_val = begin_or_end
if (
not isinstance(begin_val, int)
or not isinstance(end_val, int)
or not isinstance(step, int)
):
raise exc.TypeInferenceError("static_range requires constant integer arguments")
# Return tuple(range(...)) which will trigger existing tuple/list unrolling
return tuple(range(begin_val, end_val, step))
@_decorators.ref(static_range)
def _(
begin_or_end: int,
end_or_none: int | None = None,
step: int = 1,
) -> range:
if end_or_none is not None:
return range(begin_or_end, end_or_none, step)
return range(begin_or_end)