Rate this Page

Source code for helion.runtime.config

from __future__ import annotations

from collections.abc import Iterator
from collections.abc import Mapping
import json
import os
from pathlib import Path
from typing import Literal
from typing import cast
import uuid

IndexingLiteral = Literal["pointer", "tensor_descriptor", "block_ptr"]
PidTypeLiteral = Literal["flat", "xyz", "persistent_blocked", "persistent_interleaved"]
EvictionPolicyLiteral = Literal["", "first", "last"]


[docs] class Config(Mapping[str, object]): config: dict[str, object]
[docs] def __init__( self, *, # Core properties block_sizes: list[int] | None = None, loop_orders: list[list[int]] | None = None, flatten_loops: list[bool] | None = None, l2_groupings: list[int] | None = None, reduction_loops: list[int | None] | None = None, range_unroll_factors: list[int] | None = None, range_warp_specializes: list[bool | None] | None = None, range_num_stages: list[int] | None = None, range_multi_buffers: list[bool | None] | None = None, range_flattens: list[bool | None] | None = None, static_ranges: list[bool] | None = None, load_eviction_policies: list[EvictionPolicyLiteral] | None = None, num_warps: int | None = None, num_stages: int | None = None, pid_type: PidTypeLiteral | None = None, indexing: IndexingLiteral | list[IndexingLiteral] | None = None, # For user-defined properties **kwargs: object, ) -> None: """ Initialize a Config object. Args: block_sizes: Controls tile sizes for hl.tile invocations. loop_orders: Permutes iteration order of tiles. l2_groupings: Reorders program IDs for L2 cache locality. reduction_loops: Configures reduction loop behavior. range_unroll_factors: Loop unroll factors for tl.range calls. range_warp_specializes: Warp specialization for tl.range calls. range_num_stages: Number of stages for tl.range calls. range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls. range_flattens: Controls flatten parameter for tl.range calls. static_ranges: Whether to use tl.static_range instead tl.range. load_eviction_policies: Eviction policies for load operations ("", "first", "last"). num_warps: Number of warps per block. num_stages: Number of stages for software pipelining. pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved"). indexing: Indexing strategy for load and store operations. Can be: - A single strategy string (all loads/stores use this strategy): indexing="block_ptr" # backward compatible - A list of strategies (one per load/store operation, must specify all): indexing=["pointer", "block_ptr", "tensor_descriptor"] - Empty/omitted (all loads/stores default to "pointer") Valid strategies: "pointer", "tensor_descriptor", "block_ptr" **kwargs: Additional user-defined configuration parameters. """ self.config = {} core_props = { "block_sizes": block_sizes, "loop_orders": loop_orders, "flatten_loops": flatten_loops, "l2_groupings": l2_groupings, "reduction_loops": reduction_loops, "range_unroll_factors": range_unroll_factors, "range_warp_specializes": range_warp_specializes, "range_num_stages": range_num_stages, "range_multi_buffers": range_multi_buffers, "range_flattens": range_flattens, "static_ranges": static_ranges, "load_eviction_policies": load_eviction_policies, "num_warps": num_warps, "num_stages": num_stages, "indexing": indexing, "pid_type": pid_type, } for key, value in core_props.items(): if value is not None: self.config[key] = value self.config.update(kwargs)
def __getitem__(self, key: str) -> object: return self.config[key] def __iter__(self) -> Iterator[str]: return iter(self.config) def __len__(self) -> int: return len(self.config) def __repr__(self) -> str: return f"helion.{self.__str__()}" def __str__(self) -> str: args = [f"{key}={value!r}" for key, value in sorted(self.config.items())] return f"Config({', '.join(args)})" def __eq__(self, other: object) -> bool: if not isinstance(other, Config): return NotImplemented return self.config == other.config def __hash__(self) -> int: return hash(frozenset([(k, _to_hashable(v)) for k, v in self.config.items()])) def __getstate__(self) -> dict[str, object]: return dict(self.config) def __setstate__(self, state: dict[str, object]) -> None: self.config = dict(state)
[docs] def to_json(self) -> str: """Convert the config to a JSON string.""" return json.dumps(self.config, indent=2)
[docs] @classmethod def from_json(cls, json_str: str) -> Config: """Create a Config object from a JSON string.""" config_dict = json.loads(json_str) return cls(**config_dict) # Changed to use dictionary unpacking
[docs] def save(self, path: str | Path) -> None: """Save the config to a JSON file.""" # Write to temp dir and rename to make the operation atomic # in case we are in a multithreaded environment Path(path).parent.mkdir(parents=True, exist_ok=True) tmp = Path(path).parent / f"tmp.{uuid.uuid4()!s}" tmp.write_text(self.to_json()) os.rename(str(tmp), str(path))
[docs] @classmethod def load(cls, path: str | Path) -> Config: """Load a config from a JSON file.""" return cls.from_json(Path(path).read_text())
@property def block_sizes(self) -> list[int]: return cast("list[int]", self.config["block_sizes"]) @property def loop_orders(self) -> list[list[int]]: return cast("list[list[int]]", self.config.get("loop_orders", [])) @property def flatten_loops(self) -> list[bool]: return cast("list[bool]", self.config.get("flatten_loops", [])) @property def reduction_loops(self) -> list[int | None]: return cast("list[int | None]", self.config.get("reduction_loops", [])) @property def num_warps(self) -> int: from ..autotuner.config_spec import DEFAULT_NUM_WARPS return cast("int", self.config.get("num_warps", DEFAULT_NUM_WARPS)) @property def num_stages(self) -> int: from ..autotuner.config_spec import DEFAULT_NUM_STAGES return cast("int", self.config.get("num_stages", DEFAULT_NUM_STAGES)) @property def l2_groupings(self) -> list[int]: return cast("list[int]", self.config.get("l2_groupings", [])) @property def pid_type(self) -> PidTypeLiteral: return cast("PidTypeLiteral", self.config.get("pid_type", "flat")) @property def range_unroll_factors(self) -> list[int]: return cast("list[int]", self.config.get("range_unroll_factors", [])) @property def range_warp_specializes(self) -> list[bool | None]: return cast("list[bool | None]", self.config.get("range_warp_specializes", [])) @property def range_num_stages(self) -> list[int]: return cast("list[int]", self.config.get("range_num_stages", [])) @property def range_multi_buffers(self) -> list[bool | None]: return cast("list[bool | None]", self.config.get("range_multi_buffers", [])) @property def range_flattens(self) -> list[bool | None]: return cast("list[bool | None]", self.config.get("range_flattens", [])) @property def static_ranges(self) -> list[bool]: return cast("list[bool]", self.config.get("static_ranges", [])) @property def load_eviction_policies(self) -> list[EvictionPolicyLiteral]: return cast( "list[EvictionPolicyLiteral]", self.config.get("load_eviction_policies", []) ) @property def indexing(self) -> IndexingLiteral | list[IndexingLiteral]: return cast( "IndexingLiteral | list[IndexingLiteral]", self.config.get("indexing", []) )
def _to_hashable(x: object) -> object: if isinstance(x, list): return tuple([_to_hashable(i) for i in x]) if isinstance(x, dict): return tuple(sorted([(k, _to_hashable(v)) for k, v in x.items()])) return x