from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from torch._inductor.codegen.simd import constant_repr
from torch._inductor.runtime.runtime_utils import next_power_of_2
from .. import exc
from .._compiler.ast_extension import ExtendedAST
from .._compiler.ast_extension import expr_from_string
from .._compiler.compile_environment import AutoSize
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.type_propagation import TileIndexType
from .._compiler.type_propagation import TypeInfo
from .._compiler.type_propagation import _to_proxy
from ..autotuner.config_fragment import BaseIntegerFragment
from ..autotuner.config_fragment import ConfigSpecFragment
from ..autotuner.config_fragment import assert_integer_power_of_two
from ..autotuner.config_spec import VALID_KEYS
from ..exc import NotInsideKernel
from . import _decorators
from .loops import _normalize_begin_end
if TYPE_CHECKING:
import ast
from .._compiler.inductor_lowering import CodegenState
from .._compiler.variable_origin import Origin
__all__ = [
"register_block_size",
"register_reduction_dim",
"register_tunable",
]
[docs]
@_decorators.api(is_device_only=False, cache_type=True, tiles_as_sizes=True)
def register_block_size(min_or_max: int, max_or_none: int | None = None, /) -> int:
"""
Explicitly register a block size that should be autotuned and can be used for
allocations and inside hl.tile(..., block_size=...).
This is useful if you have two loops where you want them to share a block size,
or if you need to allocate a kernel tensor before the hl.tile() loop.
The signature can one of:
hl.register_block_size(max)
hl.register_block_size(min, max)
Where min and max are integers that control the range of block_sizes searched by
the autotuner. Max may be a symbolic shape, but min must be a constant integer.
"""
raise exc.NotInsideKernel
@_decorators.ref(register_block_size)
def _(min_or_max: int, max_or_none: int | None = None, /) -> int:
# In ref mode, always return the maximum value (full dimension size)
if max_or_none is None:
return min_or_max
return max_or_none
@_decorators.type_propagation(register_block_size)
def _(
min_or_max: TypeInfo, max_or_none: TypeInfo | None = None, /, *, origin: Origin
) -> TypeInfo:
from .._compiler.type_propagation import SymIntType
min_type, max_type = _normalize_begin_end(min_or_max, max_or_none, origin=origin)
min_proxy = _to_proxy(min_type)
max_proxy = _to_proxy(max_type)
if not isinstance(max_proxy, (int, torch.SymInt)):
raise exc.IncorrectTileUsage(
f"expected max to be an integer or size, got {max_proxy!s}"
)
if not isinstance(min_proxy, int):
raise exc.IncorrectTileUsage(
f"expected min to be an integer constant, got {min_proxy!s}"
)
env = CompileEnvironment.current()
result = TileIndexType.allocate(AutoSize(), origin)
loop_spec = env.config_spec.block_sizes.block_id_lookup(result.block_id)
loop_spec.min_size = assert_integer_power_of_two(max(1, min_proxy))
loop_spec.max_size = next_power_of_2(env.size_hint(max_proxy))
block_id = result.block_id
return SymIntType(origin, env.block_sizes[block_id].var)
def _block_id_from_state(state: CodegenState) -> int:
"""Extract the block_id from the current state for nodes hl.register_block_size."""
from .._compiler.type_propagation import SymIntType
env = CompileEnvironment.current()
if state.fx_node is not None:
val = state.fx_node.meta["val"]
assert isinstance(val, SymIntType)
block_id = env.get_block_id(val.value)
assert block_id is not None
return block_id
current_node = ExtendedAST.current()[-1]
type_info = current_node._type_info
assert isinstance(type_info, SymIntType)
block_id = env.get_block_id(type_info.value)
assert block_id is not None
return block_id
@_decorators.codegen(register_block_size)
def _(state: CodegenState) -> ast.AST:
env = CompileEnvironment.current()
block_size = env.config_spec.block_sizes.config_get(
state.config.block_sizes, _block_id_from_state(state)
)
assert block_size is not None
return expr_from_string(constant_repr(block_size))
[docs]
@_decorators.api(is_device_only=False, cache_type=True, tiles_as_sizes=True)
def register_reduction_dim(
size: int,
) -> int:
"""
Explicitly register a reduction dimension that should be used for reduction operations.
This is useful when you need to allocate a dimension for reduction that isn't
automatically inferred from a slice operation. The registered dimension can be
used for allocations and operations that require knowing the reduction size upfront.
Args:
size: An integer representing the reduction dimension size.
Returns:
torch.SymInt: A SymInt object representing the reduction dimension size.
"""
raise exc.NotInsideKernel
@_decorators.ref(register_reduction_dim)
def _(size: int) -> int:
# In ref mode, simply return the size as-is
return size
@_decorators.type_propagation(register_reduction_dim)
def _(sizes: TypeInfo, *, origin: Origin) -> TypeInfo:
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.type_propagation import SymIntType
try:
proxy_sizes = sizes.proxy()
if not isinstance(proxy_sizes, int | torch.SymInt):
raise NotImplementedError
except NotImplementedError:
raise exc.TypeInferenceError(
f"register_reduction_dim() expected int or list[int], got {sizes!s}"
) from None
env = CompileEnvironment.current()
rdim = env.allocate_reduction_dimension(proxy_sizes)
return SymIntType(origin, rdim.var)
@_decorators.codegen(register_reduction_dim)
def _(state: CodegenState) -> ast.AST:
"""Generate code for register_reduction_dim - return the size expression"""
from .._compiler.type_propagation import SymIntType
current_node = ExtendedAST.current()[-1]
type_info = current_node._type_info
assert isinstance(type_info, SymIntType)
return current_node.args[ # pyright: ignore[reportAttributeAccessIssue]
0
]
[docs]
@_decorators.api(is_device_only=False)
def register_tunable(name: str, fragment: ConfigSpecFragment) -> int:
"""
Register a tunable parameter for autotuning.
This function allows you to define parameters that can be automatically tuned
during the autotuning process. The fragment defines the search space and default value.
Args:
name: The key for the tunable parameter in the Config().
fragment: A ConfigSpecFragment that defines the search space (e.g., PowerOfTwoFragment)
Returns:
int: The value assigned to this tunable parameter in the current configuration.
"""
raise NotInsideKernel
@_decorators.type_propagation(register_tunable)
def _register_tunable_type(
name: TypeInfo, fragment: TypeInfo, *, origin: Origin
) -> TypeInfo:
# During type propagation, register the tunable parameter and return unbacked symint
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.type_propagation import NumericType
env = CompileEnvironment.current()
try:
fragment_val = fragment.as_literal()
name_val = name.as_literal()
except NotImplementedError:
fragment_val = None
name_val = None
if not (isinstance(name_val, str) and isinstance(fragment_val, ConfigSpecFragment)):
raise exc.RegisterTunableArgTypes(name, fragment)
del name, fragment
if name_val in VALID_KEYS or f"{name_val}s" in VALID_KEYS:
raise exc.TunableNameConflict(name_val)
if (
name_val in env.config_spec.user_defined_tunables
and env.config_spec.user_defined_tunables[name_val] != fragment_val
):
raise exc.TunableNameConflict(name_val)
# register the value for tuning
env.config_spec.user_defined_tunables[name_val] = fragment_val
python_type = type(fragment_val.default())
if not issubclass(python_type, (int, float, bool)):
raise exc.TunableTypeNotSupported(python_type)
return NumericType.subtype(python_type).new_unbacked(origin)
@_decorators.codegen(register_tunable)
def _register_tunable_codegen(state: CodegenState) -> ast.AST:
name = state.proxy_arg(0)
assert isinstance(name, str)
config_value = state.config[name]
assert isinstance(config_value, (int, float, bool))
return expr_from_string(constant_repr(config_value))
@_decorators.ref(register_tunable)
def _(name: str, fragment: ConfigSpecFragment) -> int:
"""Reference implementation of register_tunable."""
from ..runtime.ref_mode import RefModeContext
# Get config from the current RefModeContext
context = RefModeContext.current()
config = context.config
# Determine the value to use
value: object
assert config is not None
if name in config:
value = config[name]
else:
value = fragment.default()
# For BaseIntegerFragment subclasses (IntegerFragment, PowerOfTwoFragment), apply clamp
# This ensures the value is within valid bounds
if isinstance(fragment, BaseIntegerFragment) and isinstance(value, (int, bool)):
value = fragment.clamp(int(value))
# Convert to int if needed
if isinstance(value, bool):
return int(value)
if isinstance(value, int):
return value
# For other types (like float), convert to int
return int(value) # type: ignore[arg-type]