Source code for helion.runtime.kernel

from __future__ import annotations

import contextlib
import dataclasses
import functools
import inspect
import logging
import operator
import re
import sys
import types
from typing import TYPE_CHECKING
from typing import Callable
from typing import Generic
from typing import TypeVar
from typing import cast
from typing import overload
from typing_extensions import Protocol

import torch
from torch._dynamo.source import LocalSource
from torch._dynamo.source import TensorProperty
from torch._dynamo.source import TensorPropertySource
from torch._inductor.codecache import PyCodeCache
from torch._subclasses import FakeTensor

from .. import exc
from .._compiler.ast_extension import unparse
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.generate_ast import generate_ast
from .._compiler.host_function import HostFunction
from .._compiler.inductor_lowering_extra import patch_inductor_lowerings
from .._compiler.output_header import assert_no_conflicts
from .._compiler.output_header import get_needed_imports
from .._compiler.variable_origin import ArgumentOrigin
from .._logging import LazyString
from ..language.constexpr import ConstExpr
from .config import Config
from .ref_mode import RefModeContext
from .ref_mode import is_ref_mode_enabled
from .settings import Settings

if TYPE_CHECKING:
    from collections.abc import Hashable
    from collections.abc import Sequence

    from torch._guards import Source

    from ..autotuner import ConfigSpec
    from ..autotuner.base_cache import BoundKernelInMemoryCacheKey

    ConfigLike = Config | dict[str, object]

log: logging.Logger = logging.getLogger(__name__)
_R = TypeVar("_R")
CompiledConfig = Callable[..., _R]


[docs] class Kernel(Generic[_R]):
[docs] def __init__( self, fn: Callable[..., _R], *, configs: list[ConfigLike] | None = None, settings: Settings | None, ) -> None: """ Initialize the Kernel object. This is typically called from the `@helion.kernel` decorator. Args: fn: The function to be compiled as a Helion kernel. configs: A list of configurations to use for the kernel. settings: The settings to be used by the Kernel. If None, default settings are used. """ super().__init__() assert isinstance(fn, types.FunctionType) assert_no_conflicts(fn) self.name: str = fn.__name__ self.fn: types.FunctionType = fn self.signature: inspect.Signature = inspect.signature(fn) self.settings: Settings = settings or Settings.default() self.configs: list[Config] = [ Config(**c) if isinstance(c, dict) else c # pyright: ignore[reportArgumentType] for c in configs or [] ] self._bound_kernels: dict[BoundKernelInMemoryCacheKey, BoundKernel] = {} self._specialize_extra: dict[ Hashable, list[Callable[[Sequence[object]], Hashable]] ] = {} if any( param.kind in ( inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, ) for param in self.signature.parameters.values() ): raise TypeError( f"Kernel({self.name}) cannot have *args, **kwargs, or keyword-only arguments" ) self._annotations: list[object] = [] for param in self.signature.parameters.values(): ann = param.annotation if isinstance(ann, str) and re.search(r"constexpr", ann, re.IGNORECASE): self._annotations.append(ConstExpr) else: self._annotations.append(ann)
def _get_bound_kernel_cache_key( self, args: tuple[object, ...], signature: tuple[Hashable, ...] ) -> BoundKernelInMemoryCacheKey | None: from ..autotuner.base_cache import BoundKernelInMemoryCacheKey extra_fns = self._specialize_extra.get(signature) if extra_fns is not None: extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns]) return BoundKernelInMemoryCacheKey(signature, extra_results) return None def _create_bound_kernel_cache_key( self, bound_kernel: BoundKernel, args: tuple[object, ...], signature: tuple[Hashable, ...], ) -> BoundKernelInMemoryCacheKey: from ..autotuner.base_cache import BoundKernelInMemoryCacheKey self._specialize_extra[signature] = extra_fns = bound_kernel._specialize_extra() extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns]) return BoundKernelInMemoryCacheKey(signature, extra_results)
[docs] def bind(self, args: tuple[object, ...]) -> BoundKernel[_R]: """ Bind the given arguments to the Kernel and return a BoundKernel object. Args: args: The arguments to bind to the Kernel. Returns: BoundKernel: A BoundKernel object with the given arguments bound. """ if not isinstance(args, tuple): assert isinstance(args, list), "args must be a tuple or list" args = tuple(args) signature = self.specialization_key(args) cache_key = self._get_bound_kernel_cache_key(args, signature) bound_kernel = ( None if cache_key is None else self._bound_kernels.get(cache_key, None) ) if bound_kernel is None: normalized_args: tuple[object, ...] = self.normalize_args(*args) if len(normalized_args) != len(args): # we had default args that needed to be applied bound_kernel = self.bind(normalized_args) else: bound_kernel = BoundKernel(self, args) if cache_key is None: cache_key = self._create_bound_kernel_cache_key( bound_kernel, args, signature ) self._bound_kernels[cache_key] = bound_kernel return bound_kernel
[docs] def specialization_key(self, args: Sequence[object]) -> tuple[Hashable, ...]: """ Generate a specialization key for the given arguments. This method generates a unique key for the arguments based on their types and the corresponding extractor functions defined in `_specialization_extractors`. Args: args: The arguments to generate a specialization key for. Returns: Hashable: A hashable key representing the specialization of the arguments. """ result = [] assert len(args) <= len(self._annotations) for value, annotation in zip(args, self._annotations, strict=False): if isinstance(value, ConstExpr): result.append(value.value) elif annotation is ConstExpr: result.append(value) else: result.append(self._specialization_key(value)) return tuple(result)
def _specialization_key(self, obj: object) -> Hashable: """ Helper used to generate a specialization key for the given object. This method determines a unique key for the object based on its type and the corresponding extractor function defined in `_specialization_extractors`. Args: obj: The argument to generate a specialization key for. Returns: Hashable: A hashable key representing the specialization of the object. """ try: extractor = _specialization_extractors[type(obj)] except KeyError: if isinstance(obj, tuple) and hasattr(obj, "_fields"): # this is a namedtuple extractor = _specialization_extractors["namedtuple"] elif dataclasses.is_dataclass(obj): extractor = _specialization_extractors["dataclass"] else: raise TypeError( f"unsupported argument type: {type(obj).__name__}" ) from None return extractor(self, obj)
[docs] def normalize_args(self, *args: object, **kwargs: object) -> tuple[object, ...]: """ Normalize the given arguments and keyword arguments according to the function signature. Args: args: The positional arguments to normalize. kwargs: The keyword arguments to normalize. Returns: tuple[object, ...]: A tuple of normalized positional arguments. """ bound_args = self.signature.bind(*args, **kwargs) bound_args.apply_defaults() return tuple(bound_args.args)
[docs] def autotune( self, args: Sequence[object], *, force: bool = False, **options: object, ) -> Config: """ Perform autotuning to find the optimal configuration for the kernel. This uses the default setting, you can call helion.autotune.* directly for more customization. If config= or configs= is provided to helion.kernel(), the search will be restricted to the provided configs. Use force=True to ignore the provided configs. Mutates (the bound version of) self so that `__call__` will run the best config found. Args: args: Example arguments used for benchmarking during autotuning. force: If True, force full autotuning even if a config is provided. **options: Additional options for autotuning. Returns: Config: The best configuration found during autotuning. """ args = self.normalize_args(*args) return self.bind(args).autotune(args, force=force, **options)
[docs] def __call__(self, *args: object, **kwargs: object) -> _R: """ Call the Kernel with the given arguments and keyword arguments. Args: args: The positional arguments to pass to the Kernel. kwargs: The keyword arguments to pass to the Kernel. Returns: _R: The result of the Kernel function call. """ if kwargs: args = self.normalize_args(*args, **kwargs) return self.bind(args)(*args)
[docs] def reset(self) -> None: """ Clears the cache of bound kernels, meaning subsequent calls will recompile and re-autotune. """ self._bound_kernels.clear()
class BoundKernel(Generic[_R]): def __init__( self, kernel: Kernel[_R], args: tuple[object, ...], ) -> None: """ Initialize a BoundKernel object. This constructor sets up the environment, compiles the kernel function, and prepares the arguments for execution. Args: kernel: The Kernel object to bind. args: A tuple of arguments to bind to the kernel. """ super().__init__() self.kernel = kernel self._run: Callable[..., _R] | None = None self._config: Config | None = None self._compile_cache: dict[Config, CompiledConfig] = {} self.env = CompileEnvironment(_find_device(args), self.kernel.settings) if is_ref_mode_enabled(self.kernel.settings): self.fake_args = [] # type: ignore[assignment] self.host_function = None # type: ignore[assignment] return with self.env: assert len(args) == len(self.kernel.signature.parameters) self.fake_args: list[object] = [] constexpr_args = {} for name, arg, annotation in zip( self.kernel.signature.parameters, args, self.kernel._annotations, strict=False, ): if isinstance(arg, ConstExpr): assert not isinstance(arg.value, torch.Tensor), ( "ConstExpr cannot be a tensor" ) self.fake_args.append(arg.value) constexpr_args[name] = arg.value elif annotation is ConstExpr: assert not isinstance(arg, torch.Tensor), ( "ConstExpr cannot be a tensor" ) self.fake_args.append(arg) constexpr_args[name] = arg else: self.fake_args.append(self.env.to_fake(arg, ArgumentOrigin(name))) with ( _maybe_skip_dtype_check_in_meta_registrations(), patch_inductor_lowerings(), ): self.host_function: HostFunction = HostFunction( self.kernel.fn, self.fake_args, constexpr_args ) @property def settings(self) -> Settings: """ Retrieve the settings associated with the kernel. Returns: Settings: The settings of the kernel. """ return self.kernel.settings @property def config_spec(self) -> ConfigSpec: """ Retrieve the configuration specification for the kernel. Returns: ConfigSpec: The configuration specification. """ return self.env.config_spec @property def configs(self) -> list[Config]: """ Alias for `self.kernel.configs`. Returns: list[Config]: The list of configurations. """ return self.kernel.configs def to_triton_code(self, config: ConfigLike | None = None) -> str: """ Generate Triton code for the kernel based on the given configuration. Args: config: The configuration to use for code generation. Returns: str: The generated Triton code as a string. """ if config is None: config = self._require_implicit_config() with self.env: if not isinstance(config, Config): config = Config(**config) # pyright: ignore[reportArgumentType] self.env.config_spec.normalize(config) root = generate_ast(self.host_function, config) return get_needed_imports(root) + unparse(root) def compile_config( self, config: ConfigLike | None = None, *, allow_print: bool = True ) -> CompiledConfig: """ Compile the kernel for a specific configuration. Args: config: The configuration to compile the kernel with. allow_print: Set to suppress printing the output code when autotuning. Returns: CompiledConfig: A callable object representing the compiled kernel. """ if config is None: config = self._require_implicit_config() if not isinstance(config, Config): config = Config( **config # pyright: ignore[reportArgumentType] ) if (rv := self._compile_cache.get(config)) is not None: return rv triton_code = self.to_triton_code(config) if allow_print: log.info("Output code: \n%s", triton_code) log.debug("Debug string: \n%s", LazyString(lambda: self._debug_str())) if self.settings.print_output_code: print(triton_code, file=sys.stderr) module = PyCodeCache.load(triton_code) rv = getattr(module, self.kernel.name) self._compile_cache[config] = rv return rv def _debug_str(self) -> str: """ Generate a debug string for the kernel. Returns: str: A string containing debug information about the kernel. """ if self.host_function is None: # In ref mode, host_function is not created return f"<BoundKernel {self.kernel.fn.__name__} in ref mode>" with self.env: return self.host_function.debug_str() def autotune( self, args: Sequence[object], *, force: bool = False, **kwargs: object, ) -> Config: """ Perform autotuning to find the optimal configuration for the kernel. This uses the default setting, you can call helion.autotune.* directly for more customization. If config= or configs= is provided to helion.kernel(), the search will be restricted to the provided configs. Use force=True to ignore the provided configs. Mutates self so that `__call__` will run the best config found. Args: args: Example arguments used for benchmarking during autotuning. force: If True, force full autotuning even if a config is provided. **kwargs: Additional options for autotuning. Returns: Config: The best configuration found during autotuning. """ force = force or self.settings.force_autotune if not force and self.kernel.configs: if len(self.kernel.configs) == 1: (config,) = self.kernel.configs else: # We have finite predetermined configs, no need to precompile self.settings.autotune_precompile = False from ..autotuner import FiniteSearch config = FiniteSearch(self, args, self.configs).autotune() else: self.settings.check_autotuning_disabled() config = self.settings.autotuner_fn(self, args, **kwargs).autotune() self.set_config(config) return config def set_config(self, config: ConfigLike) -> None: """ Set the configuration for the kernel and compile it. Mutates self so that `__call__` will run the provided config. Args: config: The configuration to set. """ if not isinstance(config, Config): config = Config( **config # pyright: ignore[reportArgumentType] ) self._run = self.compile_config(config) self._config = config def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]: """ Returns a list of functions that will be called to generate extra specialization keys. This is used to specialize on the values hl.specialize()'ed arguments. Returns: list[Callable[[Sequence[object]], Hashable]]: A list of functions that generate extra specialization keys. """ if not self.env.specialized_vars: return [] def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]: if isinstance(v, TensorPropertySource): assert v.prop == TensorProperty.SIZE index = v.idx assert index is not None inner = make_extractor(v.base) return lambda args: cast("torch.Tensor", inner(args)).size(index) if isinstance(v, LocalSource): index = arg_name_to_index[v.local_name] return operator.itemgetter(index) raise exc.SpecializeArgType(v) arg_name_to_index: dict[str, int] = { n: i for i, n in enumerate(self.kernel.signature.parameters.keys()) } extractors = [] for v in sorted(self.env.specialized_vars, key=lambda v: v.name): source = self.env.shape_env.var_to_sources[v][0] extractors.append(make_extractor(source)) return extractors def _implicit_config(self) -> Config | None: """ Returns a single config that is implicitly used by this kernel, if any. """ configs = self.kernel.configs if self._config is not None: return self._config if len(configs) == 1: return configs[0] if len(configs) == 0 and self.kernel.settings.use_default_config: return self.config_spec.default_config() return None def _require_implicit_config(self) -> Config: """ Returns the implicit config for this kernel, or raises an error if no implicit config is available. """ if (config := self._implicit_config()) is None: raise RuntimeError("no config provided and no implicit config available") return config def run_ref(self, *args: object) -> _R: # pyright: ignore[reportReturnType] # Unwrap ConstExpr arguments clean_args = [] for arg in args: if isinstance(arg, ConstExpr): clean_args.append(arg.value) else: clean_args.append(arg) # Pass the config to RefModeContext with RefModeContext(self.env, self._config): result = self.kernel.fn(*clean_args) return cast("_R", result) def __call__(self, *args: object) -> _R: """ Execute the kernel with the given arguments. Args: args: The arguments to pass to the kernel. Returns: _R: The result of the kernel execution. """ if is_ref_mode_enabled(self.kernel.settings): if (config := self._implicit_config()) is not None: self._config = config return self.run_ref(*args) if self._run is None: if (config := self._implicit_config()) is not None: self.set_config(config) else: self.autotune(args) assert self._run is not None return self._run(*args) class _KernelDecorator(Protocol): def __call__( self, fn: Callable[..., _R], ) -> Kernel[_R]: ... @overload def kernel( fn: Callable[..., _R], *, config: ConfigLike | None = None, configs: list[ConfigLike] | None = None, **settings: object, ) -> Kernel[_R]: ... @overload def kernel( fn: None = None, *, config: ConfigLike | None = None, configs: list[ConfigLike] | None = None, **settings: object, ) -> _KernelDecorator: ...
[docs] def kernel( fn: Callable[..., _R] | None = None, *, config: ConfigLike | None = None, configs: list[ConfigLike] | None = None, **settings: object, ) -> Kernel[_R] | _KernelDecorator: """ Decorator to create a Kernel object from a Python function. Args: fn: The function to be wrapped by the Kernel. If None, a decorator is returned. config: A single configuration to use for the kernel. See :class:`~helion.Config` for details. configs: A list of configurations to use for the kernel. Can only specify one of config or configs. See :class:`~helion.Config` for details. settings: Keyword arguments representing settings for the Kernel. Can also use settings=Settings(...) to pass a Settings object directly. See :class:`~helion.Settings` for available options. Returns: object: A Kernel object or a decorator that returns a Kernel object. See Also: - :class:`~helion.Settings`: Controls compilation behavior and debugging options - :class:`~helion.Config`: Controls GPU execution parameters and optimization strategies """ if config is not None: assert not configs, "Cannot specify both config and configs" configs = [config] elif configs is None: configs = [] if settings_obj := settings.get("settings"): assert len(settings) == 1, "settings must be the only keyword argument" assert isinstance(settings_obj, Settings), "settings must be a Settings object" else: settings_obj = Settings(**settings) if fn is None: return functools.partial(kernel, configs=configs, settings=settings_obj) return Kernel(fn, configs=configs, settings=settings_obj)
def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable: # NOTE: If a machine has two different gpu types on the same machine, # obj.device.type will incorrectly hit if fn.settings.static_shapes: return ( obj.dtype, obj.device.type, (*obj.size(),), (*obj.stride(),), ) return ( obj.dtype, obj.device.type, # 0, 1, or >=2 specialization tuple([min(s, 2) for s in obj.size()]), ) def _sequence_key(fn: Kernel, obj: Sequence) -> Hashable: return type(obj), tuple([fn._specialization_key(item) for item in obj]) def _mapping_key( fn: Kernel, obj: dict[str | int, object], real_type: type[object] ) -> Hashable: return real_type, tuple( sorted((k, fn._specialization_key(v)) for k, v in obj.items()) ) def _number_key(fn: Kernel, n: float | bool) -> object: return type(n) def _function_key(fn: Kernel, obj: types.FunctionType) -> object: if obj.__closure__: closures = [ fn._specialization_key(cell.cell_contents) for cell in obj.__closure__ ] return (obj.__code__, *closures) return obj.__code__ _specialization_extractors: dict[ type[object] | str, Callable[[Kernel, object], Hashable] ] = { # pyright: ignore[reportAssignmentType] torch.Tensor: _tensor_key, torch.nn.Parameter: _tensor_key, FakeTensor: _tensor_key, torch.dtype: lambda fn, x: x, torch.device: lambda fn, x: x, int: _number_key, float: _number_key, bool: _number_key, str: lambda fn, x: x, list: _sequence_key, tuple: _sequence_key, dict: lambda fn, x: _mapping_key(fn, x, type(x)), # pyright: ignore[reportArgumentType] "namedtuple": lambda fn, x: _mapping_key(fn, x._asdict(), type(x)), # pyright: ignore[reportAttributeAccessIssue] "dataclass": lambda fn, x: _mapping_key(fn, dataclasses.asdict(x), type(x)), # pyright: ignore[reportArgumentType] types.FunctionType: _function_key, types.BuiltinFunctionType: lambda fn, x: x, ConstExpr: lambda fn, x: x.value, # pyright: ignore[reportAttributeAccessIssue] } def _find_device(args: tuple[object, ...]) -> torch.device: """ Extract the device from the arguments. Args: args: The arguments to extract the device from. Returns: torch.device: The extracted device """ for arg in args: if isinstance(arg, torch.device): return arg if isinstance(arg, torch.Tensor): return arg.device if isinstance(arg, (tuple, list)): for item in arg: try: return _find_device(item) except exc.NoTensorArgs: pass elif isinstance(arg, dict): for item in arg.values(): try: return _find_device(item) except exc.NoTensorArgs: pass raise exc.NoTensorArgs def _maybe_skip_dtype_check_in_meta_registrations() -> ( contextlib.AbstractContextManager[None, None] ): if hasattr(torch.fx.experimental._config, "skip_dtype_check_in_meta_registrations"): # pyright: ignore[reportAttributeAccessIssue] return torch.fx.experimental._config.patch( # pyright: ignore[reportAttributeAccessIssue] skip_dtype_check_in_meta_registrations=True ) return contextlib.nullcontext()