Rate this Page

Source code for helion.language.inline_triton_ops

from __future__ import annotations

import ast
from collections.abc import Mapping
from collections.abc import Sequence
import textwrap
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import cast

import torch
from torch._inductor.utils import triton_type
from torch.fx import has_side_effect

from .. import exc
from .._compiler.ast_extension import convert
from .._compiler.ast_extension import create
from .._compiler.ast_extension import expr_from_string
from .._compiler.ast_extension import statement_from_string
from . import _decorators

if TYPE_CHECKING:
    from .._compiler.inductor_lowering import CodegenState

    _T = TypeVar("_T")

__all__ = ["inline_triton"]


[docs] @has_side_effect @_decorators.api(is_device_only=True, allow_host_tensor=True) def inline_triton( triton_source: str, args: Sequence[object] | Mapping[str, object], output_like: _T, ) -> _T: """Inline a raw Triton snippet inside a Helion kernel. Args: triton_source: The Triton code snippet. The last statement must be an expression representing the return value. The snippet may be indented, and common indentation is stripped automatically. args: Positional or keyword placeholders that will be substituted via ``str.format`` before code generation. Provide a tuple/list for positional placeholders (``{0}``, ``{1}``, ...) or a mapping for named placeholders (``{x}``, ``{y}``, ...). output_like: Example tensors describing the expected outputs. A single tensor indicates a single output; a tuple/list of tensors indicates multiple outputs. Returns: The value(s) produced by the snippet. Matches the structure of ``output_like``. """ raise exc.NotInsideKernel
def _validate_args(args: object) -> None: if isinstance(args, Mapping): return if isinstance(args, Sequence) and not isinstance(args, (str, bytes)): return raise exc.InvalidAPIUsage("inline_triton args must be a tuple/list or a mapping") def _fake_outputs(output_like: object) -> object: if output_like is None: return None if isinstance(output_like, torch.Tensor): return torch.empty_like(output_like) if isinstance(output_like, Sequence) and not isinstance(output_like, (str, bytes)): outputs = [] for i, item in enumerate(output_like): if not isinstance(item, torch.Tensor): raise exc.InvalidAPIUsage( f"output_like[{i}] must be a torch.Tensor, got {type(item)}" ) outputs.append(torch.empty_like(item)) return type(output_like)(outputs) if isinstance(output_like, tuple) else outputs raise exc.InvalidAPIUsage( "output_like must be a tensor or a sequence of tensors or None" ) @_decorators.register_fake(inline_triton) def _( triton_source: str, args: object, output_like: object, ) -> object: if not isinstance(triton_source, str): raise exc.InvalidAPIUsage( f"triton_source must be a string, got {type(triton_source)}" ) _validate_args(args) return _fake_outputs(output_like) def _ensure_name(state: CodegenState, node: ast.AST) -> str: lifted = state.codegen.lift(node) assert isinstance(lifted, ast.Name) return lifted.id def _format_triton_source( state: CodegenState, triton_source: str, args_obj: object, args_ast: object, ) -> str: source = textwrap.dedent(triton_source).strip() if not source: raise exc.InvalidAPIUsage("triton_source must contain code") if isinstance(args_obj, Mapping): if not isinstance(args_ast, dict): raise exc.InvalidAPIUsage( "inline_triton expects a dict literal when args is a mapping" ) assert args_obj.keys() == args_ast.keys() format_args: dict[str, str] = { key: _ensure_name(state, args_ast[key]) for key in args_ast } try: return source.format(**format_args) except (KeyError, IndexError, ValueError) as exc_value: raise exc.InvalidAPIUsage( f"Failed to format triton_source with mapping args: {exc_value}" ) from exc_value if isinstance(args_obj, Sequence) and not isinstance(args_obj, (str, bytes)): if not isinstance(args_ast, (ast.List, ast.Tuple, list, tuple)): raise exc.InvalidAPIUsage( "inline_triton expects a list/tuple literal when args is a sequence" ) arg_nodes = ( args_ast.elts if isinstance(args_ast, (ast.List, ast.Tuple)) else list(args_ast) ) names = [_ensure_name(state, node) for node in arg_nodes] try: expected_len = len(args_obj) except TypeError: # pragma: no cover - defensive expected_len = len(names) if expected_len != len(names): raise exc.InvalidAPIUsage( "inline_triton sequence args must be provided as a literal" ) try: return source.format(*names) except (IndexError, ValueError) as exc_value: raise exc.InvalidAPIUsage( f"Failed to format triton_source with positional args: {exc_value}" ) from exc_value raise exc.InvalidAPIUsage("inline_triton args must be a tuple/list or a mapping") def _parse_triton_source(source: str) -> tuple[list[ast.stmt], ast.AST]: try: module = ast.parse(source) except SyntaxError as exc_value: raise exc.InvalidAPIUsage( f"Failed to parse triton_source: {exc_value}" ) from exc_value if not module.body: raise exc.InvalidAPIUsage("triton_source must contain at least one expression") *prefix, last = module.body if not isinstance(last, ast.Expr): raise exc.InvalidAPIUsage( "The last line of triton_source must be an expression" ) converted_prefix = [cast("ast.stmt", convert(stmt)) for stmt in prefix] return converted_prefix, convert(last.value) def _normalize_output_ast(output_ast: object) -> list[ast.AST]: if isinstance(output_ast, (ast.Tuple, ast.List)): return [cast("ast.AST", elem) for elem in output_ast.elts] if isinstance(output_ast, (tuple, list)): nodes: list[ast.AST] = [] for elem in output_ast: if not isinstance(elem, ast.AST): raise exc.InvalidAPIUsage( "output_like literal must reference tensors directly" ) nodes.append(elem) return nodes if isinstance(output_ast, ast.AST): return [output_ast] raise exc.InvalidAPIUsage( "output_like must be provided as a tensor or tuple/list literal" ) def _collect_output_metadata( output_like: object, output_ast: object, ) -> tuple[list[torch.dtype], list[ast.AST], bool]: if output_like is None: return [], [], False if isinstance(output_like, torch.Tensor): return [output_like.dtype], _normalize_output_ast(output_ast), False if isinstance(output_like, Sequence) and not isinstance(output_like, (str, bytes)): if not output_like: raise exc.InvalidAPIUsage("output_like sequence must not be empty") ast_nodes = _normalize_output_ast(output_ast) dtypes: list[torch.dtype] = [] for i, item in enumerate(output_like): if not isinstance(item, torch.Tensor): raise exc.InvalidAPIUsage( f"output_like[{i}] must be a torch.Tensor, got {type(item)}" ) dtypes.append(item.dtype) if len(dtypes) != len(ast_nodes): raise exc.InvalidAPIUsage( "output_like literal must match the structure passed into inline_triton" ) return dtypes, ast_nodes, True raise exc.InvalidAPIUsage( "output_like must be a tensor or a sequence of tensors or None" ) def _emit_output_assertions( state: CodegenState, result_name: str, dtypes: list[torch.dtype], output_nodes: list[ast.AST], is_multi: bool, ) -> None: if not dtypes: return if not is_multi: lhs = expr_from_string(f"{result_name}.dtype") rhs = expr_from_string(triton_type(dtypes[0])) msg = ast.Constant( value=f"inline_triton output dtype mismatch; expected {dtypes[0]}" ) state.add_statement( statement_from_string( "tl.static_assert({lhs} == {rhs}, {msg})", lhs=lhs, rhs=rhs, msg=msg ) ) shape_lhs = expr_from_string(f"{result_name}.shape") shape_rhs = expr_from_string("{value}.shape", value=output_nodes[0]) shape_msg = ast.Constant(value="inline_triton output shape mismatch") state.add_statement( statement_from_string( "tl.static_assert({lhs} == {rhs}, {msg})", lhs=shape_lhs, rhs=shape_rhs, msg=shape_msg, ) ) return count_msg = ast.Constant(value=f"inline_triton expected {len(dtypes)} outputs") state.add_statement( statement_from_string( "tl.static_assert(len({result}) == {count}, {msg})", result=expr_from_string(result_name), count=ast.Constant(value=len(dtypes)), msg=count_msg, ) ) for index, dtype in enumerate(dtypes): lhs = expr_from_string(f"{result_name}[{index}].dtype") rhs = expr_from_string(triton_type(dtype)) msg = ast.Constant( value=f"inline_triton output {index} dtype mismatch; expected {dtype}" ) state.add_statement( statement_from_string( "tl.static_assert({lhs} == {rhs}, {msg})", lhs=lhs, rhs=rhs, msg=msg, ) ) shape_lhs = expr_from_string(f"{result_name}[{index}].shape") shape_rhs = expr_from_string("{value}.shape", value=output_nodes[index]) shape_msg = ast.Constant(value=f"inline_triton output {index} shape mismatch") state.add_statement( statement_from_string( "tl.static_assert({lhs} == {rhs}, {msg})", lhs=shape_lhs, rhs=shape_rhs, msg=shape_msg, ) ) @_decorators.codegen(inline_triton) def _(state: CodegenState) -> ast.AST | list[ast.AST]: triton_source = state.proxy_arg(0) args_obj = state.proxy_arg(1) output_like = state.proxy_arg(2) if not isinstance(triton_source, str): # defensive; validated earlier raise exc.InvalidAPIUsage( f"triton_source must be a string, got {type(triton_source)}" ) formatted = _format_triton_source( state, triton_source, args_obj, state.ast_args[1], ) statements, result_expr = _parse_triton_source(formatted) for stmt in statements: state.add_statement(stmt) if output_like is None: state.add_statement(create(ast.Expr, value=result_expr)) return create(ast.Constant, value=None) result_name = state.device_function.new_var("inline_triton_result") assign = create( ast.Assign, targets=[create(ast.Name, id=result_name, ctx=ast.Store())], value=result_expr, ) state.add_statement(assign) dtypes, output_nodes, is_multi = _collect_output_metadata( output_like, state.ast_args[2] ) _emit_output_assertions(state, result_name, dtypes, output_nodes, is_multi) if is_multi: return [expr_from_string(f"{result_name}[{i}]") for i in range(len(dtypes))] return expr_from_string(result_name)