Source code for helion.runtime.settings

from __future__ import annotations

import dataclasses
import logging
import os
import sys
import threading
from typing import TYPE_CHECKING
from typing import Literal
from typing import Protocol
from typing import Sequence
from typing import cast

import torch
from torch._environment import is_fbcode

from helion import exc
from helion.runtime.ref_mode import RefMode

if TYPE_CHECKING:
    from contextlib import AbstractContextManager

    from ..autotuner.base_search import BaseAutotuner
    from .kernel import BoundKernel

    class _TLS(Protocol):
        default_settings: Settings | None

    class AutotunerFunction(Protocol):
        def __call__(
            self, bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
        ) -> BaseAutotuner: ...


_tls: _TLS = cast("_TLS", threading.local())


[docs] def set_default_settings(settings: Settings) -> AbstractContextManager[None, None]: """ Set the default settings for the current thread and return a context manager that restores the previous settings upon exit. Args: settings: The Settings object to set as the default. Returns: AbstractContextManager[None, None]: A context manager that restores the previous settings upon exit. """ prior = getattr(_tls, "default_settings", None) _tls.default_settings = settings class _RestoreContext: def __enter__(self) -> None: pass def __exit__(self, *args: object) -> None: _tls.default_settings = prior return _RestoreContext()
def default_autotuner_fn( bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object ) -> BaseAutotuner: from ..autotuner import DifferentialEvolutionSearch from ..autotuner import LocalAutotuneCache return LocalAutotuneCache(DifferentialEvolutionSearch(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType] @dataclasses.dataclass class _Settings: # see __slots__ below for the doc strings that show up in help(Settings) ignore_warnings: list[type[exc.BaseWarning]] = dataclasses.field( default_factory=list ) index_dtype: torch.dtype = torch.int32 dot_precision: Literal["tf32", "tf32x3", "ieee"] = cast( "Literal['tf32', 'tf32x3', 'ieee']", os.environ.get("TRITON_F32_DEFAULT", "tf32"), ) static_shapes: bool = False use_default_config: bool = os.environ.get("HELION_USE_DEFAULT_CONFIG", "0") == "1" autotune_log_level: int = logging.INFO autotune_compile_timeout: int = int( os.environ.get("HELION_AUTOTUNE_COMPILE_TIMEOUT", "60") ) autotune_precompile: bool = sys.platform != "win32" print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1" force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1" allow_warp_specialize: bool = ( os.environ.get("HELION_ALLOW_WARP_SPECIALIZE", "1") == "1" ) ref_mode: RefMode = ( RefMode.EAGER if os.environ.get("HELION_INTERPRET", "") == "1" else RefMode.OFF ) autotuner_fn: AutotunerFunction = default_autotuner_fn
[docs] class Settings(_Settings): """ Settings can be passed to hl.kernel as kwargs and control the behavior of the compilation process. Unlike a Config, settings are not auto-tuned and set by the user. """ __slots__: dict[str, str] = { "ignore_warnings": "Subtypes of exc.BaseWarning to ignore when compiling.", "index_dtype": "The dtype to use for index variables. Default is torch.int32.", "dot_precision": "Precision for dot products, see `triton.language.dot`. Can be 'tf32', 'tf32x3', or 'ieee'.", "static_shapes": "If True, use static shapes for all tensors. This is a performance optimization.", "use_default_config": "For development only, skips all autotuning and uses the default config (which may be slow).", "autotune_log_level": "Log level for autotuning using Python logging levels. Default is logging.INFO. Use 0 to disable all output.", "autotune_compile_timeout": "Timeout for Triton compilation in seconds used for autotuning. Default is 60 seconds.", "autotune_precompile": "If True, precompile the kernel before autotuning. Requires fork-safe environment.", "print_output_code": "If True, print the output code of the kernel to stderr.", "force_autotune": "If True, force autotuning even if a config is provided.", "allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.", "ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.", "autotuner_fn": "Function to create an autotuner", } assert __slots__.keys() == {field.name for field in dataclasses.fields(_Settings)}
[docs] def __init__(self, **settings: object) -> None: """ Initialize the Settings object with the provided dictionary of settings. If no settings are provided, the default settings are used (see `set_default_settings`). Args: settings: Keyword arguments representing various settings. """ if defaults := getattr(_tls, "default_settings", None): settings = {**defaults.to_dict(), **settings} super().__init__(**settings) # pyright: ignore[reportArgumentType]
[docs] def to_dict(self) -> dict[str, object]: """ Convert the Settings object to a dictionary. Returns: dict[str, object]: A dictionary representation of the Settings object. """ def shallow_copy(x: object) -> object: if isinstance(x, (list, dict)): return x.copy() return x return {k: shallow_copy(v) for k, v in dataclasses.asdict(self).items()}
[docs] def check_autotuning_disabled(self) -> None: msg = None if os.environ.get("HELION_DISALLOW_AUTOTUNING", "0") == "1": msg = "by HELION_DISALLOW_AUTOTUNING=1" if is_fbcode(): from aiplatform.runtime_environment.runtime_environment_pybind import ( # type: ignore[import-untyped] RuntimeEnvironment, ) if RuntimeEnvironment().get_mast_job_name() is not None: msg = "because autotuning is not allowed in MAST environment" if msg: raise exc.AutotuningDisallowedInEnvironment(msg)
[docs] @staticmethod def default() -> Settings: """ Get the default Settings object. If no default settings are set, create a new one. Returns: Settings: The default Settings object. """ result = getattr(_tls, "default_settings", None) if result is None: _tls.default_settings = result = Settings() return result