from __future__ import annotations
from collections.abc import Iterator
from collections.abc import Mapping
import json
import os
from pathlib import Path
from typing import Literal
from typing import cast
import uuid
from ..autotuner.config_spec import DEFAULT_NUM_STAGES
from ..autotuner.config_spec import DEFAULT_NUM_WARPS
IndexingLiteral = Literal["pointer", "tensor_descriptor", "block_ptr"]
PidTypeLiteral = Literal["flat", "xyz", "persistent_blocked", "persistent_interleaved"]
[docs]
class Config(Mapping[str, object]):
config: dict[str, object]
[docs]
def __init__(
self,
*,
# Core properties
block_sizes: list[int] | None = None,
loop_orders: list[list[int]] | None = None,
flatten_loops: list[bool] | None = None,
l2_groupings: list[int] | None = None,
reduction_loops: list[int | None] | None = None,
range_unroll_factors: list[int] | None = None,
range_warp_specializes: list[bool | None] | None = None,
range_num_stages: list[int] | None = None,
range_multi_buffers: list[bool | None] | None = None,
range_flattens: list[bool | None] | None = None,
static_ranges: list[bool] | None = None,
num_warps: int | None = None,
num_stages: int | None = None,
pid_type: PidTypeLiteral | None = None,
indexing: IndexingLiteral | None = None,
# For user-defined properties
**kwargs: object,
) -> None:
"""
Initialize a Config object.
Args:
block_sizes: Controls tile sizes for hl.tile invocations.
loop_orders: Permutes iteration order of tiles.
l2_groupings: Reorders program IDs for L2 cache locality.
reduction_loops: Configures reduction loop behavior.
range_unroll_factors: Loop unroll factors for tl.range calls.
range_warp_specializes: Warp specialization for tl.range calls.
range_num_stages: Number of stages for tl.range calls.
range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls.
range_flattens: Controls flatten parameter for tl.range calls.
static_ranges: Whether to use tl.static_range instead tl.range.
num_warps: Number of warps per block.
num_stages: Number of stages for software pipelining.
pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved").
indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr").
**kwargs: Additional user-defined configuration parameters.
"""
self.config = {}
core_props = {
"block_sizes": block_sizes,
"loop_orders": loop_orders,
"flatten_loops": flatten_loops,
"l2_groupings": l2_groupings,
"reduction_loops": reduction_loops,
"range_unroll_factors": range_unroll_factors,
"range_warp_specializes": range_warp_specializes,
"range_num_stages": range_num_stages,
"range_multi_buffers": range_multi_buffers,
"range_flattens": range_flattens,
"static_ranges": static_ranges,
"num_warps": num_warps,
"num_stages": num_stages,
"indexing": indexing,
"pid_type": pid_type,
}
for key, value in core_props.items():
if value is not None:
self.config[key] = value
self.config.update(kwargs)
def __getitem__(self, key: str) -> object:
return self.config[key]
def __iter__(self) -> Iterator[str]:
return iter(self.config)
def __len__(self) -> int:
return len(self.config)
def __repr__(self) -> str:
return f"helion.{self.__str__()}"
def __str__(self) -> str:
args = [f"{key}={value!r}" for key, value in self.config.items()]
return f"Config({', '.join(args)})"
def __eq__(self, other: object) -> bool:
if not isinstance(other, Config):
return NotImplemented
return self.config == other.config
def __hash__(self) -> int:
return hash(frozenset([(k, _list_to_tuple(v)) for k, v in self.config.items()]))
[docs]
def to_json(self) -> str:
"""Convert the config to a JSON string."""
return json.dumps(self.config, indent=2)
[docs]
@classmethod
def from_json(cls, json_str: str) -> Config:
"""Create a Config object from a JSON string."""
config_dict = json.loads(json_str)
return cls(**config_dict) # Changed to use dictionary unpacking
[docs]
def save(self, path: str | Path) -> None:
"""Save the config to a JSON file."""
# Write to temp dir and rename to make the operation atomic
# in case we are in a multithreaded environment
Path(path).parent.mkdir(parents=True, exist_ok=True)
tmp = Path(path).parent / f"tmp.{uuid.uuid4()!s}"
tmp.write_text(self.to_json())
os.rename(str(tmp), str(path))
[docs]
@classmethod
def load(cls, path: str | Path) -> Config:
"""Load a config from a JSON file."""
return cls.from_json(Path(path).read_text())
@property
def block_sizes(self) -> list[int]:
return cast("list[int]", self.config["block_sizes"])
@property
def loop_orders(self) -> list[list[int]]:
return cast("list[list[int]]", self.config.get("loop_orders", []))
@property
def flatten_loops(self) -> list[bool]:
return cast("list[bool]", self.config.get("flatten_loops", []))
@property
def reduction_loops(self) -> list[int | None]:
return cast("list[int | None]", self.config.get("reduction_loops", []))
@property
def num_warps(self) -> int:
return cast("int", self.config.get("num_warps", DEFAULT_NUM_WARPS))
@property
def num_stages(self) -> int:
return cast("int", self.config.get("num_stages", DEFAULT_NUM_STAGES))
@property
def l2_groupings(self) -> list[int]:
return cast("list[int]", self.config.get("l2_groupings", []))
@property
def pid_type(self) -> PidTypeLiteral:
return cast("PidTypeLiteral", self.config.get("pid_type", "flat"))
@property
def range_unroll_factors(self) -> list[int]:
return cast("list[int]", self.config.get("range_unroll_factors", []))
@property
def range_warp_specializes(self) -> list[bool | None]:
return cast("list[bool | None]", self.config.get("range_warp_specializes", []))
@property
def range_num_stages(self) -> list[int]:
return cast("list[int]", self.config.get("range_num_stages", []))
@property
def range_multi_buffers(self) -> list[bool | None]:
return cast("list[bool | None]", self.config.get("range_multi_buffers", []))
@property
def range_flattens(self) -> list[bool | None]:
return cast("list[bool | None]", self.config.get("range_flattens", []))
@property
def static_ranges(self) -> list[bool]:
return cast("list[bool]", self.config.get("static_ranges", []))
@property
def indexing(self) -> IndexingLiteral:
return self.config.get("indexing", "pointer") # type: ignore[return-value]
def _list_to_tuple(x: object) -> object:
if isinstance(x, list):
return tuple([_list_to_tuple(i) for i in x])
return x