Rate this Page

Source code for helion.runtime.kernel

from __future__ import annotations

import contextlib
import dataclasses
import functools
import inspect
import itertools
import logging
import operator
import re
import sys
import textwrap
import types
from typing import TYPE_CHECKING
from typing import Callable
from typing import Generic
from typing import Hashable
from typing import Sequence
from typing import TypeVar
from typing import cast
from typing import overload
from typing_extensions import Protocol

import torch
from torch._dynamo.source import LocalSource
from torch._dynamo.source import TensorProperty
from torch._dynamo.source import TensorPropertySource
from torch._inductor.codecache import PyCodeCache
from torch._inductor.codecache import compiled_fx_graph_hash
from torch._subclasses import FakeTensor
from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakIdKeyDictionary

from .. import exc
from .._compiler.ast_extension import unparse
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.generate_ast import generate_ast
from .._compiler.host_function import HostFunction
from .._compiler.inductor_lowering_extra import patch_inductor_lowerings
from .._compiler.output_header import assert_no_conflicts
from .._compiler.output_header import get_needed_imports
from .._compiler.variable_origin import ArgumentOrigin
from .._logging import LazyString
from .._utils import counters
from ..language.constexpr import ConstExpr
from .config import Config
from .ref_mode import RefModeContext
from .ref_mode import is_ref_mode_enabled
from .settings import Settings

if TYPE_CHECKING:
    from collections.abc import Hashable
    from collections.abc import Sequence

    from torch._guards import Source

    from ..autotuner import ConfigSpec
    from ..autotuner.base_cache import BoundKernelInMemoryCacheKey

    ConfigLike = Config | dict[str, object]

log: logging.Logger = logging.getLogger(__name__)
_R = TypeVar("_R")
CompiledConfig = Callable[..., _R]

# Cache for GraphModule hashes
_graph_module_hash_cache: WeakIdKeyDictionary = WeakIdKeyDictionary()

_INT32_INDEX_LIMIT = torch.iinfo(torch.int32).max


def _resolve_index_dtype(
    settings: Settings,
    args: Sequence[object] | tuple[object, ...],
) -> torch.dtype:
    if (index_dtype := settings.index_dtype) is not None:
        limit = torch.iinfo(index_dtype).max
    else:
        limit = _INT32_INDEX_LIMIT
    over_limit = False

    def _check(tensor: torch.Tensor) -> None:
        nonlocal over_limit
        if over_limit:
            return
        try:
            over_limit = bool(tensor.numel() > limit)
        except RuntimeError:  # unbacked SymInt
            if index_dtype is None:
                over_limit = True

    tree_map_only(torch.Tensor, _check, args)
    # pyrefly: ignore [unbound-name]
    if index_dtype is None:  # Auto-select when not provided
        return torch.int64 if over_limit else torch.int32
    if over_limit:
        # pyrefly: ignore [unbound-name]
        raise exc.InputTensorNumelExceedsIndexType(index_dtype=index_dtype)
    # pyrefly: ignore [unbound-name]
    return index_dtype


[docs] class Kernel(Generic[_R]):
[docs] def __init__( self, fn: Callable[..., _R], *, configs: list[ConfigLike] | None = None, settings: Settings | None, key: Callable[..., Hashable] | None = None, ) -> None: """ Initialize the Kernel object. This is typically called from the `@helion.kernel` decorator. Args: fn: The function to be compiled as a Helion kernel. configs: A list of configurations to use for the kernel. settings: The settings to be used by the Kernel. If None, a new `Settings()` instance is created. key: Optional callable that returns an extra hashable component for specialization. """ super().__init__() assert isinstance(fn, types.FunctionType) assert_no_conflicts(fn) self.name: str = fn.__name__ # pyrefly: ignore [read-only] self.fn: types.FunctionType = fn self.signature: inspect.Signature = inspect.signature(fn) self.settings: Settings = settings or Settings() self._key_fn: Callable[..., Hashable] | None = key self.configs: list[Config] = [ # pyrefly: ignore [bad-argument-type] Config(**c) if isinstance(c, dict) else c for c in configs or [] ] self._bound_kernels: dict[BoundKernelInMemoryCacheKey, BoundKernel] = {} self._specialize_extra: dict[ Hashable, list[Callable[[Sequence[object]], Hashable]] ] = {} if any( param.kind in ( inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, ) for param in self.signature.parameters.values() ): raise TypeError( f"Kernel({self.name}) cannot have *args, **kwargs, or keyword-only arguments" ) self._annotations: list[object] = [] for param in self.signature.parameters.values(): ann = param.annotation if isinstance(ann, str) and re.search(r"constexpr", ann, re.IGNORECASE): self._annotations.append(ConstExpr) else: self._annotations.append(ann) # Expose function attributes for compatibility with torch.library.custom_op # These are set as instance attributes to allow the Kernel to be used # as if it were a regular function for introspection purposes functools.update_wrapper(self, fn) # Manually add function-specific attributes not copied by update_wrapper self.__globals__ = fn.__globals__ self.__code__ = fn.__code__ self.__defaults__ = fn.__defaults__ self.__kwdefaults__ = fn.__kwdefaults__
def _get_bound_kernel_cache_key( self, args: tuple[object, ...], signature: tuple[Hashable, ...] ) -> BoundKernelInMemoryCacheKey | None: from ..autotuner.base_cache import BoundKernelInMemoryCacheKey extra_fns = self._specialize_extra.get(signature) if extra_fns is not None: extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns]) return BoundKernelInMemoryCacheKey(signature, extra_results) return None def _create_bound_kernel_cache_key( self, bound_kernel: BoundKernel, args: tuple[object, ...], signature: tuple[Hashable, ...], ) -> BoundKernelInMemoryCacheKey: from ..autotuner.base_cache import BoundKernelInMemoryCacheKey self._specialize_extra[signature] = extra_fns = bound_kernel._specialize_extra() extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns]) return BoundKernelInMemoryCacheKey(signature, extra_results)
[docs] def bind(self, args: tuple[object, ...]) -> BoundKernel[_R]: """ Bind the given arguments to the Kernel and return a BoundKernel object. Args: args: The arguments to bind to the Kernel. Returns: BoundKernel: A BoundKernel object with the given arguments bound. """ if not isinstance(args, tuple): assert isinstance(args, list), "args must be a tuple or list" args = tuple(args) if len(args) > len(self.signature.parameters): raise TypeError( f"Too many arguments passed to the kernel, expected: {len(self.signature.parameters)} got: {len(args)}." ) signature = self.specialization_key(args) cache_key = self._get_bound_kernel_cache_key(args, signature) bound_kernel = ( None if cache_key is None else self._bound_kernels.get(cache_key, None) ) if bound_kernel is None: normalized_args: tuple[object, ...] = self.normalize_args(*args) if len(normalized_args) != len(args): # we had default args that needed to be applied bound_kernel = self.bind(normalized_args) else: bound_kernel = BoundKernel(self, args) if cache_key is None: cache_key = self._create_bound_kernel_cache_key( bound_kernel, args, signature ) self._bound_kernels[cache_key] = bound_kernel return bound_kernel
[docs] def specialization_key(self, args: Sequence[object]) -> tuple[Hashable, ...]: """ Generate a specialization key for the given arguments. This method generates a unique key for the arguments based on their types and the corresponding extractor functions defined in `_specialization_extractors`. Args: args: The arguments to generate a specialization key for. Returns: Hashable: A hashable key representing the specialization of the arguments. """ result = [] assert len(args) <= len(self._annotations) for value, annotation in zip(args, self._annotations, strict=False): if isinstance(value, ConstExpr): result.append(value.value) elif annotation is ConstExpr: result.append(value) else: result.append(self._specialization_key(value)) if self._key_fn is not None: return (*result, self._key_fn(*args)) return (*result,)
def _specialization_key(self, obj: object) -> Hashable: """ Helper used to generate a specialization key for the given object. This method determines a unique key for the object based on its type and the corresponding extractor function defined in `_specialization_extractors`. Args: obj: The argument to generate a specialization key for. Returns: Hashable: A hashable key representing the specialization of the object. """ try: extractor = _specialization_extractors[type(obj)] except KeyError: if isinstance(obj, torch.fx.GraphModule): # GraphModule subclasses need special handling extractor = _specialization_extractors[torch.fx.GraphModule] elif isinstance(obj, tuple) and hasattr(obj, "_fields"): # this is a namedtuple extractor = _specialization_extractors["namedtuple"] elif dataclasses.is_dataclass(obj): extractor = _specialization_extractors["dataclass"] else: raise TypeError( f"unsupported argument type: {type(obj).__name__}" ) from None return extractor(self, obj)
[docs] def normalize_args(self, *args: object, **kwargs: object) -> tuple[object, ...]: """ Normalize the given arguments and keyword arguments according to the function signature. Args: args: The positional arguments to normalize. kwargs: The keyword arguments to normalize. Returns: tuple[object, ...]: A tuple of normalized positional arguments. """ bound_args = self.signature.bind(*args, **kwargs) bound_args.apply_defaults() return tuple(bound_args.args)
[docs] def autotune( self, args: Sequence[object], *, force: bool = True, **options: object, ) -> Config: """ Perform autotuning to find the optimal configuration for the kernel. This uses the default setting, you can call helion.autotune.* directly for more customization. If config= or configs= is provided to helion.kernel(), the search will be restricted to the provided configs. Use force=True to ignore the provided configs. Mutates (the bound version of) self so that `__call__` will run the best config found. Args: args: Example arguments used for benchmarking during autotuning. force: If True, force full autotuning even if a config is provided. options: Additional keyword options forwarded to the autotuner. Returns: Config: The best configuration found during autotuning. """ args = self.normalize_args(*args) return self.bind(args).autotune(args, force=force, **options)
[docs] def __call__(self, *args: object, **kwargs: object) -> _R: """ Call the Kernel with the given arguments and keyword arguments. Args: args: The positional arguments to pass to the Kernel. kwargs: The keyword arguments to pass to the Kernel. Returns: _R: The result of the Kernel function call. """ if kwargs: args = self.normalize_args(*args, **kwargs) return self.bind(args)(*args)
[docs] def reset(self) -> None: """ Clears the cache of bound kernels, meaning subsequent calls will recompile and re-autotune. """ self._bound_kernels.clear()
class BoundKernel(Generic[_R]): def __init__( self, kernel: Kernel[_R], args: tuple[object, ...], ) -> None: """ Initialize a BoundKernel object. This constructor sets up the environment, compiles the kernel function, and prepares the arguments for execution. Args: kernel: The Kernel object to bind. args: A tuple of arguments to bind to the kernel. """ super().__init__() self.kernel = kernel self._run: Callable[..., _R] | None = None self._config: Config | None = None self._compile_cache: dict[Config, CompiledConfig] = {} self._cache_path_map: dict[Config, str | None] = {} self.env = CompileEnvironment( _find_device(args), self.kernel.settings, index_dtype=_resolve_index_dtype(self.kernel.settings, args), ) if is_ref_mode_enabled(self.kernel.settings): self.fake_args = [] # type: ignore[assignment] self.host_function = None # type: ignore[assignment] return with self.env: assert len(args) == len(self.kernel.signature.parameters) self.fake_args: list[object] = [] constexpr_args = {} for name, arg, annotation in zip( self.kernel.signature.parameters, args, self.kernel._annotations, strict=False, ): if isinstance(arg, ConstExpr): assert not isinstance(arg.value, torch.Tensor), ( "ConstExpr cannot be a tensor" ) self.fake_args.append(arg.value) constexpr_args[name] = arg.value elif annotation is ConstExpr: assert not isinstance(arg, torch.Tensor), ( "ConstExpr cannot be a tensor" ) self.fake_args.append(arg) constexpr_args[name] = arg else: self.fake_args.append(self.env.to_fake(arg, ArgumentOrigin(name))) self._apply_mark_static(args) with ( _maybe_skip_dtype_check_in_meta_registrations(), patch_inductor_lowerings(), ): try: # pyrefly: ignore [bad-assignment] self.host_function: HostFunction = HostFunction( # pyrefly: ignore [bad-argument-type] self.kernel.fn, self.fake_args, constexpr_args, ) except Exception: config = self.env.config_spec.default_config() self.maybe_log_repro(log.warning, args, config=config) raise def _apply_mark_static(self, args: tuple[object, ...]) -> None: """ Apply torch._dynamo.mark_static() markings from input tensors. This reads _dynamo_static_indices from each tensor argument and marks the corresponding dimensions as specialized (constant) in the kernel. """ for arg, fake_arg in zip(args, self.fake_args, strict=True): if isinstance(arg, torch.Tensor) and isinstance(fake_arg, torch.Tensor): for dim in getattr(arg, "_dynamo_static_indices", ()): size = fake_arg.size(dim) if isinstance(size, torch.SymInt): self.env.specialized_vars.update(size._sympy_().free_symbols) @property def settings(self) -> Settings: """ Retrieve the settings associated with the kernel. Returns: Settings: The settings of the kernel. """ return self.kernel.settings @property def config_spec(self) -> ConfigSpec: """ Retrieve the configuration specification for the kernel. Returns: ConfigSpec: The configuration specification. """ return self.env.config_spec @property def configs(self) -> list[Config]: """ Alias for `self.kernel.configs`. Returns: list[Config]: The list of configurations. """ return self.kernel.configs def format_kernel_decorator(self, config: Config, settings: Settings) -> str: """Return the @helion.kernel decorator snippet capturing configs and settings that influence Triton code generation.""" parts = [ f"config={config.__repr__()}", f"static_shapes={settings.static_shapes}", ] if settings.index_dtype is not None: parts.append(f"index_dtype={settings.index_dtype}") return f"@helion.kernel({', '.join(parts)})" def to_triton_code( self, config: ConfigLike | None = None, *, emit_repro_caller: bool = False, output_origin_lines: bool | None = None, ) -> str: """ Generate Triton code for the kernel based on the given configuration. Args: config: The configuration to use for code generation. emit_repro_caller: Emits a main function to call the triton kernel with example inputs. Returns: str: The generated Triton code as a string. """ if config is None: config = self._require_implicit_config() with self.env: if not isinstance(config, Config): # pyrefly: ignore [bad-argument-type] config = Config(**config) self.env.config_spec.normalize(config) # pyrefly: ignore [bad-argument-type] root = generate_ast(self.host_function, config, emit_repro_caller) if output_origin_lines is None: output_origin_lines = self.settings.output_origin_lines return get_needed_imports(root) + unparse( root, output_origin_lines=output_origin_lines ) def compile_config( self, config: ConfigLike | None = None, *, allow_print: bool = True ) -> CompiledConfig: """ Compile the kernel for a specific configuration. Args: config: The configuration to compile the kernel with. allow_print: Set to suppress printing the output code when autotuning. Returns: CompiledConfig: A callable object representing the compiled kernel. """ if config is None: config = self._require_implicit_config() if not isinstance(config, Config): config = Config( # pyrefly: ignore [bad-argument-type] **config ) if (rv := self._compile_cache.get(config)) is not None: return rv try: triton_code = self.to_triton_code( config, emit_repro_caller=self.settings.print_output_code ) module = PyCodeCache.load(triton_code) except Exception: log.warning( "Helion compiler triton codegen error for %s", self.format_kernel_decorator(config, self.settings), exc_info=True, ) self.maybe_log_repro(log.warning, self.fake_args, config=config) raise if allow_print: log.info("Output code written to: %s", module.__file__) log.debug("Debug string: \n%s", LazyString(lambda: self._debug_str())) if self.settings.print_output_code: log.info("Output code: \n%s", triton_code) print(triton_code, file=sys.stderr) rv = getattr(module, self.kernel.name) self._compile_cache[config] = rv self._cache_path_map[config] = module.__file__ return rv def get_cached_path(self, config: ConfigLike | None = None) -> str | None: """ Get the file path of the generated Triton code for a specific configuration. Args: config: The configuration to get the file path for. Returns: str | None: The file path of the generated Triton code, or None if not found. """ if config is None: config = self._require_implicit_config() if not isinstance(config, Config): config = Config( # pyrefly: ignore [bad-argument-type] **config ) return self._cache_path_map.get(config, None) def _debug_str(self) -> str: """ Generate a debug string for the kernel. Returns: str: A string containing debug information about the kernel. """ if self.host_function is None: # In ref mode, host_function is not created return f"<BoundKernel {self.kernel.fn.__name__} in ref mode>" with self.env: return self.host_function.debug_str() def autotune( self, args: Sequence[object], *, force: bool = True, **kwargs: object, ) -> Config: """ Perform autotuning to find the optimal configuration for the kernel. This uses the default setting, you can call helion.autotune.* directly for more customization. If config= or configs= is provided to helion.kernel(), the search will be restricted to the provided configs. Use force=True to ignore the provided configs. Mutates self so that `__call__` will run the best config found. Args: args: Example arguments used for benchmarking during autotuning. force: If True, force full autotuning even if a config is provided. kwargs: Additional keyword options forwarded to the autotuner. Returns: Config: The best configuration found during autotuning. """ force = force or self.settings.force_autotune if not force and self.kernel.configs: if len(self.kernel.configs) == 1: (config,) = self.kernel.configs else: # We have finite predetermined configs, no need to precompile self.settings.autotune_precompile = None from ..autotuner import FiniteSearch config = FiniteSearch(self, args, self.configs).autotune() else: self.settings.check_autotuning_disabled() config = self.settings.autotuner_fn(self, args, **kwargs).autotune( skip_cache=force ) self.set_config(config) return config def set_config(self, config: ConfigLike) -> None: """ Set the configuration for the kernel and compile it. Mutates self so that `__call__` will run the provided config. Args: config: The configuration to set. """ if not isinstance(config, Config): config = Config( # pyrefly: ignore [bad-argument-type] **config ) self._run = self.compile_config(config) self._config = config def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]: """ Returns a list of functions that will be called to generate extra specialization keys. This is used to specialize on the values hl.specialize()'ed arguments. Returns: list[Callable[[Sequence[object]], Hashable]]: A list of functions that generate extra specialization keys. """ if not self.env.specialized_vars: return [] def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]: if isinstance(v, TensorPropertySource): index = v.idx assert index is not None inner = make_extractor(v.base) if v.prop == TensorProperty.SIZE: return lambda args: cast("torch.Tensor", inner(args)).size(index) if v.prop == TensorProperty.STRIDE: return lambda args: cast("torch.Tensor", inner(args)).stride(index) raise exc.SpecializeArgType(v) if isinstance(v, LocalSource): index = arg_name_to_index[v.local_name] return operator.itemgetter(index) raise exc.SpecializeArgType(v) arg_name_to_index: dict[str, int] = { n: i for i, n in enumerate(self.kernel.signature.parameters.keys()) } extractors = [] for v in sorted(self.env.specialized_vars, key=lambda v: v.name): source = self.env.shape_env.var_to_sources[v][0] extractors.append(make_extractor(source)) return extractors def _implicit_config(self) -> Config | None: """ Returns a single config that is implicitly used by this kernel, if any. """ configs = self.kernel.configs if self._config is not None: return self._config if self.settings.force_autotune: # If force autotune is enabled, do not pick an implicit config return None if len(configs) == 1: return configs[0] if len(configs) == 0 and self.kernel.settings.autotune_effort == "none": config = self.config_spec.default_config() if not is_ref_mode_enabled(self.kernel.settings): kernel_decorator = self.format_kernel_decorator(config, self.settings) print( f"Using default config: {kernel_decorator}", file=sys.stderr, ) return config return None def _require_implicit_config(self) -> Config: """ Returns the implicit config for this kernel, or raises an error if no implicit config is available. """ if (config := self._implicit_config()) is None: raise RuntimeError("no config provided and no implicit config available") return config # pyrefly: ignore [bad-return] def run_ref(self, *args: object) -> _R: # Unwrap ConstExpr arguments clean_args = [] for arg in args: if isinstance(arg, ConstExpr): clean_args.append(arg.value) else: clean_args.append(arg) # Pass the config to RefModeContext with RefModeContext(self.env, self._config): result = self.kernel.fn(*clean_args) return cast("_R", result) def __call__(self, *args: object) -> _R: """ Execute the kernel with the given arguments. Args: args: The arguments to pass to the kernel. Returns: _R: The result of the kernel execution. """ if is_ref_mode_enabled(self.kernel.settings): if (config := self._implicit_config()) is not None: self._config = config return self.run_ref(*args) if self._run is None: if (config := self._implicit_config()) is not None: self.set_config(config) else: self.autotune(args, force=False) assert self._run is not None assert self._config is not None counters["best_config_decorator"][ self.format_kernel_decorator(self._config, self.settings) ] = 1 self.maybe_log_repro(log.warning, args) return self._run(*args) def maybe_log_repro( self, log_func: Callable[[str], None], args: Sequence[object], config: Config | None = None, ) -> None: if not self.settings.print_repro: return effective_config = config or self._config assert effective_config is not None # Get kernel source try: raw_source = inspect.getsource(self.kernel.fn) source_lines = textwrap.dedent(raw_source).splitlines() # Skip decorator lines (including multi-line decorators) start_idx = 0 while start_idx < len(source_lines) and not source_lines[ start_idx ].lstrip().startswith("def "): start_idx += 1 kernel_body = "\n".join(source_lines[start_idx:]) except (OSError, TypeError): kernel_body = f"# Source unavailable for {self.kernel.fn.__module__}.{self.kernel.fn.__qualname__}" # Format decorator decorator = self.format_kernel_decorator(effective_config, self.settings) # Build output output_lines = [ "# === HELION KERNEL REPRO ===", "import helion", "import helion.language as hl", "import torch", "from torch._dynamo.testing import rand_strided", "", decorator, kernel_body, ] # Generate caller function if args: def _render_input_arg_assignment(name: str, value: object) -> list[str]: if isinstance(value, torch.Tensor): shape = tuple(int(d) for d in value.shape) stride = tuple(int(s) for s in value.stride()) device = str(value.device) dtype = str(value.dtype) lines = [ f"{name} = rand_strided({shape!r}, {stride!r}, dtype={dtype}, device={device!r})" ] if value.requires_grad: lines.append(f"{name}.requires_grad_(True)") return lines return [f"{name} = {value!r}"] sig_param_names = list(self.kernel.signature.parameters.keys()) assert len(args) == len(sig_param_names) output_lines.extend(["", "def helion_repro_caller():"]) output_lines.append(" torch.manual_seed(0)") arg_names: list[str] = [] for i, value in enumerate(args): var_name = sig_param_names[i] arg_names.append(var_name) # Add assignment lines with indentation for line in _render_input_arg_assignment(var_name, value): output_lines.append(f" {line}") # Add return statement call_args = ", ".join(arg_names) output_lines.append(f" return {self.kernel.name}({call_args})") output_lines.extend(["", "helion_repro_caller()"]) output_lines.append("# === END HELION KERNEL REPRO ===") repro_text = "\n" + "\n".join(output_lines) log_func(repro_text) class _KernelDecorator(Protocol): def __call__( self, fn: Callable[..., _R], ) -> Kernel[_R]: ... @overload def kernel( fn: Callable[..., _R], *, config: ConfigLike | None = None, configs: list[ConfigLike] | None = None, key: Callable[..., Hashable] | None = None, **settings: object, ) -> Kernel[_R]: ... @overload def kernel( fn: None = None, *, config: ConfigLike | None = None, configs: list[ConfigLike] | None = None, key: Callable[..., Hashable] | None = None, **settings: object, ) -> _KernelDecorator: ...
[docs] def kernel( fn: Callable[..., _R] | None = None, *, config: ConfigLike | None = None, configs: list[ConfigLike] | None = None, key: Callable[..., Hashable] | None = None, **settings: object, ) -> Kernel[_R] | _KernelDecorator: """ Decorator to create a Kernel object from a Python function. Args: fn: The function to be wrapped by the Kernel. If None, a decorator is returned. config: A single configuration to use for the kernel. Refer to the ``helion.Config`` class for details. configs: A list of configurations to use for the kernel. Can only specify one of config or configs. Refer to the ``helion.Config`` class for details. key: Optional callable returning a hashable that augments the specialization key. settings: Keyword arguments representing settings for the Kernel. Can also use settings=Settings(...) to pass a Settings object directly. Refer to the ``helion.Settings`` class for available options. Returns: object: A Kernel object or a decorator that returns a Kernel object. See Also: - :class:`~helion.Settings`: Controls compilation behavior and debugging options - :class:`~helion.Config`: Controls GPU execution parameters and optimization strategies """ if config is not None: assert not configs, "Cannot specify both config and configs" configs = [config] elif configs is None: configs = [] if settings_obj := settings.get("settings"): assert len(settings) == 1, "settings must be the only keyword argument" assert isinstance(settings_obj, Settings), "settings must be a Settings object" else: settings_obj = Settings(**settings) if fn is None: return functools.partial( kernel, configs=configs, settings=settings_obj, key=key ) return Kernel(fn, configs=configs, settings=settings_obj, key=key)
def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable: # NOTE: If a machine has two different gpu types on the same machine, # obj.device.type will incorrectly hit static_indices = frozenset(getattr(obj, "_dynamo_static_indices", ())) if fn.settings.static_shapes: return ( obj.dtype, obj.device.type, (*obj.size(),), (*obj.stride(),), static_indices, ) bucketed = tuple([min(s, 2) for s in obj.size()]) if fn.settings.index_dtype is None: try: needs_int64 = bool(obj.numel() > _INT32_INDEX_LIMIT) except RuntimeError: needs_int64 = True # unbacked SymInt return ( obj.dtype, obj.device.type, bucketed, needs_int64, static_indices, ) return ( obj.dtype, obj.device.type, bucketed, static_indices, ) def _sequence_key(fn: Kernel, obj: Sequence) -> Hashable: return type(obj), tuple([fn._specialization_key(item) for item in obj]) def _mapping_key( fn: Kernel, obj: dict[str | int, object], real_type: type[object] ) -> Hashable: return real_type, tuple( sorted((k, fn._specialization_key(v)) for k, v in obj.items()) ) def _number_key(fn: Kernel, n: float | bool) -> object: return type(n) def _function_key(fn: Kernel, obj: types.FunctionType) -> object: if obj.__closure__: closures = [ fn._specialization_key(cell.cell_contents) for cell in obj.__closure__ ] return (obj.__code__, *closures) return obj.__code__ def _graph_module_key(fn: Kernel, obj: torch.fx.GraphModule) -> Hashable: """Generate a specialization key for GraphModule arguments.""" # Check if already cached if obj in _graph_module_hash_cache: return _graph_module_hash_cache[obj] # Check for unsupported operations unsupported_ops = { node.op for node in itertools.chain( obj.graph.find_nodes(op="call_module"), obj.graph.find_nodes(op="get_attr"), ) } if unsupported_ops: raise exc.GraphModuleUnsupportedOps(", ".join(sorted(unsupported_ops))) _graph_module_hash_cache[obj] = rv = str(compiled_fx_graph_hash(obj, [], {}, [])) return rv _specialization_extractors: dict[ type[object] | str, Callable[[Kernel, object], Hashable] # pyrefly: ignore [bad-assignment] ] = { torch.Tensor: _tensor_key, torch.nn.Parameter: _tensor_key, FakeTensor: _tensor_key, torch.dtype: lambda fn, x: x, torch.device: lambda fn, x: x, int: _number_key, float: _number_key, bool: _number_key, str: lambda fn, x: x, list: _sequence_key, tuple: _sequence_key, # pyrefly: ignore [bad-argument-type] dict: lambda fn, x: _mapping_key(fn, x, type(x)), # pyrefly: ignore [missing-attribute] "namedtuple": lambda fn, x: _mapping_key(fn, x._asdict(), type(x)), # pyrefly: ignore [no-matching-overload] "dataclass": lambda fn, x: _mapping_key(fn, dataclasses.asdict(x), type(x)), types.FunctionType: _function_key, types.BuiltinFunctionType: lambda fn, x: x, torch.fx.GraphModule: _graph_module_key, # pyrefly: ignore [missing-attribute] ConstExpr: lambda fn, x: x.value, type(None): lambda fn, x: None, } def _find_device(args: tuple[object, ...]) -> torch.device: """ Extract the device from the arguments. Args: args: The arguments to extract the device from. Returns: torch.device: The extracted device """ for arg in args: if isinstance(arg, torch.device): return arg if isinstance(arg, torch.Tensor): return arg.device if isinstance(arg, (tuple, list)): for item in arg: try: return _find_device(item) except exc.NoTensorArgs: pass elif isinstance(arg, dict): for item in arg.values(): try: return _find_device(item) except exc.NoTensorArgs: pass raise exc.NoTensorArgs def _maybe_skip_dtype_check_in_meta_registrations() -> ( contextlib.AbstractContextManager[None, None] ): # pyrefly: ignore [implicit-import] if hasattr(torch.fx.experimental._config, "skip_dtype_check_in_meta_registrations"): # pyrefly: ignore [implicit-import, missing-attribute] return torch.fx.experimental._config.patch( skip_dtype_check_in_meta_registrations=True ) return contextlib.nullcontext()