from __future__ import annotations
import contextlib
import dataclasses
import functools
import inspect
import itertools
import logging
import operator
import re
import sys
import textwrap
import types
from typing import TYPE_CHECKING
from typing import Callable
from typing import Generic
from typing import Hashable
from typing import TypeVar
from typing import cast
from typing import overload
from typing_extensions import Protocol
import torch
from torch._dynamo.source import LocalSource
from torch._dynamo.source import TensorProperty
from torch._dynamo.source import TensorPropertySource
from torch._inductor.codecache import PyCodeCache
from torch._inductor.codecache import compiled_fx_graph_hash
from torch._subclasses import FakeTensor
from torch.utils.weak import WeakIdKeyDictionary
from .. import exc
from .._compiler.ast_extension import unparse
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.generate_ast import generate_ast
from .._compiler.host_function import HostFunction
from .._compiler.inductor_lowering_extra import patch_inductor_lowerings
from .._compiler.output_header import assert_no_conflicts
from .._compiler.output_header import get_needed_imports
from .._compiler.variable_origin import ArgumentOrigin
from .._logging import LazyString
from .._utils import counters
from ..language.constexpr import ConstExpr
from .config import Config
from .ref_mode import RefModeContext
from .ref_mode import is_ref_mode_enabled
from .settings import Settings
if TYPE_CHECKING:
from collections.abc import Hashable
from collections.abc import Sequence
from torch._guards import Source
from ..autotuner import ConfigSpec
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey
ConfigLike = Config | dict[str, object]
log: logging.Logger = logging.getLogger(__name__)
_R = TypeVar("_R")
CompiledConfig = Callable[..., _R]
# Cache for GraphModule hashes
_graph_module_hash_cache: WeakIdKeyDictionary = WeakIdKeyDictionary()
[docs]
class Kernel(Generic[_R]):
[docs]
def __init__(
self,
fn: Callable[..., _R],
*,
configs: list[ConfigLike] | None = None,
settings: Settings | None,
key: Callable[..., Hashable] | None = None,
) -> None:
"""
Initialize the Kernel object. This is typically called from the `@helion.kernel` decorator.
Args:
fn: The function to be compiled as a Helion kernel.
configs: A list of configurations to use for the kernel.
settings: The settings to be used by the Kernel. If None, a new `Settings()` instance is created.
key: Optional callable that returns an extra hashable component for specialization.
"""
super().__init__()
assert isinstance(fn, types.FunctionType)
assert_no_conflicts(fn)
self.name: str = fn.__name__
self.fn: types.FunctionType = fn
self.signature: inspect.Signature = inspect.signature(fn)
self.settings: Settings = settings or Settings()
self._key_fn: Callable[..., Hashable] | None = key
self.configs: list[Config] = [
Config(**c) if isinstance(c, dict) else c # pyright: ignore[reportArgumentType]
for c in configs or []
]
self._bound_kernels: dict[BoundKernelInMemoryCacheKey, BoundKernel] = {}
self._specialize_extra: dict[
Hashable, list[Callable[[Sequence[object]], Hashable]]
] = {}
if any(
param.kind
in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
for param in self.signature.parameters.values()
):
raise TypeError(
f"Kernel({self.name}) cannot have *args, **kwargs, or keyword-only arguments"
)
self._annotations: list[object] = []
for param in self.signature.parameters.values():
ann = param.annotation
if isinstance(ann, str) and re.search(r"constexpr", ann, re.IGNORECASE):
self._annotations.append(ConstExpr)
else:
self._annotations.append(ann)
def _get_bound_kernel_cache_key(
self, args: tuple[object, ...], signature: tuple[Hashable, ...]
) -> BoundKernelInMemoryCacheKey | None:
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey
extra_fns = self._specialize_extra.get(signature)
if extra_fns is not None:
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns])
return BoundKernelInMemoryCacheKey(signature, extra_results)
return None
def _create_bound_kernel_cache_key(
self,
bound_kernel: BoundKernel,
args: tuple[object, ...],
signature: tuple[Hashable, ...],
) -> BoundKernelInMemoryCacheKey:
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey
self._specialize_extra[signature] = extra_fns = bound_kernel._specialize_extra()
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns])
return BoundKernelInMemoryCacheKey(signature, extra_results)
[docs]
def bind(self, args: tuple[object, ...]) -> BoundKernel[_R]:
"""
Bind the given arguments to the Kernel and return a BoundKernel object.
Args:
args: The arguments to bind to the Kernel.
Returns:
BoundKernel: A BoundKernel object with the given arguments bound.
"""
if not isinstance(args, tuple):
assert isinstance(args, list), "args must be a tuple or list"
args = tuple(args)
if len(args) > len(self.signature.parameters):
raise TypeError(
f"Too many arguments passed to the kernel, expected: {len(self.signature.parameters)} got: {len(args)}."
)
signature = self.specialization_key(args)
cache_key = self._get_bound_kernel_cache_key(args, signature)
bound_kernel = (
None if cache_key is None else self._bound_kernels.get(cache_key, None)
)
if bound_kernel is None:
normalized_args: tuple[object, ...] = self.normalize_args(*args)
if len(normalized_args) != len(args):
# we had default args that needed to be applied
bound_kernel = self.bind(normalized_args)
else:
bound_kernel = BoundKernel(self, args)
if cache_key is None:
cache_key = self._create_bound_kernel_cache_key(
bound_kernel, args, signature
)
self._bound_kernels[cache_key] = bound_kernel
return bound_kernel
[docs]
def specialization_key(self, args: Sequence[object]) -> tuple[Hashable, ...]:
"""
Generate a specialization key for the given arguments.
This method generates a unique key for the arguments based on their types
and the corresponding extractor functions defined in `_specialization_extractors`.
Args:
args: The arguments to generate a specialization key for.
Returns:
Hashable: A hashable key representing the specialization of the arguments.
"""
result = []
assert len(args) <= len(self._annotations)
for value, annotation in zip(args, self._annotations, strict=False):
if isinstance(value, ConstExpr):
result.append(value.value)
elif annotation is ConstExpr:
result.append(value)
else:
result.append(self._specialization_key(value))
if self._key_fn is not None:
return (*result, self._key_fn(*args))
return (*result,)
def _specialization_key(self, obj: object) -> Hashable:
"""
Helper used to generate a specialization key for the given object.
This method determines a unique key for the object based on its type
and the corresponding extractor function defined in `_specialization_extractors`.
Args:
obj: The argument to generate a specialization key for.
Returns:
Hashable: A hashable key representing the specialization of the object.
"""
try:
extractor = _specialization_extractors[type(obj)]
except KeyError:
if isinstance(obj, torch.fx.GraphModule):
# GraphModule subclasses need special handling
extractor = _specialization_extractors[torch.fx.GraphModule]
elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
# this is a namedtuple
extractor = _specialization_extractors["namedtuple"]
elif dataclasses.is_dataclass(obj):
extractor = _specialization_extractors["dataclass"]
else:
raise TypeError(
f"unsupported argument type: {type(obj).__name__}"
) from None
return extractor(self, obj)
[docs]
def normalize_args(self, *args: object, **kwargs: object) -> tuple[object, ...]:
"""
Normalize the given arguments and keyword arguments according to the function signature.
Args:
args: The positional arguments to normalize.
kwargs: The keyword arguments to normalize.
Returns:
tuple[object, ...]: A tuple of normalized positional arguments.
"""
bound_args = self.signature.bind(*args, **kwargs)
bound_args.apply_defaults()
return tuple(bound_args.args)
[docs]
def autotune(
self,
args: Sequence[object],
*,
force: bool = True,
**options: object,
) -> Config:
"""
Perform autotuning to find the optimal configuration for the kernel. This uses the
default setting, you can call helion.autotune.* directly for more customization.
If config= or configs= is provided to helion.kernel(), the search will be restricted to
the provided configs. Use force=True to ignore the provided configs.
Mutates (the bound version of) self so that `__call__` will run the best config found.
Args:
args: Example arguments used for benchmarking during autotuning.
force: If True, force full autotuning even if a config is provided.
options: Additional keyword options forwarded to the autotuner.
Returns:
Config: The best configuration found during autotuning.
"""
args = self.normalize_args(*args)
return self.bind(args).autotune(args, force=force, **options)
[docs]
def __call__(self, *args: object, **kwargs: object) -> _R:
"""
Call the Kernel with the given arguments and keyword arguments.
Args:
args: The positional arguments to pass to the Kernel.
kwargs: The keyword arguments to pass to the Kernel.
Returns:
_R: The result of the Kernel function call.
"""
if kwargs:
args = self.normalize_args(*args, **kwargs)
return self.bind(args)(*args)
[docs]
def reset(self) -> None:
"""
Clears the cache of bound kernels, meaning subsequent calls will
recompile and re-autotune.
"""
self._bound_kernels.clear()
class BoundKernel(Generic[_R]):
def __init__(
self,
kernel: Kernel[_R],
args: tuple[object, ...],
) -> None:
"""
Initialize a BoundKernel object.
This constructor sets up the environment, compiles the kernel function, and prepares
the arguments for execution.
Args:
kernel: The Kernel object to bind.
args: A tuple of arguments to bind to the kernel.
"""
super().__init__()
self.kernel = kernel
self._run: Callable[..., _R] | None = None
self._config: Config | None = None
self._compile_cache: dict[Config, CompiledConfig] = {}
self.env = CompileEnvironment(_find_device(args), self.kernel.settings)
if is_ref_mode_enabled(self.kernel.settings):
self.fake_args = [] # type: ignore[assignment]
self.host_function = None # type: ignore[assignment]
return
with self.env:
assert len(args) == len(self.kernel.signature.parameters)
self.fake_args: list[object] = []
constexpr_args = {}
for name, arg, annotation in zip(
self.kernel.signature.parameters,
args,
self.kernel._annotations,
strict=False,
):
if isinstance(arg, ConstExpr):
assert not isinstance(arg.value, torch.Tensor), (
"ConstExpr cannot be a tensor"
)
self.fake_args.append(arg.value)
constexpr_args[name] = arg.value
elif annotation is ConstExpr:
assert not isinstance(arg, torch.Tensor), (
"ConstExpr cannot be a tensor"
)
self.fake_args.append(arg)
constexpr_args[name] = arg
else:
self.fake_args.append(self.env.to_fake(arg, ArgumentOrigin(name)))
with (
_maybe_skip_dtype_check_in_meta_registrations(),
patch_inductor_lowerings(),
):
self.host_function: HostFunction = HostFunction(
self.kernel.fn, self.fake_args, constexpr_args
)
@property
def settings(self) -> Settings:
"""
Retrieve the settings associated with the kernel.
Returns:
Settings: The settings of the kernel.
"""
return self.kernel.settings
@property
def config_spec(self) -> ConfigSpec:
"""
Retrieve the configuration specification for the kernel.
Returns:
ConfigSpec: The configuration specification.
"""
return self.env.config_spec
@property
def configs(self) -> list[Config]:
"""
Alias for `self.kernel.configs`.
Returns:
list[Config]: The list of configurations.
"""
return self.kernel.configs
def format_kernel_decorator(self, config: Config, settings: Settings) -> str:
"""Return the @helion.kernel decorator snippet capturing configs and settings that influence Triton code generation."""
return f"@helion.kernel(config={config.__repr__()}, static_shapes={settings.static_shapes})"
def to_triton_code(
self,
config: ConfigLike | None = None,
*,
emit_repro_caller: bool = False,
output_origin_lines: bool | None = None,
) -> str:
"""
Generate Triton code for the kernel based on the given configuration.
Args:
config: The configuration to use for code generation.
emit_repro_caller: Emits a main function to call the triton kernel with example inputs.
Returns:
str: The generated Triton code as a string.
"""
if config is None:
config = self._require_implicit_config()
with self.env:
if not isinstance(config, Config):
config = Config(**config) # pyright: ignore[reportArgumentType]
self.env.config_spec.normalize(config)
root = generate_ast(self.host_function, config, emit_repro_caller)
if output_origin_lines is None:
output_origin_lines = self.settings.output_origin_lines
return get_needed_imports(root) + unparse(
root, output_origin_lines=output_origin_lines
)
def compile_config(
self, config: ConfigLike | None = None, *, allow_print: bool = True
) -> CompiledConfig:
"""
Compile the kernel for a specific configuration.
Args:
config: The configuration to compile the kernel with.
allow_print: Set to suppress printing the output code when autotuning.
Returns:
CompiledConfig: A callable object representing the compiled kernel.
"""
if config is None:
config = self._require_implicit_config()
if not isinstance(config, Config):
config = Config(
**config # pyright: ignore[reportArgumentType]
)
if (rv := self._compile_cache.get(config)) is not None:
return rv
try:
triton_code = self.to_triton_code(
config, emit_repro_caller=self.settings.print_output_code
)
module = PyCodeCache.load(triton_code)
except Exception:
log.warning(
"Helion compiler triton codegen error for %s",
self.format_kernel_decorator(config, self.settings),
exc_info=True,
)
raise
if allow_print:
log.info("Output code: \n%s", triton_code)
log.info("Output code written to: %s", module.__file__)
log.debug("Debug string: \n%s", LazyString(lambda: self._debug_str()))
if self.settings.print_output_code:
print(triton_code, file=sys.stderr)
rv = getattr(module, self.kernel.name)
self._compile_cache[config] = rv
return rv
def _debug_str(self) -> str:
"""
Generate a debug string for the kernel.
Returns:
str: A string containing debug information about the kernel.
"""
if self.host_function is None:
# In ref mode, host_function is not created
return f"<BoundKernel {self.kernel.fn.__name__} in ref mode>"
with self.env:
return self.host_function.debug_str()
def autotune(
self,
args: Sequence[object],
*,
force: bool = True,
**kwargs: object,
) -> Config:
"""
Perform autotuning to find the optimal configuration for the kernel. This uses the
default setting, you can call helion.autotune.* directly for more customization.
If config= or configs= is provided to helion.kernel(), the search will be restricted to
the provided configs. Use force=True to ignore the provided configs.
Mutates self so that `__call__` will run the best config found.
Args:
args: Example arguments used for benchmarking during autotuning.
force: If True, force full autotuning even if a config is provided.
kwargs: Additional keyword options forwarded to the autotuner.
Returns:
Config: The best configuration found during autotuning.
"""
force = force or self.settings.force_autotune
if not force and self.kernel.configs:
if len(self.kernel.configs) == 1:
(config,) = self.kernel.configs
else:
# We have finite predetermined configs, no need to precompile
self.settings.autotune_precompile = None
from ..autotuner import FiniteSearch
config = FiniteSearch(self, args, self.configs).autotune()
else:
self.settings.check_autotuning_disabled()
config = self.settings.autotuner_fn(self, args, **kwargs).autotune(
skip_cache=force
)
self.set_config(config)
return config
def set_config(self, config: ConfigLike) -> None:
"""
Set the configuration for the kernel and compile it.
Mutates self so that `__call__` will run the provided config.
Args:
config: The configuration to set.
"""
if not isinstance(config, Config):
config = Config(
**config # pyright: ignore[reportArgumentType]
)
self._run = self.compile_config(config)
self._config = config
def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]:
"""
Returns a list of functions that will be called to generate extra specialization keys.
This is used to specialize on the values hl.specialize()'ed arguments.
Returns:
list[Callable[[Sequence[object]], Hashable]]: A list of functions that generate extra specialization keys.
"""
if not self.env.specialized_vars:
return []
def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]:
if isinstance(v, TensorPropertySource):
assert v.prop == TensorProperty.SIZE
index = v.idx
assert index is not None
inner = make_extractor(v.base)
return lambda args: cast("torch.Tensor", inner(args)).size(index)
if isinstance(v, LocalSource):
index = arg_name_to_index[v.local_name]
return operator.itemgetter(index)
raise exc.SpecializeArgType(v)
arg_name_to_index: dict[str, int] = {
n: i for i, n in enumerate(self.kernel.signature.parameters.keys())
}
extractors = []
for v in sorted(self.env.specialized_vars, key=lambda v: v.name):
source = self.env.shape_env.var_to_sources[v][0]
extractors.append(make_extractor(source))
return extractors
def _implicit_config(self) -> Config | None:
"""
Returns a single config that is implicitly used by this kernel, if any.
"""
configs = self.kernel.configs
if self._config is not None:
return self._config
if self.settings.force_autotune:
# If force autotune is enabled, do not pick an implicit config
return None
if len(configs) == 1:
return configs[0]
if len(configs) == 0 and self.kernel.settings.autotune_effort == "none":
config = self.config_spec.default_config()
if not is_ref_mode_enabled(self.kernel.settings):
kernel_decorator = self.format_kernel_decorator(config, self.settings)
print(
f"Using default config: {kernel_decorator}",
file=sys.stderr,
)
return config
return None
def _require_implicit_config(self) -> Config:
"""
Returns the implicit config for this kernel, or raises an error if no implicit config is available.
"""
if (config := self._implicit_config()) is None:
raise RuntimeError("no config provided and no implicit config available")
return config
def run_ref(self, *args: object) -> _R: # pyright: ignore[reportReturnType]
# Unwrap ConstExpr arguments
clean_args = []
for arg in args:
if isinstance(arg, ConstExpr):
clean_args.append(arg.value)
else:
clean_args.append(arg)
# Pass the config to RefModeContext
with RefModeContext(self.env, self._config):
result = self.kernel.fn(*clean_args)
return cast("_R", result)
def __call__(self, *args: object) -> _R:
"""
Execute the kernel with the given arguments.
Args:
args: The arguments to pass to the kernel.
Returns:
_R: The result of the kernel execution.
"""
if is_ref_mode_enabled(self.kernel.settings):
if (config := self._implicit_config()) is not None:
self._config = config
return self.run_ref(*args)
if self._run is None:
if (config := self._implicit_config()) is not None:
self.set_config(config)
else:
self.autotune(args, force=False)
assert self._run is not None
assert self._config is not None
counters["best_config_decorator"][
self.format_kernel_decorator(self._config, self.settings)
] = 1
self.maybe_log_repro(log.warning, args)
return self._run(*args)
def maybe_log_repro(
self,
log_func: Callable[[str], None],
args: Sequence[object],
config: Config | None = None,
) -> None:
if not self.settings.print_repro:
return
effective_config = config or self._config
assert effective_config is not None
# Get kernel source
try:
raw_source = inspect.getsource(self.kernel.fn)
source_lines = textwrap.dedent(raw_source).splitlines()
# Skip decorator lines (including multi-line decorators)
start_idx = 0
while start_idx < len(source_lines) and not source_lines[
start_idx
].lstrip().startswith("def "):
start_idx += 1
kernel_body = "\n".join(source_lines[start_idx:])
except (OSError, TypeError):
kernel_body = f"# Source unavailable for {self.kernel.fn.__module__}.{self.kernel.fn.__qualname__}"
# Format decorator
decorator = self.format_kernel_decorator(effective_config, self.settings)
# Build output
output_lines = [
"# === HELION KERNEL REPRO ===",
"import helion",
"import helion.language as hl",
"import torch",
"from torch._dynamo.testing import rand_strided",
"",
decorator,
kernel_body,
]
# Generate caller function
if args:
def _render_input_arg_assignment(name: str, value: object) -> list[str]:
if isinstance(value, torch.Tensor):
shape = tuple(int(d) for d in value.shape)
stride = tuple(int(s) for s in value.stride())
device = str(value.device)
dtype = str(value.dtype)
lines = [
f"{name} = rand_strided({shape!r}, {stride!r}, dtype={dtype}, device={device!r})"
]
if value.requires_grad:
lines.append(f"{name}.requires_grad_(True)")
return lines
return [f"{name} = {value!r}"]
sig_param_names = list(self.kernel.signature.parameters.keys())
assert len(args) == len(sig_param_names)
output_lines.extend(["", "def helion_repro_caller():"])
output_lines.append(" torch.manual_seed(0)")
arg_names = []
for i, value in enumerate(args):
var_name = sig_param_names[i]
arg_names.append(var_name)
# Add assignment lines with indentation
for line in _render_input_arg_assignment(var_name, value):
output_lines.append(f" {line}")
# Add return statement
call_args = ", ".join(arg_names)
output_lines.append(f" return {self.kernel.name}({call_args})")
output_lines.extend(["", "helion_repro_caller()"])
output_lines.append("# === END HELION KERNEL REPRO ===")
repro_text = "\n".join(output_lines)
log_func(repro_text)
class _KernelDecorator(Protocol):
def __call__(
self,
fn: Callable[..., _R],
) -> Kernel[_R]: ...
@overload
def kernel(
fn: Callable[..., _R],
*,
config: ConfigLike | None = None,
configs: list[ConfigLike] | None = None,
key: Callable[..., Hashable] | None = None,
**settings: object,
) -> Kernel[_R]: ...
@overload
def kernel(
fn: None = None,
*,
config: ConfigLike | None = None,
configs: list[ConfigLike] | None = None,
key: Callable[..., Hashable] | None = None,
**settings: object,
) -> _KernelDecorator: ...
[docs]
def kernel(
fn: Callable[..., _R] | None = None,
*,
config: ConfigLike | None = None,
configs: list[ConfigLike] | None = None,
key: Callable[..., Hashable] | None = None,
**settings: object,
) -> Kernel[_R] | _KernelDecorator:
"""
Decorator to create a Kernel object from a Python function.
Args:
fn: The function to be wrapped by the Kernel. If None, a decorator is returned.
config: A single configuration to use for the kernel. Refer to the
``helion.Config`` class for details.
configs: A list of configurations to use for the kernel. Can only specify
one of config or configs. Refer to the ``helion.Config`` class for
details.
key: Optional callable returning a hashable that augments the specialization key.
settings: Keyword arguments representing settings for the Kernel.
Can also use settings=Settings(...) to pass a Settings object
directly. Refer to the ``helion.Settings`` class for available
options.
Returns:
object: A Kernel object or a decorator that returns a Kernel object.
See Also:
- :class:`~helion.Settings`: Controls compilation behavior and debugging options
- :class:`~helion.Config`: Controls GPU execution parameters and optimization strategies
"""
if config is not None:
assert not configs, "Cannot specify both config and configs"
configs = [config]
elif configs is None:
configs = []
if settings_obj := settings.get("settings"):
assert len(settings) == 1, "settings must be the only keyword argument"
assert isinstance(settings_obj, Settings), "settings must be a Settings object"
else:
settings_obj = Settings(**settings)
if fn is None:
return functools.partial(
kernel, configs=configs, settings=settings_obj, key=key
)
return Kernel(fn, configs=configs, settings=settings_obj, key=key)
def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
# NOTE: If a machine has two different gpu types on the same machine,
# obj.device.type will incorrectly hit
if fn.settings.static_shapes:
return (
obj.dtype,
obj.device.type,
(*obj.size(),),
(*obj.stride(),),
)
return (
obj.dtype,
obj.device.type,
# 0, 1, or >=2 specialization
tuple([min(s, 2) for s in obj.size()]),
)
def _sequence_key(fn: Kernel, obj: Sequence) -> Hashable:
return type(obj), tuple([fn._specialization_key(item) for item in obj])
def _mapping_key(
fn: Kernel, obj: dict[str | int, object], real_type: type[object]
) -> Hashable:
return real_type, tuple(
sorted((k, fn._specialization_key(v)) for k, v in obj.items())
)
def _number_key(fn: Kernel, n: float | bool) -> object:
return type(n)
def _function_key(fn: Kernel, obj: types.FunctionType) -> object:
if obj.__closure__:
closures = [
fn._specialization_key(cell.cell_contents) for cell in obj.__closure__
]
return (obj.__code__, *closures)
return obj.__code__
def _graph_module_key(fn: Kernel, obj: torch.fx.GraphModule) -> Hashable:
"""Generate a specialization key for GraphModule arguments."""
# Check if already cached
if obj in _graph_module_hash_cache:
return _graph_module_hash_cache[obj]
# Check for unsupported operations
unsupported_ops = {
node.op
for node in itertools.chain(
obj.graph.find_nodes(op="call_module"),
obj.graph.find_nodes(op="get_attr"),
)
}
if unsupported_ops:
raise exc.GraphModuleUnsupportedOps(", ".join(sorted(unsupported_ops)))
_graph_module_hash_cache[obj] = rv = str(compiled_fx_graph_hash(obj, [], {}, []))
return rv
_specialization_extractors: dict[
type[object] | str, Callable[[Kernel, object], Hashable]
] = { # pyright: ignore[reportAssignmentType]
torch.Tensor: _tensor_key,
torch.nn.Parameter: _tensor_key,
FakeTensor: _tensor_key,
torch.dtype: lambda fn, x: x,
torch.device: lambda fn, x: x,
int: _number_key,
float: _number_key,
bool: _number_key,
str: lambda fn, x: x,
list: _sequence_key,
tuple: _sequence_key,
dict: lambda fn, x: _mapping_key(fn, x, type(x)), # pyright: ignore[reportArgumentType]
"namedtuple": lambda fn, x: _mapping_key(fn, x._asdict(), type(x)), # pyright: ignore[reportAttributeAccessIssue]
"dataclass": lambda fn, x: _mapping_key(fn, dataclasses.asdict(x), type(x)), # pyright: ignore[reportArgumentType]
types.FunctionType: _function_key,
types.BuiltinFunctionType: lambda fn, x: x,
torch.fx.GraphModule: _graph_module_key,
ConstExpr: lambda fn, x: x.value, # pyright: ignore[reportAttributeAccessIssue]
type(None): lambda fn, x: None,
}
def _find_device(args: tuple[object, ...]) -> torch.device:
"""
Extract the device from the arguments.
Args:
args: The arguments to extract the device from.
Returns:
torch.device: The extracted device
"""
for arg in args:
if isinstance(arg, torch.device):
return arg
if isinstance(arg, torch.Tensor):
return arg.device
if isinstance(arg, (tuple, list)):
for item in arg:
try:
return _find_device(item)
except exc.NoTensorArgs:
pass
elif isinstance(arg, dict):
for item in arg.values():
try:
return _find_device(item)
except exc.NoTensorArgs:
pass
raise exc.NoTensorArgs
def _maybe_skip_dtype_check_in_meta_registrations() -> (
contextlib.AbstractContextManager[None, None]
):
if hasattr(torch.fx.experimental._config, "skip_dtype_check_in_meta_registrations"): # pyright: ignore[reportAttributeAccessIssue]
return torch.fx.experimental._config.patch( # pyright: ignore[reportAttributeAccessIssue]
skip_dtype_check_in_meta_registrations=True
)
return contextlib.nullcontext()