Rate this Page

Source code for helion.runtime.settings

from __future__ import annotations

import dataclasses
import functools
import json
import logging
import os
import time
from typing import TYPE_CHECKING
from typing import Callable
from typing import Literal
from typing import Protocol
from typing import Sequence
from typing import TypeVar
from typing import cast

import torch
from torch._environment import is_fbcode

from .. import exc
from ..autotuner.effort_profile import AutotuneEffort
from ..autotuner.effort_profile import get_effort_profile
from .ref_mode import RefMode

if TYPE_CHECKING:
    from ..autotuner.base_search import BaseAutotuner
    from .kernel import BoundKernel

    _T = TypeVar("_T")

    class AutotunerFunction(Protocol):
        def __call__(
            self, bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
        ) -> BaseAutotuner: ...


DotPrecision = Literal["tf32", "tf32x3", "ieee"]
PrecompileMode = Literal["spawn", "fork"] | None
_TRUE_LITERALS = frozenset({"1", "true", "yes", "on"})
_FALSE_LITERALS = frozenset({"0", "false", "no", "off"})


def _resolve_warning_name(name: str) -> type[exc.BaseWarning]:
    attr = name.strip()
    if not attr:
        raise ValueError("HELION_IGNORE_WARNINGS entries must be non-empty names")
    try:
        warning_cls = getattr(exc, attr)
    except AttributeError as err:
        raise ValueError(
            f"HELION_IGNORE_WARNINGS entry {name!r} is not a warning defined in helion.exc"
        ) from err
    if not isinstance(warning_cls, type) or not issubclass(
        warning_cls, exc.BaseWarning
    ):
        raise ValueError(
            f"HELION_IGNORE_WARNINGS entry {name!r} does not refer to a helion.exc.BaseWarning subclass"
        )
    return warning_cls


def _get_ignore_warnings() -> list[type[exc.BaseWarning]]:
    value = os.environ.get("HELION_IGNORE_WARNINGS")
    if not value:
        return []
    result: list[type[exc.BaseWarning]] = []
    for entry in value.split(","):
        entry = entry.strip()
        if not entry:
            continue
        result.append(_resolve_warning_name(entry))
    return result


def _env_get_optional_int(var_name: str) -> int | None:
    value = os.environ.get(var_name)
    if value is None or (value := value.strip()) == "":
        return None
    try:
        parsed = int(value)
    except ValueError as err:
        raise ValueError(f"{var_name} must be an integer, got {value!r}") from err
    return parsed


def _env_get_int(var_name: str, default: int) -> int:
    result = _env_get_optional_int(var_name)
    return default if result is None else result


def _env_get_optional_float(var_name: str) -> float | None:
    value = os.environ.get(var_name)
    if value is None or (value := value.strip()) == "":
        return None
    try:
        return float(value)
    except ValueError as err:
        raise ValueError(f"{var_name} must be a float, got {value!r}") from err


def _env_get_bool(var_name: str, default: bool) -> bool:
    value = os.environ.get(var_name)
    if value is None or (value := value.strip()) == "":
        return default
    lowered = value.lower()
    if lowered in _TRUE_LITERALS:
        return True
    if lowered in _FALSE_LITERALS:
        return False
    raise ValueError(
        f"{var_name} must be one of {_TRUE_LITERALS | _FALSE_LITERALS}, got {value!r}"
    )


def _env_get_literal(
    var_name: str,
    default: _T,
    *,
    mapping: dict[str, object],
) -> _T:
    value = os.environ.get(var_name)
    if value is None:
        return default
    value = value.strip()
    if value in mapping:
        return cast("_T", mapping[value])
    if value == "":
        return default
    raise ValueError(
        f"{var_name} must be one of {', '.join(sorted(mapping))}, got {value!r}"
    )


def _env_get_str(var_name: str, default: str) -> str:
    value = os.environ.get(var_name)
    if value is None or (value := value.strip()) == "":
        return default
    return value


def _get_index_dtype() -> torch.dtype:
    value = os.environ.get("HELION_INDEX_DTYPE")
    if value is None or (token := value.strip()) == "":
        return torch.int32
    try:
        dtype = getattr(torch, token)
    except AttributeError as err:
        raise ValueError(
            f"HELION_INDEX_DTYPE must map to a torch dtype attribute, got {value!r}"
        ) from err
    if not isinstance(dtype, torch.dtype):
        raise ValueError(f"HELION_INDEX_DTYPE {value!r} is not a torch.dtype")
    return dtype


def _get_autotune_log_level() -> int:
    value = os.environ.get("HELION_AUTOTUNE_LOG_LEVEL")
    if value is None or value.strip() == "":
        return logging.INFO
    text = value.strip()
    if text.lstrip("+-").isdigit():
        return int(text)
    upper = text.upper()
    level = logging.getLevelName(upper)
    if isinstance(level, int):
        return level
    raise ValueError(
        f"HELION_AUTOTUNE_LOG_LEVEL must be an integer or logging level name, got {value!r}"
    )


def _get_autotune_config_overrides() -> dict[str, object]:
    value = os.environ.get("HELION_AUTOTUNE_CONFIG_OVERRIDES")
    if not value or (value := value.strip()) == "":
        return {}
    if not value.startswith("{") and os.path.exists(value):
        value = open(value).read()
    try:
        parsed = json.loads(value)
    except json.JSONDecodeError as err:
        raise ValueError(
            "HELION_AUTOTUNE_CONFIG_OVERRIDES must be valid JSON mapping of config keys to values"
        ) from err
    if not isinstance(parsed, dict):
        raise ValueError(
            "HELION_AUTOTUNE_CONFIG_OVERRIDES must decode to a JSON dictionary"
        )
    return parsed


def default_autotuner_fn(
    bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
) -> BaseAutotuner:
    from ..autotuner import cache_classes
    from ..autotuner import search_algorithms

    autotuner_name = os.environ.get("HELION_AUTOTUNER", "PatternSearch")
    autotuner_cls = search_algorithms.get(autotuner_name)
    if autotuner_cls is None:
        raise ValueError(
            f"Unknown HELION_AUTOTUNER value: {autotuner_name}, valid options are: "
            f"{', '.join(search_algorithms.keys())}"
        )

    # Use autotune_max_generations from settings if kwarg is not explicitly provided
    if autotuner_name in ("PatternSearch", "DifferentialEvolutionSearch"):
        if bound_kernel.settings.autotune_max_generations is not None:
            kwargs.setdefault(
                "max_generations", bound_kernel.settings.autotune_max_generations
            )

    profile = get_effort_profile(bound_kernel.settings.autotune_effort)

    if autotuner_cls.__name__ == "PatternSearch":
        assert profile.pattern_search is not None
        kwargs.setdefault(
            "initial_population", profile.pattern_search.initial_population
        )
        kwargs.setdefault("copies", profile.pattern_search.copies)
        kwargs.setdefault("max_generations", profile.pattern_search.max_generations)
    elif autotuner_cls.__name__ == "DifferentialEvolutionSearch":
        assert profile.differential_evolution is not None
        kwargs.setdefault(
            "population_size", profile.differential_evolution.population_size
        )
        kwargs.setdefault(
            "max_generations", profile.differential_evolution.max_generations
        )
    elif autotuner_cls.__name__ == "RandomSearch":
        assert profile.random_search is not None
        kwargs.setdefault("count", profile.random_search.count)

    settings = bound_kernel.settings
    cache_name = settings.autotune_cache
    cache_cls = cache_classes.get(cache_name)
    if cache_cls is None:
        raise ValueError(
            f"Unknown HELION_AUTOTUNE_CACHE value: {cache_name}, valid options are: "
            f"{', '.join(cache_classes.keys())}"
        )

    return cache_cls(autotuner_cls(bound_kernel, args, **kwargs))  # pyright: ignore[reportArgumentType]


def _get_autotune_random_seed() -> int:
    if (seed := _env_get_optional_int("HELION_AUTOTUNE_RANDOM_SEED")) is not None:
        return seed
    return int(time.time() * 1000) % 2**32


def _get_ref_mode() -> RefMode:
    interpret = _env_get_bool("HELION_INTERPRET", False)
    return RefMode.EAGER if interpret else RefMode.OFF


@dataclasses.dataclass
class _Settings:
    # see __slots__ below for the doc strings that show up in help(Settings)
    ignore_warnings: list[type[exc.BaseWarning]] = dataclasses.field(
        default_factory=_get_ignore_warnings
    )
    index_dtype: torch.dtype = dataclasses.field(default_factory=_get_index_dtype)
    dot_precision: DotPrecision = dataclasses.field(
        default_factory=functools.partial(
            _env_get_literal,
            "TRITON_F32_DEFAULT",
            cast("DotPrecision", "tf32"),
            mapping={k: k for k in ("tf32", "tf32x3", "ieee")},
        )
    )  # pyright: ignore[reportAssignmentType]
    static_shapes: bool = dataclasses.field(
        default_factory=functools.partial(_env_get_bool, "HELION_STATIC_SHAPES", True)
    )
    autotune_log_level: int = dataclasses.field(default_factory=_get_autotune_log_level)
    autotune_compile_timeout: int = dataclasses.field(
        default_factory=functools.partial(
            _env_get_int, "HELION_AUTOTUNE_COMPILE_TIMEOUT", 60
        )
    )
    autotune_precompile: PrecompileMode = dataclasses.field(
        default_factory=functools.partial(
            _env_get_literal,
            "HELION_AUTOTUNE_PRECOMPILE",
            cast("PrecompileMode", "fork"),
            mapping={
                "spawn": "spawn",
                "fork": "fork",
                "": None,
                "0": None,
            },
        )
    )  # pyright: ignore[reportAssignmentType]
    autotune_precompile_jobs: int | None = dataclasses.field(
        default_factory=functools.partial(
            _env_get_optional_int,
            "HELION_AUTOTUNE_PRECOMPILE_JOBS",
        )
    )
    autotune_random_seed: int = dataclasses.field(
        default_factory=_get_autotune_random_seed
    )
    autotune_accuracy_check: bool = dataclasses.field(
        default_factory=functools.partial(
            _env_get_bool, "HELION_AUTOTUNE_ACCURACY_CHECK", True
        )
    )
    autotune_rebenchmark_threshold: float | None = dataclasses.field(
        default_factory=functools.partial(
            _env_get_optional_float,
            "HELION_REBENCHMARK_THRESHOLD",
        )
    )
    autotune_progress_bar: bool = dataclasses.field(
        default_factory=functools.partial(
            _env_get_bool, "HELION_AUTOTUNE_PROGRESS_BAR", True
        )
    )
    autotune_max_generations: int | None = dataclasses.field(
        default_factory=functools.partial(
            _env_get_optional_int,
            "HELION_AUTOTUNE_MAX_GENERATIONS",
        )
    )
    autotune_ignore_errors: bool = dataclasses.field(
        default_factory=functools.partial(
            _env_get_bool, "HELION_AUTOTUNE_IGNORE_ERRORS", False
        )
    )
    print_output_code: bool = dataclasses.field(
        default_factory=functools.partial(
            _env_get_bool, "HELION_PRINT_OUTPUT_CODE", False
        )
    )
    print_repro: bool = dataclasses.field(
        default_factory=functools.partial(_env_get_bool, "HELION_PRINT_REPRO", False)
    )
    output_origin_lines: bool = dataclasses.field(
        default_factory=functools.partial(
            _env_get_bool, "HELION_OUTPUT_ORIGIN_LINES", True
        )
    )
    force_autotune: bool = dataclasses.field(
        default_factory=functools.partial(_env_get_bool, "HELION_FORCE_AUTOTUNE", False)
    )
    autotune_config_overrides: dict[str, object] = dataclasses.field(
        default_factory=_get_autotune_config_overrides
    )
    autotune_effort: AutotuneEffort = dataclasses.field(
        default_factory=functools.partial(
            _env_get_literal,
            "HELION_AUTOTUNE_EFFORT",
            cast("AutotuneEffort", "full"),
            mapping={key: key for key in ("none", "quick", "full")},
        )
    )  # pyright: ignore[reportAssignmentType]
    allow_warp_specialize: bool = dataclasses.field(
        default_factory=functools.partial(
            _env_get_bool, "HELION_ALLOW_WARP_SPECIALIZE", True
        )
    )
    debug_dtype_asserts: bool = dataclasses.field(
        default_factory=functools.partial(
            _env_get_bool, "HELION_DEBUG_DTYPE_ASSERTS", False
        )
    )
    ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode)
    autotune_cache: str = dataclasses.field(
        default_factory=functools.partial(
            _env_get_str, "HELION_AUTOTUNE_CACHE", "LocalAutotuneCache"
        )
    )
    autotuner_fn: AutotunerFunction = default_autotuner_fn
    autotune_baseline_fn: Callable[..., object] | None = None


[docs] class Settings(_Settings): """ Settings can be passed to hl.kernel as kwargs and control the behavior of the compilation process. Unlike a Config, settings are not auto-tuned and set by the user. """ __slots__ = { "ignore_warnings": ( "Subtypes of exc.BaseWarning to ignore when compiling. " "Set HELION_IGNORE_WARNINGS=WarningA,WarningB (names from helion.exc) to configure via env." ), "index_dtype": ( "The dtype to use for index variables. Default is torch.int32. " "Override with HELION_INDEX_DTYPE=torch.int64, etc." ), "dot_precision": "Precision for dot products, see `triton.language.dot`. Can be 'tf32', 'tf32x3', or 'ieee'.", "static_shapes": ( "If True, use static shapes for all tensors. This is a performance optimization. " "Set HELION_STATIC_SHAPES=0 to disable." ), "autotune_log_level": ( "Log level for autotuning using Python logging levels. Default is logging.INFO. " "Use HELION_AUTOTUNE_LOG_LEVEL to override or set 0 to disable output." ), "autotune_compile_timeout": "Timeout for Triton compilation in seconds used for autotuning. Default is 60 seconds.", "autotune_precompile": "Autotuner precompile mode: 'fork', 'spawn', or falsy/None to disable. Defaults to 'fork' on non-Windows platforms.", "autotune_precompile_jobs": "Maximum concurrent Triton precompile processes, default to cpu count.", "autotune_random_seed": "Seed used for autotuner random number generation. Defaults to HELION_AUTOTUNE_RANDOM_SEED or a time-based seed.", "autotune_accuracy_check": "If True, validate candidate configs against the baseline kernel output before accepting them during autotuning.", "autotune_rebenchmark_threshold": "If a config is within threshold*best_perf, re-benchmark it to avoid outliers. Defaults to effort profile value. Set HELION_REBENCHMARK_THRESHOLD to override.", "autotune_progress_bar": "If True, show progress bar during autotuning. Default is True. Set HELION_AUTOTUNE_PROGRESS_BAR=0 to disable.", "autotune_max_generations": "Override the maximum number of generations for Pattern Search and Differential Evolution Search autotuning algorithms with HELION_AUTOTUNE_MAX_GENERATIONS=N or @helion.kernel(autotune_max_generations=N).", "autotune_ignore_errors": ( "If True, skip logging and raising autotune errors. " "Set HELION_AUTOTUNE_IGNORE_ERRORS=1 to enable globally." ), "print_output_code": "If True, print the output code of the kernel to stderr.", "print_repro": "If True, print Helion kernel code, config, and caller code to stderr as a standalone repro script.", "output_origin_lines": ( "If True, annotate generated Triton code with source-origin comments. " "Set HELION_OUTPUT_ORIGIN_LINES=0 to disable." ), "force_autotune": "If True, force autotuning even if a config is provided.", "autotune_config_overrides": ( "Dictionary of config key/value pairs forced during autotuning. " "Accepts HELION_AUTOTUNE_CONFIG_OVERRIDES='{\"num_warps\":4}'." ), "allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.", "debug_dtype_asserts": "If True, emit tl.static_assert checks for dtype after each device node.", "ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.", "autotuner_fn": ( "Function to create an autotuner. " "Override by passing a callable to @helion.kernel(..., autotuner_fn=...)." ), "autotune_effort": "Autotuning effort preset. One of 'none', 'quick', 'full'.", "autotune_baseline_fn": ( "Custom baseline function for computing baseline output during autotuning. " "If provided, this function will be called instead of running the default config. " "Should have the same signature as the kernel function. " "Pass as @helion.kernel(..., autotune_baseline_fn=my_baseline_fn)." ), "autotune_cache": ( "The name of the autotuner cache class to use. " "Set HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache to enable strict caching. " "Defaults to 'LocalAutotuneCache'." ), }
[docs] def __init__(self, **settings: object) -> None: """ Initialize the Settings object with the provided dictionary of settings. """ super().__init__(**settings) # pyright: ignore[reportArgumentType] self._check_ref_eager_mode_before_print_output_code()
[docs] def to_dict(self) -> dict[str, object]: """ Convert the Settings object to a dictionary. Returns: dict[str, object]: A dictionary representation of the Settings object. """ def shallow_copy(x: object) -> object: if isinstance(x, (list, dict)): return x.copy() return x # Only include fields that are meant to be public (repr=True) public_fields = {f.name for f in dataclasses.fields(self) if f.repr} return { k: shallow_copy(v) for k, v in dataclasses.asdict(self).items() if k in public_fields }
[docs] def check_autotuning_disabled(self) -> None: msg = None if os.environ.get("HELION_DISALLOW_AUTOTUNING", "0") == "1": msg = "by HELION_DISALLOW_AUTOTUNING=1" if is_fbcode(): from aiplatform.runtime_environment.runtime_environment_pybind import ( # type: ignore[import-untyped] RuntimeEnvironment, ) if RuntimeEnvironment().get_mast_job_name() is not None: msg = "because autotuning is not allowed in MAST environment" if msg: raise exc.AutotuningDisallowedInEnvironment(msg)
[docs] def get_rebenchmark_threshold(self) -> float: """ Get the effective rebenchmark threshold. Uses the explicit setting if provided, otherwise falls back to the effort profile default. Returns: float: The rebenchmark threshold value. """ if self.autotune_rebenchmark_threshold is not None: return self.autotune_rebenchmark_threshold from ..autotuner.effort_profile import get_effort_profile return get_effort_profile(self.autotune_effort).rebenchmark_threshold
def _check_ref_eager_mode_before_print_output_code(self) -> None: """ Check if ref eager mode is enabled before printing output code. If ref eager mode is enabled, raise an error. """ if self.ref_mode == RefMode.EAGER and self.print_output_code: raise exc.RefEagerModeCodePrintError