Config#
The Config class represents kernel optimization parameters that control how Helion kernels are compiled and executed.
- class helion.Config(*, block_sizes=None, loop_orders=None, flatten_loops=None, l2_groupings=None, reduction_loops=None, range_unroll_factors=None, range_warp_specializes=None, range_num_stages=None, range_multi_buffers=None, range_flattens=None, static_ranges=None, load_eviction_policies=None, num_warps=None, num_stages=None, pid_type=None, indexing=None, **kwargs)[source]#
-
- Parameters:
load_eviction_policies (
list[Literal['','first','last']] |None) –pid_type (
Optional[Literal['flat','xyz','persistent_blocked','persistent_interleaved']]) –indexing (
Union[Literal['pointer','tensor_descriptor','block_ptr'],list[Literal['pointer','tensor_descriptor','block_ptr']],None]) –kwargs (
object) –
- __init__(*, block_sizes=None, loop_orders=None, flatten_loops=None, l2_groupings=None, reduction_loops=None, range_unroll_factors=None, range_warp_specializes=None, range_num_stages=None, range_multi_buffers=None, range_flattens=None, static_ranges=None, load_eviction_policies=None, num_warps=None, num_stages=None, pid_type=None, indexing=None, **kwargs)[source]#
Initialize a Config object.
- Parameters:
block_sizes (
list[int] |None) – Controls tile sizes for hl.tile invocations.loop_orders (
list[list[int]] |None) – Permutes iteration order of tiles.l2_groupings (
list[int] |None) – Reorders program IDs for L2 cache locality.reduction_loops (
list[int|None] |None) – Configures reduction loop behavior.range_unroll_factors (
list[int] |None) – Loop unroll factors for tl.range calls.range_warp_specializes (
list[bool|None] |None) – Warp specialization for tl.range calls.range_num_stages (
list[int] |None) – Number of stages for tl.range calls.range_multi_buffers (
list[bool|None] |None) – Controls disallow_acc_multi_buffer for tl.range calls.range_flattens (
list[bool|None] |None) – Controls flatten parameter for tl.range calls.static_ranges (
list[bool] |None) – Whether to use tl.static_range instead tl.range.load_eviction_policies (
list[Literal['','first','last']] |None) – Eviction policies for load operations (“”, “first”, “last”).num_stages (
int|None) – Number of stages for software pipelining.pid_type (
Optional[Literal['flat','xyz','persistent_blocked','persistent_interleaved']]) – Program ID type strategy (“flat”, “xyz”, “persistent_blocked”, “persistent_interleaved”).indexing (
Union[Literal['pointer','tensor_descriptor','block_ptr'],list[Literal['pointer','tensor_descriptor','block_ptr']],None]) –Indexing strategy for load and store operations. Can be: - A single strategy string (all loads/stores use this strategy):
indexing=”block_ptr” # backward compatible
A list of strategies (one per load/store operation, must specify all): indexing=[“pointer”, “block_ptr”, “tensor_descriptor”]
Empty/omitted (all loads/stores default to “pointer”)
Valid strategies: “pointer”, “tensor_descriptor”, “block_ptr”
**kwargs (
object) – Additional user-defined configuration parameters.
Overview#
Config objects specify optimization parameters that control how Helion kernels run on the hardware.
Key Characteristics#
Performance-focused: Control GPU resource allocation, memory access patterns, and execution strategies
Autotuned: The autotuner searches through different Config combinations to find optimal performance
Kernel-specific: Each kernel can have different optimal Config parameters based on its computation pattern
Hardware-dependent: Optimal configs vary based on GPU architecture and problem size
Config vs Settings#
Aspect |
Config |
Settings |
|---|---|---|
Purpose |
Control execution performance |
Control compilation behavior |
Autotuning |
✅ Automatically optimized |
❌ Never autotuned |
Examples |
|
|
When to use |
Performance optimization |
Development, debugging, environment setup |
Configs are typically discovered automatically through autotuning, but can also be manually specified for more control.
Configuration Parameters#
Block Sizes and Resources#
- Config.block_sizes#
List of tile sizes for
hl.tile()loops. Each value controls the number of elements processed per GPU thread block for the corresponding tile dimension.
- Config.reduction_loops#
Configuration for reduction operations within loops.
- Config.num_warps#
Number of warps (groups of 32 threads) per thread block. Higher values increase parallelism but may reduce occupancy.
- Config.num_stages#
Number of pipeline stages for software pipelining. Higher values can improve memory bandwidth utilization.
Loop Optimizations#
- Config.loop_orders#
Permutation of loop iteration order for each
hl.tile()loop. Used to optimize memory access patterns.
- Config.flatten_loops#
Whether to flatten nested loops for each
hl.tile()invocation.
- Config.range_unroll_factors#
Unroll factors for
tl.rangeloops in generated Triton code.
- Config.range_warp_specializes#
Whether to enable warp specialization for
tl.rangeloops.
- Config.range_num_stages#
Number of pipeline stages for
tl.rangeloops.
- Config.range_multi_buffers#
Controls
disallow_acc_multi_bufferparameter fortl.rangeloops.
- Config.range_flattens#
Controls
flattenparameter fortl.rangeloops.
- Config.static_ranges#
Whether to use
tl.static_rangeinstead oftl.range.
Execution and Indexing#
- Config.pid_type#
Program ID layout strategy:
"flat": Standard linear program ID assignment"xyz": 3D program ID layout"persistent_blocked": Persistent kernels with blocked work distribution"persistent_interleaved": Persistent kernels with interleaved distribution
- Config.l2_groupings#
Controls reordering of program IDs to improve L2 cache locality.
- Config.indexing#
Memory indexing strategy for load and store operations. Can be specified as:
Single strategy (applies to all loads and stores - backward compatible):
indexing="block_ptr" # All loads and stores use block pointers
Per-operation strategies (list, one per load/store in execution order):
# 2 loads + 1 store = 3 indexing strategies indexing=["pointer", "pointer", "block_ptr"] # loads use pointer, store uses block_ptr
Empty/omitted (defaults to
"pointer"for all operations):# indexing not specified - all loads and stores use pointer indexingValid strategies:
"pointer": Pointer-based indexing (default)"tensor_descriptor": Tensor descriptor indexing (requires Hopper+ GPU)"block_ptr": Block pointer indexing
Note
When using a list, provide one strategy for each load and store operation in the order they appear in the kernel. The indexing list is ordered as:
[load1, load2, ..., loadN, store1, store2, ..., storeM]
Memory and Caching#
- Config.load_eviction_policies#
Eviction policies for load operations issued from device loops. Provide one policy per
hl.loadsite discovered in the kernel. Allowed values:"": No eviction policy (omitted)"first": Maps to Tritoneviction_policy='evict_first'"last": Maps to Tritoneviction_policy='evict_last'
Notes:
The number of entries must match the number of load sites considered tunable by the kernel.
An explicit
eviction_policy=...argument passed tohl.loadoverrides this config.
Usage Examples#
Manual Config Creation#
import torch
import helion
import helion.language as hl
# Create a specific configuration
config = helion.Config(
block_sizes=[64, 32], # 64 elements per tile in dim 0, 32 in dim 1
num_warps=8, # Use 8 warps (256 threads) per block
num_stages=4, # 4-stage pipeline
pid_type="xyz" # Use 3D program ID layout
)
# Use with kernel
@helion.kernel(config=config)
def my_kernel(x: torch.Tensor) -> torch.Tensor:
result = torch.zeros_like(x)
for i, j in hl.tile(x.shape):
result[i, j] = x[i, j] * 2
return result
Eviction Policy Example#
import torch
import helion
import helion.language as hl
@helion.kernel(
config={
"block_size": 16,
"load_eviction_policies": ["", "last"], # second load uses evict_last
}
)
def kernel_with_eviction(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
for tile in hl.tile(x.size(0)):
a = hl.load(x, [tile]) # No eviction policy
b = hl.load(y, [tile]) # Will use evict_last from config
out[tile] = a + b
return out
# Explicit policy on hl.load overrides config:
# hl.load(x, [tile], eviction_policy="evict_first")
Per-Load Indexing Example#
import torch
import helion
import helion.language as hl
# Single indexing strategy for all loads and stores (backward compatible)
@helion.kernel(config={"indexing": "block_ptr"})
def kernel_uniform_indexing(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
for tile in hl.tile(x.size(0)):
a = hl.load(x, [tile]) # Load: uses block_ptr
b = hl.load(y, [tile]) # Load: uses block_ptr
out[tile] = a + b # Store: uses block_ptr
return out
# Per-operation indexing strategies for fine-grained control
# Indexing list is ordered: [load1, load2, ..., store1, store2, ...]
@helion.kernel(
config={
"block_size": 16,
"indexing": ["pointer", "pointer", "block_ptr"], # 2 loads + 1 store
}
)
def kernel_mixed_indexing(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
for tile in hl.tile(x.size(0)):
a = hl.load(x, [tile]) # First load: pointer indexing
b = hl.load(y, [tile]) # Second load: pointer indexing
out[tile] = a + b # Store: block_ptr indexing
return out
Config Serialization#
# Save config to file
config.save("my_config.json")
# Load config from file
loaded_config = helion.Config.load("my_config.json")
# JSON serialization
config_dict = config.to_json()
restored_config = helion.Config.from_json(config_dict)
Autotuning with Restricted Configs#
# Restrict autotuning to specific configurations
configs = [
helion.Config(block_sizes=[32, 32], num_warps=4),
helion.Config(block_sizes=[64, 16], num_warps=8),
helion.Config(block_sizes=[16, 64], num_warps=4),
]
@helion.kernel(configs=configs)
def matrix_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
m, k = a.size()
k2, n = b.size()
assert k == k2, f"size mismatch {k} != {k2}"
out = torch.empty([m, n], dtype=a.dtype, device=a.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
out[tile_m, tile_n] = acc
return out
See Also#
Settings - Compilation settings and environment variables
Kernel - Kernel execution and autotuning
Autotuner Module - Autotuning configuration