Rate this Page

Source code for helion.language.constexpr

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Callable
from typing import NamedTuple
from typing_extensions import TypeVar

import torch

from .. import exc
from .._compiler.ast_extension import expr_from_string
from . import _decorators

if TYPE_CHECKING:
    import ast

    from .._compiler.inductor_lowering import CodegenState
    from .._compiler.type_propagation import TypeInfo
    from .._compiler.variable_origin import Origin

    _T = TypeVar("_T")


class ConstExpr(NamedTuple):
    """
    Typically used as a type annotation for kernels:

    .. code-block:: python

        @helion.kernel()
        def fn(v: hl.constexpr, ...):
            ...

    Can also be used when calling a kernel:

    .. code-block:: python

        some_kernel(..., hl.constexpr(5.0))

    Causes the generated code to specialize on the value of `v`, where a different
    kernel, hardcoding the value of v, will be generated every time `v` changes.

    See Also:
        - :func:`specialize`: Convert dynamic shapes to compile-time constants
    """

    value: object

    def __index__(self) -> int:
        if isinstance(self.value, int):
            return self.value
        raise TypeError(f"ConstExpr cannot be indexed: {self.value}")

    def __bool__(self) -> bool:
        return bool(self.value)


[docs] @_decorators.api(is_device_only=False) def specialize(value: _T) -> _T: """ Turn dynamic shapes into compile-time constants. Examples:: channels = hl.specialize(tensor.size(1)) height, width = hl.specialize(tensor.shape[-2:]) Args: value: The symbolic value or sequence of symbolic values to specialize on. Returns: A Python int or a sequence containing only Python ints. See Also: - :class:`ConstExpr`: Create compile-time constants for kernel parameters """ raise exc.NotInsideKernel
@_decorators.type_propagation(specialize) def _(value: TypeInfo, *, origin: Origin) -> TypeInfo: from .._compiler.compile_environment import CompileEnvironment from .._compiler.type_propagation import TypeInfo if origin.is_device(): raise exc.SpecializeOnDevice proxy = value.proxy() env = CompileEnvironment.current() def handle_symint(symint: torch.SymInt) -> int: env.specialized_vars.update(symint._sympy_().free_symbols) return symint.__int__() specialized = _convert_specializable(proxy, on_symint=handle_symint) return TypeInfo.from_example(specialized, origin=origin) @_decorators.codegen(specialize) def _(state: CodegenState) -> ast.AST: value = state.proxy_arg(0) specialized = _convert_specializable(value) return expr_from_string(repr(specialized)) @_decorators.ref(specialize) def _(value: _T) -> _T: return _convert_specializable(value) def _convert_specializable( value: _T, *, on_symint: Callable[[torch.SymInt], int] = lambda symint: symint.__int__(), ) -> _T: if isinstance(value, torch.SymInt): return on_symint(value) if isinstance(value, int): return value if isinstance(value, (torch.Size, tuple, list)): try: return type(value)( [_convert_specializable(x, on_symint=on_symint) for x in value] ) except exc.SpecializeArgType: raise exc.SpecializeArgType(value) from None raise exc.SpecializeArgType(value)