Source code for helion.language.constexpr

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import NamedTuple

import torch

from .. import exc
from .._compiler.ast_extension import expr_from_string
from . import _decorators

if TYPE_CHECKING:
    import ast

    from .._compiler.inductor_lowering import CodegenState
    from .._compiler.type_propagation import TypeInfo
    from .._compiler.variable_origin import Origin


class ConstExpr(NamedTuple):
    """
    Typically used as a type annotation for kernels:

    .. code-block:: python

        @helion.kernel()
        def fn(v: hl.constexpr, ...):
            ...

    Can also be used when calling a kernel:

    .. code-block:: python

        some_kernel(..., hl.constexpr(5.0))

    Causes the generated code to specialize on the value of `v`, where a different
    kernel, hardcoding the value of v, will be generated every time `v` changes.

    See Also:
        - :func:`specialize`: Convert dynamic shapes to compile-time constants
    """

    value: object

    def __index__(self) -> int:
        if isinstance(self.value, int):
            return self.value
        raise TypeError(f"ConstExpr cannot be indexed: {self.value}")

    def __bool__(self) -> bool:
        return bool(self.value)


[docs] @_decorators.api(is_device_only=False) def specialize(value: int | torch.SymInt) -> int: """ Turn a dynamic shape into a compile-time constant. Example: hl.specialize(tensor.size(1)) Args: value: The symbolic value to specialize on. Returns: int: The specialized value. See Also: - :class:`ConstExpr`: Create compile-time constants for kernel parameters """ raise exc.NotInsideKernel
@_decorators.type_propagation(specialize) def _(value: TypeInfo, *, origin: Origin) -> TypeInfo: from .._compiler.compile_environment import CompileEnvironment from .._compiler.type_propagation import TypeInfo if origin.is_device(): raise exc.SpecializeOnDevice proxy = value.proxy() if isinstance(proxy, torch.SymInt): CompileEnvironment.current().specialized_vars.update( proxy._sympy_().free_symbols ) return TypeInfo.from_example(proxy.__int__(), origin=origin) if isinstance(proxy, int): return TypeInfo.from_example(proxy, origin=origin) # already specialized raise exc.SpecializeArgType(value) @_decorators.codegen(specialize) def _(state: CodegenState) -> ast.AST: value = state.proxy_arg(0) if isinstance(value, torch.SymInt): value = value.__int__() assert isinstance(value, int) return expr_from_string(repr(value)) @_decorators.ref(specialize) def _(value: int | torch.SymInt) -> int: return int(value)