Rate this Page

Source code for helion.language.scan_ops

from __future__ import annotations

import ast
import operator
from typing import TYPE_CHECKING
from typing import cast
from typing import overload

import torch
import torch._higher_order_ops as higher_order_ops
from torch.fx.experimental import proxy_tensor

from .. import exc
from . import _decorators

if TYPE_CHECKING:
    from .._compiler.device_ir import HelperFunctionGraphInfo
    from .._compiler.helper_function import CombineFunction
    from .._compiler.inductor_lowering import CodegenState
    from .._compiler.type_propagation import Origin
    from .._compiler.type_propagation import TypeInfo


__all__ = ["associative_scan", "cumprod", "cumsum"]


@overload
@_decorators.device_func_replacement(higher_order_ops.associative_scan)
@_decorators.api(is_device_only=True)
def associative_scan(
    combine_fn: CombineFunction,
    input_tensor: torch.Tensor,
    dim: int,
    reverse: bool = False,
) -> torch.Tensor: ...


@overload
@_decorators.device_func_replacement(higher_order_ops.associative_scan)
@_decorators.api(is_device_only=True)
def associative_scan(
    combine_fn: CombineFunction,
    input_tensor: tuple[torch.Tensor, ...],
    dim: int,
    reverse: bool = False,
) -> tuple[torch.Tensor, ...]: ...


[docs] @_decorators.device_func_replacement(higher_order_ops.associative_scan) @_decorators.api(is_device_only=True) def associative_scan( combine_fn: CombineFunction, input_tensor: torch.Tensor | tuple[torch.Tensor, ...], dim: int, reverse: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """ Applies an associative scan operation along a specified dimension. Computes the prefix scan (cumulative operation) along a dimension using a custom combine function. Unlike :func:`~helion.language.reduce`, this preserves the input shape. Args: combine_fn: A binary function that combines two elements element-wise. Must be associative for correct results. Can be tensor->tensor or tuple->tuple function. input_tensor: Input tensor or tuple of tensors to scan dim: The dimension along which to scan reverse: If True, performs the scan in reverse order Returns: torch.Tensor or tuple[torch.Tensor, ...]: Tensor(s) with same shape as input containing the scan result See Also: - :func:`~helion.language.reduce`: For dimension-reducing operations - :func:`~helion.language.cumsum`: For cumulative sum - :func:`~helion.language.cumprod`: For cumulative product Note: - combine_fn must be associative (not necessarily commutative) - Output has same shape as input (unlike reduce) - For standard scans, use :func:`~helion.language.cumsum` or :func:`~helion.language.cumprod` (faster) - Reverse scan applies the operation from right to left """ raise exc.NotInsideKernel
@_decorators.register_fake(associative_scan) def _( combine_fn: CombineFunction, input_tensor: torch.Tensor | tuple[torch.Tensor, ...], dim: int, reverse: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Fake implementation that returns fake tensors with the same shape as input.""" if isinstance(input_tensor, (tuple, list)): return tuple(torch.empty_like(t) for t in input_tensor) return torch.empty_like(input_tensor) @_decorators.ref(associative_scan) def _( combine_fn: CombineFunction, input_tensor: torch.Tensor | tuple[torch.Tensor, ...], dim: int, reverse: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: return higher_order_ops.associative_scan( combine_fn, input_tensor, dim, reverse=reverse ) @_decorators.register_to_device_ir(associative_scan) def _( tracer: proxy_tensor.PythonKeyTracer, combine_fn: CombineFunction, input_tensor: torch.Tensor | tuple[torch.Tensor, ...], dim: int, reverse: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """ Device IR implementation that handles tracing for associative_scan. We map associative_scan to _associative_scan, with a pre-traced graph for the combine function. """ from .._compiler.device_ir import DeviceIR from .._compiler.device_ir import HelperFunctionGraphInfo from .._compiler.device_ir import args_to_proxies from .._compiler.device_ir import select_decomp_table from .._compiler.helper_function import create_combine_function_wrapper from .._compiler.helper_function import extract_helper_function_name is_tuple_input = isinstance(input_tensor, (tuple, list)) if is_tuple_input: assert all(isinstance(t, torch.Tensor) for t in input_tensor), ( "associative_scan input must be a tuple of tensors" ) else: assert isinstance(input_tensor, torch.Tensor), ( "associative_scan input must be a tensor" ) assert isinstance(dim, int), "associative_scan dim must be an integer" assert callable(combine_fn), "combine_fn must be callable" # Extract the function name before wrapping original_function_name = extract_helper_function_name(combine_fn) combine_fn = create_combine_function_wrapper( combine_fn, is_tuple_input=is_tuple_input, target_format="unpacked" ) # Create fake inputs for the combine function fake_inputs = [] for tensor in input_tensor if is_tuple_input else [input_tensor]: fake_inputs.extend( [ torch.empty([1], dtype=tensor.dtype, device=tensor.device), torch.empty([1], dtype=tensor.dtype, device=tensor.device), ] ) combine_graph = proxy_tensor.make_fx( combine_fn, decomposition_table=select_decomp_table() )(*fake_inputs).graph combine_graph_id = DeviceIR.current().add_graph( combine_graph, HelperFunctionGraphInfo, node_args=[], original_function_name=original_function_name, ) # Create the associative_scan tracing operation scan_args = (combine_graph_id, input_tensor, dim, reverse, is_tuple_input) proxy_args, proxy_kwargs = args_to_proxies(tracer, scan_args) proxy_out = tracer.create_proxy( "call_function", _associative_scan, proxy_args, proxy_kwargs, ) # Create new output tensors to avoid aliasing input with output. # tl.associative_scan modifies its input in-place, so we must track # distinct output tensors to ensure the input remains usable after the scan. if is_tuple_input: output_tensors = [] assert isinstance(input_tensor, (tuple, list)) for i, tensor in enumerate(input_tensor): output_tensor = torch.empty_like(tensor) element_proxy = tracer.create_proxy( "call_function", operator.getitem, (proxy_out, i), {}, ) proxy_tensor.track_tensor_tree( output_tensor, element_proxy, constant=None, tracer=tracer ) output_tensors.append(output_tensor) return tuple(output_tensors) output_tensor = torch.empty_like(input_tensor) proxy_tensor.track_tensor_tree( output_tensor, proxy_out, constant=None, tracer=tracer ) return output_tensor @_decorators.type_propagation(associative_scan) def _( combine_fn: TypeInfo, input_tensor: TypeInfo, dim: TypeInfo, reverse: TypeInfo | None = None, *, origin: Origin, ) -> TypeInfo: """Type propagation for associative_scan - output has same type as input.""" from .._compiler.type_propagation import CallableType from .._compiler.type_propagation import SequenceType from .._compiler.type_propagation import TensorType # Validate that combine_fn is callable if not isinstance(combine_fn, CallableType): raise exc.TypeInferenceError(f"combine_fn must be callable, got {combine_fn}") # Validate that input_tensor is a tensor or tuple of tensors if isinstance(input_tensor, TensorType): # Single tensor case return input_tensor if isinstance(input_tensor, SequenceType): # Tuple of tensors case - validate all elements are tensors for elem_type in input_tensor.unpack(): if not isinstance(elem_type, TensorType): raise exc.TypeInferenceError( f"All elements in tuple must be tensors, got {elem_type}" ) # Return the same tuple type return input_tensor raise exc.TypeInferenceError( f"input_tensor must be a tensor or tuple of tensors, got {input_tensor}" )
[docs] @_decorators.device_func_replacement(torch.cumsum) def cumsum(input_tensor: torch.Tensor, dim: int, reverse: bool = False) -> torch.Tensor: """ Compute the cumulative sum along a specified dimension. Equivalent to ``hl.associative_scan(torch.add, input_tensor, dim, reverse)``. Args: input_tensor: Input tensor to compute cumulative sum dim: The dimension along which to compute cumulative sum reverse: If True, performs the cumsum in reverse order Returns: torch.Tensor: Tensor with same shape as input containing cumulative sum See Also: - :func:`~helion.language.associative_scan`: For custom scan operations - :func:`~helion.language.cumprod`: For cumulative product - :func:`~helion.language.reduce`: For dimension-reducing operations Note: - Output has same shape as input - Reverse=True computes cumsum from right to left - Equivalent to torch.cumsum """ return associative_scan(torch.add, input_tensor, dim, reverse)
[docs] @_decorators.device_func_replacement(torch.cumprod) def cumprod( input_tensor: torch.Tensor, dim: int, reverse: bool = False ) -> torch.Tensor: """ Compute the cumulative product along a specified dimension. Equivalent to ``hl.associative_scan(torch.mul, input_tensor, dim, reverse)``. Args: input_tensor: Input tensor to compute cumulative product dim: The dimension along which to compute cumulative product reverse: If True, performs the cumprod in reverse order Returns: torch.Tensor: Tensor with same shape as input containing cumulative product See Also: - :func:`~helion.language.associative_scan`: For custom scan operations - :func:`~helion.language.cumsum`: For cumulative sum - :func:`~helion.language.reduce`: For dimension-reducing operations Note: - Output has same shape as input - Reverse=True computes cumprod from right to left - Equivalent to torch.cumprod """ return associative_scan(torch.mul, input_tensor, dim, reverse)
@_decorators.api() def _associative_scan( combine_graph_id: int, input_tensor: torch.Tensor | tuple[torch.Tensor, ...], dim: int, reverse: bool = False, is_tuple_input: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Device IR implementation of associative scan, not meant to be called directly.""" raise AssertionError("this should never be called") @_decorators.register_fake(_associative_scan) def _( combine_graph_id: int, input_tensor: torch.Tensor | tuple[torch.Tensor, ...], dim: int, reverse: bool = False, is_tuple_input: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Fake implementation that returns a tensor/tuple with the same shape as input.""" if is_tuple_input: assert isinstance(input_tensor, (tuple, list)), input_tensor return tuple(torch.empty_like(t) for t in input_tensor) assert isinstance(input_tensor, torch.Tensor), input_tensor return torch.empty_like(input_tensor) @_decorators.codegen(_associative_scan, "triton") def _(state: CodegenState) -> ast.AST | list[ast.AST]: """Generate code for associative scan with combine function.""" combine_graph_id = state.proxy_arg(0) dim = state.proxy_arg(2) reverse = state.proxy_arg(3) is_tuple_input = state.proxy_arg(4) input_tensor = _get_input_tensor_ast(state, bool(is_tuple_input)) helper_func_name = _register_helper_function(state, cast("int", combine_graph_id)) scan_expr = _create_scan_expression( input_tensor, cast("int", dim), helper_func_name, bool(reverse) ) if is_tuple_input: return _create_tuple_result_expressions(state, scan_expr) return scan_expr @_decorators.codegen(_associative_scan, "cute") def _(state: CodegenState) -> ast.AST: from torch.fx.node import Node from .._compiler.ast_extension import expr_from_string from .._compiler.ast_extension import statement_from_string from .._compiler.compile_environment import CompileEnvironment from .._compiler.cute.indexing import CuteSortableLoad from .._compiler.device_ir import HelperFunctionGraphInfo combine_graph_id = cast("int", state.proxy_arg(0)) dim = cast("int", state.proxy_arg(2)) reverse = bool(state.proxy_arg(3)) is_tuple_input = bool(state.proxy_arg(4)) if is_tuple_input: raise exc.BackendUnsupported("cute", "tuple associative_scan") helper_graph_info = state.get_graph(combine_graph_id) assert isinstance(helper_graph_info, HelperFunctionGraphInfo) op = _scan_combine_operator(helper_graph_info) if op not in ("add", "max", "min", "mul"): raise exc.BackendUnsupported("cute", "associative_scan combine function") fx_node = state.fx_node if fx_node is None: raise exc.BackendUnsupported("cute", "associative_scan without FX node") input_node = fx_node.args[1] input_tensor = fx_node.meta["val"] if dim < 0: dim += input_tensor.ndim if dim != input_tensor.ndim - 1: raise exc.BackendUnsupported("cute", "associative_scan non-last dimension") scan_source = state.ast_args[1] sorted_source: tuple[CuteSortableLoad, bool] | None = None if ( isinstance(input_node, Node) and input_node.target is operator.getitem and isinstance(input_node.args[0], Node) and input_node.args[0].target is torch.ops.aten.sort.default ): sort_node = input_node.args[0] load = sort_node.meta.get("cute_sort_load") descending = sort_node.meta.get("cute_sort_descending") if isinstance(load, CuteSortableLoad) and isinstance(descending, bool): sorted_source = (load, descending) if sorted_source is None: if not isinstance(scan_source, CuteSortableLoad): if isinstance(input_node, Node): scan_source = input_node.meta.get("cute_sortable_load") if not isinstance(scan_source, CuteSortableLoad): raise exc.BackendUnsupported("cute", "associative_scan input") load = scan_source else: load = sorted_source[0] env = CompileEnvironment.current() n = input_tensor.shape[-1] n_hint = env.size_hint(n) if isinstance(n, torch.SymInt) else n if not isinstance(n_hint, int): raise exc.BackendUnsupported("cute", "dynamic associative_scan extent") dtype_str = env.backend.dtype_str(input_tensor.dtype) index_dtype = env.backend.dtype_str(env.index_dtype) out_pos = state.device_function.new_var("scan_out_pos") scan_i = state.device_function.new_var("scan_i") acc = state.device_function.new_var("scan_acc") initialized = state.device_function.new_var("scan_initialized") include = state.device_function.new_var("scan_include") value = state.device_function.new_var("scan_value") state.codegen.add_statement( statement_from_string( f"{out_pos} = {index_dtype}({load.index_exprs[load.sort_index_pos]})" ) ) identity = "1" if op == "mul" else "0" state.codegen.add_statement( statement_from_string(f"{acc} = {dtype_str}({identity})") ) state.codegen.add_statement(statement_from_string(f"{initialized} = False")) if op == "add": combine_expr = f"{acc} + {value}" elif op == "mul": combine_expr = f"{acc} * {value}" elif op == "max": combine_expr = f"{acc} if {acc} > {value} else {value}" elif op == "min": combine_expr = f"{acc} if {acc} < {value} else {value}" else: raise AssertionError(op) if sorted_source is not None: value_lines = _cute_sorted_value_lines( state, load, sorted_source[1], scan_i, value, n_hint ) else: value_lines = [f" {value} = {_cute_scan_load_expr(load, scan_i)}"] include_expr = f"{scan_i} >= {out_pos}" if reverse else f"{scan_i} <= {out_pos}" state.codegen.add_statement( statement_from_string( "\n".join( [ f"for {scan_i} in range(cutlass.Int32(0), cutlass.Int32({n_hint}), cutlass.Int32(1)):", f" {include} = {include_expr}", *value_lines, f" {acc} = ({combine_expr}) if ({include} and {initialized}) else ({value} if {include} else {acc})", f" {initialized} = True if {include} else {initialized}", ] ) ) ) return expr_from_string(acc) def _get_input_tensor_ast(state: CodegenState, is_tuple_input: bool) -> ast.AST: """Get the input tensor AST, handling tuple inputs specially.""" if not is_tuple_input: return state.ast_arg(1) raw_input = state.ast_args[1] if isinstance(raw_input, tuple): from .._compiler.ast_extension import create tuple_elts = [ elt if isinstance(elt, ast.AST) else ast.Constant(value=elt) for elt in raw_input ] return create(ast.Tuple, elts=tuple_elts, ctx=ast.Load()) return state.ast_arg(1) def _scan_combine_operator(helper_graph_info: HelperFunctionGraphInfo) -> str: import operator as operator_mod graph = helper_graph_info.graph for node in graph.nodes: if node.op != "call_function": continue if node.target in ( operator_mod.add, torch.add, torch.ops.aten.add.Tensor, torch.ops.aten.add.Scalar, ): return "add" if node.target in ( operator_mod.mul, torch.mul, torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar, ): return "mul" if node.target in ( torch.maximum, torch.ops.aten.maximum.default, ): return "max" if node.target in ( torch.minimum, torch.ops.aten.minimum.default, ): return "min" raise exc.BackendUnsupported("cute", "associative_scan combine graph") def _cute_scan_load_expr(load: object, index: str) -> str: from .._compiler.compile_environment import CompileEnvironment from .._compiler.cute.indexing import CuteSortableLoad assert isinstance(load, CuteSortableLoad) index_exprs = list(load.index_exprs) index_exprs[load.sort_index_pos] = index expr = f"{load.tensor_name}[{', '.join(index_exprs)}]" if load.mask_expr is not None: dtype_str = CompileEnvironment.current().backend.dtype_str(load.dtype) return f"({expr} if {load.mask_expr} else {dtype_str}(0))" return expr def _cute_sorted_value_lines( state: CodegenState, load: object, descending: bool, out_pos: str, output_var: str, n_hint: int, ) -> list[str]: from .._compiler.compile_environment import CompileEnvironment from .._compiler.cute.indexing import CuteSortableLoad assert isinstance(load, CuteSortableLoad) env = CompileEnvironment.current() dtype_str = env.backend.dtype_str(load.dtype) index_dtype = env.backend.dtype_str(env.index_dtype) sorted_value = state.device_function.new_var("scan_sorted_value") candidate = state.device_function.new_var("scan_sort_k") probe = state.device_function.new_var("scan_sort_j") candidate_rank = state.device_function.new_var("scan_sort_rank") candidate_value = state.device_function.new_var("scan_sort_candidate") probe_value = state.device_function.new_var("scan_sort_probe") before = state.device_function.new_var("scan_sort_before") selected = state.device_function.new_var("scan_sort_selected") cmp_op = ">" if descending else "<" return [ f" {sorted_value} = {dtype_str}(0)", f" for {candidate} in range(cutlass.Int32(0), cutlass.Int32({n_hint}), cutlass.Int32(1)):", f" {candidate_value} = {_cute_scan_load_expr(load, candidate)}", f" {candidate_rank} = {index_dtype}(0)", f" for {probe} in range(cutlass.Int32(0), cutlass.Int32({n_hint}), cutlass.Int32(1)):", f" {probe_value} = {_cute_scan_load_expr(load, probe)}", f" {before} = ({probe_value} {cmp_op} {candidate_value}) or (({probe_value} == {candidate_value}) and ({probe} < {candidate}))", f" {candidate_rank} = {candidate_rank} + ({index_dtype}(1) if {before} else {index_dtype}(0))", f" {selected} = {candidate_rank} == {out_pos}", f" {sorted_value} = {candidate_value} if {selected} else {sorted_value}", f" {output_var} = {sorted_value}", ] def _register_helper_function(state: CodegenState, combine_graph_id: int) -> str: """Register the helper function and return its final name.""" from .._compiler.device_ir import HelperFunctionGraphInfo helper_graph_info = state.get_graph(combine_graph_id) assert isinstance(helper_graph_info, HelperFunctionGraphInfo) state.codegen.device_function.register_helper_function(helper_graph_info) # Get the final name from the helper manager (which uses the namespace) return state.codegen.device_function.helper_manager.get_final_name( helper_graph_info ) def _create_scan_expression( input_tensor: ast.AST, dim: int, helper_func_name: str, reverse: bool ) -> ast.AST: """Create the tl.associative_scan expression.""" from .._compiler.ast_extension import expr_from_string template = ( f"tl.associative_scan({{input_tensor}}, {{dim_value}}, {helper_func_name}, reverse=True)" if reverse else f"tl.associative_scan({{input_tensor}}, {{dim_value}}, {helper_func_name})" ) return expr_from_string( template, input_tensor=input_tensor, dim_value=ast.Constant(value=dim), ) def _create_tuple_result_expressions( state: CodegenState, scan_expr: ast.AST ) -> list[ast.AST]: """Create getitem expressions for tuple results.""" from .._compiler.ast_extension import expr_from_string raw_input = state.ast_args[1] num_elements = len(raw_input) if isinstance(raw_input, tuple) else 2 return [ expr_from_string( "{scan_result}[{index}]", scan_result=scan_expr, index=ast.Constant(value=i) ) for i in range(num_elements) ]