Rate this Page

Source code for helion.language.reduce_ops

from __future__ import annotations

import ast
import inspect
import itertools
import operator
from typing import TYPE_CHECKING
from typing import Callable
from typing import cast
from typing import overload

import torch
from torch.fx.experimental import proxy_tensor

from .. import exc
from . import _decorators

if TYPE_CHECKING:
    from .._compiler.device_ir import HelperFunctionGraphInfo
    from .._compiler.helper_function import CombineFunction
    from .._compiler.helper_function import CombineFunctionBasic
    from .._compiler.helper_function import CombineFunctionTuple
    from .._compiler.inductor_lowering import CodegenState


__all__ = ["reduce"]


@overload
@_decorators.api(is_device_only=True)
def reduce(
    combine_fn: CombineFunction,
    input_tensor: torch.Tensor,
    dim: int | None = None,
    other: float = 0,
    keep_dims: bool = False,
) -> torch.Tensor: ...


@overload
@_decorators.api(is_device_only=True)
def reduce(
    combine_fn: CombineFunction,
    input_tensor: tuple[torch.Tensor, ...],
    dim: int | None = None,
    other: float | tuple[float, ...] = 0,
    keep_dims: bool = False,
) -> tuple[torch.Tensor, ...]: ...


[docs] @_decorators.api(is_device_only=True) def reduce( combine_fn: CombineFunction, input_tensor: torch.Tensor | tuple[torch.Tensor, ...], dim: int | None = None, other: float | tuple[float, ...] = 0, keep_dims: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """ Applies a reduction operation along a specified dimension or all dimensions. This function is only needed for user-defined combine functions. Standard PyTorch reductions (such as sum, mean, amax, etc.) work directly in Helion without requiring this function. Args: combine_fn: A binary function that combines two elements element-wise. Must be associative and commutative for correct results. Can be tensor->tensor or tuple->tuple function. input_tensor: Input tensor or tuple of tensors to reduce dim: The dimension along which to reduce (None for all dimensions) other: Value for masked/padded elements (default: 0) For tuple inputs, can be tuple of values with same length keep_dims: If True, reduced dimensions are retained with size 1 Returns: torch.Tensor or tuple[torch.Tensor, ...]: Tensor(s) with reduced dimensions See Also: - :func:`~helion.language.associative_scan`: For prefix operations Note: - combine_fn must be associative and commutative - For standard reductions, use PyTorch functions directly (faster) - Masked elements use the 'other' value during reduction """ raise exc.NotInsideKernel
@_decorators.register_fake(reduce) def _( combine_fn: CombineFunction, input_tensor: torch.Tensor | tuple[torch.Tensor, ...], dim: int | None = None, other: float | tuple[float, ...] = 0, keep_dims: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Fake implementation that returns fake tensors with reduced shape.""" if isinstance(input_tensor, (tuple, list)): return tuple(_fake_reduce_tensor(t, dim, keep_dims) for t in input_tensor) return _fake_reduce_tensor(input_tensor, dim, keep_dims) @_decorators.ref(reduce) def _( combine_fn: CombineFunction, input_tensor: torch.Tensor | tuple[torch.Tensor, ...], dim: int | None = None, other: float | tuple[float, ...] = 0, keep_dims: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Reference implementation of reduce.""" from .._compiler.helper_function import extract_helper_function # Extract the raw function if it's wrapped in a @helion.kernel decorator combine_fn = extract_helper_function(combine_fn) is_tuple = isinstance(input_tensor, tuple) # Normalize inputs to always work with lists if not is_tuple: assert isinstance(other, (int, float)), ( "other must be a scalar for single tensor input" ) input_data = [input_tensor] other = (other,) # Wrap single-tensor combine function to work with tuples original_fn = cast("CombineFunctionBasic", combine_fn) def wrapped_combine_fn( left_tuple: tuple[torch.Tensor, ...], right_tuple: tuple[torch.Tensor, ...] ) -> tuple[torch.Tensor, ...]: result = original_fn(left_tuple[0], right_tuple[0]) return cast("tuple[torch.Tensor, ...]", (result,)) combine_fn = wrapped_combine_fn else: input_data = list(input_tensor) # Ensure other is a tuple with same length if not isinstance(other, tuple): other = (other,) * len(input_data) else: assert len(other) == len(input_data), ( "other tuple must match input tensor tuple length" ) # Get metadata from first tensor first_tensor = input_data[0] shape, ndim = first_tensor.shape, first_tensor.ndim # Check if unpacked arguments expected (tuple case only) if is_tuple: sig = inspect.signature(combine_fn) num_params = len(sig.parameters) expected_unpacked = 2 * len(input_data) # All elements unpacked if num_params == expected_unpacked: # Wrap unpacked function to accept packed arguments original_fn = cast("CombineFunctionTuple", combine_fn) def wrapped_combine_fn2( left_tuple: tuple[torch.Tensor, ...], right_tuple: tuple[torch.Tensor, ...], ) -> tuple[torch.Tensor, ...]: # pyrefly: ignore [bad-return] return original_fn(*left_tuple, *right_tuple) combine_fn = wrapped_combine_fn2 # Prepare reduction parameters if dim is None: dims_to_reduce = list(range(ndim)) else: if dim < 0: dim = ndim + dim dims_to_reduce = [dim] # Calculate output shape output_shape = [] for i, s in enumerate(shape): if i in dims_to_reduce: output_shape.append(1 if keep_dims else None) else: output_shape.append(s) output_shape = [s for s in output_shape if s is not None] # Create output tensors (always as list) outputs = [ torch.full(output_shape, other[i], dtype=t.dtype, device=t.device) for i, t in enumerate(input_data) ] # Perform reduction # Create index iterators for non-reduced dimensions index_iterators = [ [slice(None)] if i in dims_to_reduce else list(range(shape[i])) for i in range(len(shape)) ] # Iterate over all combinations of non-reduced dimensions # pyrefly: ignore [no-matching-overload] for idx in itertools.product(*index_iterators): # Gather values along reduction dimensions values_list = [] # Get ranges for each dimension being reduced reduction_ranges = [range(shape[d]) for d in dims_to_reduce] # Iterate over all combinations of indices in reduction dimensions for reduction_indices in itertools.product(*reduction_ranges): full_idx = list(idx) # Fill in the reduction dimension indices for d, pos in zip(dims_to_reduce, reduction_indices, strict=False): full_idx[d] = pos values_list.append(tuple(t[tuple(full_idx)] for t in input_data)) if not values_list: continue # No values to reduce # Reduce values result = values_list[0] tuple_combine_fn = cast( "Callable[[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]], tuple[torch.Tensor, ...]]", combine_fn, ) for values in values_list[1:]: result = tuple_combine_fn(result, values) # Build output index output_idx = tuple( 0 if isinstance(idx_val, slice) and keep_dims else idx_val for idx_val in idx if not isinstance(idx_val, slice) or keep_dims ) # Store results for i, out in enumerate(outputs): out[output_idx] = result[i] # Convert back to single tensor if needed if not is_tuple: return outputs[0] return tuple(outputs) def _fake_reduce_tensor( tensor: torch.Tensor, dim: int | None, keep_dims: bool ) -> torch.Tensor: """Helper to create a fake tensor with reduced dimensions.""" if dim is None: # Reduce all dimensions if keep_dims: return torch.empty( [1] * tensor.ndim, dtype=tensor.dtype, device=tensor.device ) return torch.empty([], dtype=tensor.dtype, device=tensor.device) # Reduce specific dimension new_shape = [*tensor.shape] # Handle negative dimension indexing if dim < 0: dim = tensor.ndim + dim if keep_dims: new_shape[dim] = 1 else: new_shape.pop(dim) return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) @_decorators.register_to_device_ir(reduce) def _( tracer: proxy_tensor.PythonKeyTracer, combine_fn: CombineFunction, input_tensor: torch.Tensor | tuple[torch.Tensor, ...], dim: int | None = None, other: float | tuple[float, ...] = 0, keep_dims: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """ Device IR implementation that handles tracing for reduce. We map reduce to _reduce, with a pre-traced graph for the combine function. """ from .._compiler.device_ir import DeviceIR from .._compiler.device_ir import HelperFunctionGraphInfo from .._compiler.device_ir import args_to_proxies from .._compiler.device_ir import select_decomp_table from .._compiler.helper_function import create_combine_function_wrapper from .._compiler.helper_function import extract_helper_function_name is_tuple_input = isinstance(input_tensor, (tuple, list)) if is_tuple_input: assert all(isinstance(t, torch.Tensor) for t in input_tensor), ( "reduce input must be a tuple of tensors" ) else: assert isinstance(input_tensor, torch.Tensor), "reduce input must be a tensor" assert callable(combine_fn), "combine_fn must be callable" # Extract the function name before wrapping original_function_name = extract_helper_function_name(combine_fn) combine_fn = create_combine_function_wrapper( combine_fn, is_tuple_input=is_tuple_input, target_format="tuple" ) # Create fake inputs for the combine function if is_tuple_input: # For tuple inputs, create two tuples of fake tensors for left and right args left_fake_tensors = [] right_fake_tensors = [] for tensor in input_tensor: left_fake_tensors.append( torch.empty([1], dtype=tensor.dtype, device=tensor.device) ) right_fake_tensors.append( torch.empty([1], dtype=tensor.dtype, device=tensor.device) ) # The combine function expects (left_tuple, right_tuple) fake_inputs = [tuple(left_fake_tensors), tuple(right_fake_tensors)] else: # For single tensor inputs, create two different fake tensors for left and right args left_fake_tensor = torch.empty( [1], dtype=input_tensor.dtype, device=input_tensor.device ) right_fake_tensor = torch.empty( [1], dtype=input_tensor.dtype, device=input_tensor.device ) fake_inputs = [left_fake_tensor, right_fake_tensor] combine_graph = proxy_tensor.make_fx( combine_fn, decomposition_table=select_decomp_table() )(*fake_inputs).graph combine_graph_id = DeviceIR.current().add_graph( combine_graph, HelperFunctionGraphInfo, node_args=[], original_function_name=original_function_name, ) # Validate other parameter for mask_node_inputs if is_tuple_input: assert isinstance(input_tensor, (tuple, list)) # Handle other parameter for tuple inputs if isinstance(other, (tuple, list)): if len(other) != len(input_tensor): raise ValueError( f"other tuple length {len(other)} must match input tensor length {len(input_tensor)}" ) # For tuple inputs with tuple others, mask_node_inputs doesn't directly support this # We'll handle this in a different way below else: # Broadcast single other value to all tensors - mask_node_inputs will handle this pass else: # Single tensor case if isinstance(other, (tuple, list)): raise ValueError("other must be a scalar for single tensor input") # Create the reduce tracing operation without other values (masking will be handled by mask_node_inputs) reduce_args = ( combine_graph_id, input_tensor, dim, keep_dims, is_tuple_input, ) proxy_args, proxy_kwargs = args_to_proxies(tracer, reduce_args) proxy_out = tracer.create_proxy( "call_function", _reduce, proxy_args, proxy_kwargs, ) # Apply masking to the input tensors in the proxy node from .._compiler.node_masking import apply_masking # Get the actual node from the proxy and apply masking actual_node = proxy_out.node if is_tuple_input and isinstance(other, (tuple, list)): # For tuple inputs with tuple others, apply masking to each tensor separately input_arg = actual_node.args[1] assert isinstance(input_arg, (tuple, list)) masked_tensors = [] for tensor_node, other_val in zip(input_arg, other, strict=True): assert isinstance(tensor_node, torch.fx.Node) masked_tensor = apply_masking( tensor_node, base_node=actual_node, other=other_val ) masked_tensors.append(masked_tensor) # Update the args with masked tensors actual_node.args = ( actual_node.args[0], tuple(masked_tensors), *actual_node.args[2:], ) else: # For single tensor or single other value, use mask_node_inputs from .._compiler.node_masking import mask_node_inputs # pyrefly: ignore [bad-argument-type] mask_node_inputs(actual_node, other=other) # Create output tensors with reduced shape if is_tuple_input: output_tensors = [] assert isinstance(input_tensor, (tuple, list)) for i, tensor in enumerate(input_tensor): reduced_tensor = _fake_reduce_tensor(tensor, dim, keep_dims) element_proxy = tracer.create_proxy( "call_function", operator.getitem, (proxy_out, i), {}, ) proxy_tensor.track_tensor_tree( reduced_tensor, element_proxy, constant=None, tracer=tracer ) output_tensors.append(reduced_tensor) return tuple(output_tensors) output_tensor = _fake_reduce_tensor(input_tensor, dim, keep_dims) proxy_tensor.track_tensor_tree( output_tensor, proxy_out, constant=None, tracer=tracer ) return output_tensor @_decorators.api() def _reduce( combine_graph_id: int, input_tensor: torch.Tensor | tuple[torch.Tensor, ...], dim: int | None = None, keep_dims: bool = False, is_tuple_input: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Device IR implementation of reduce, not meant to be called directly.""" raise AssertionError("this should never be called") @_decorators.register_fake(_reduce) def _( combine_graph_id: int, input_tensor: torch.Tensor | tuple[torch.Tensor, ...], dim: int | None = None, keep_dims: bool = False, is_tuple_input: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Fake implementation that returns tensors with reduced shape.""" if is_tuple_input: assert isinstance(input_tensor, (tuple, list)), input_tensor return tuple(_fake_reduce_tensor(t, dim, keep_dims) for t in input_tensor) assert isinstance(input_tensor, torch.Tensor), input_tensor return _fake_reduce_tensor(input_tensor, dim, keep_dims) @_decorators.codegen(_reduce, "triton") def _(state: CodegenState) -> ast.AST | list[ast.AST]: """Generate code for reduce with combine function.""" combine_graph_id = state.proxy_arg(0) dim = state.proxy_arg(2) keep_dims = state.proxy_arg(3) is_tuple_input = state.proxy_arg(4) # Input tensor is already masked, so we can use it directly if is_tuple_input: # For tuple inputs, we need to handle the tuple structure input_tensor = state.ast_args[1] if isinstance(input_tensor, tuple): from .._compiler.ast_extension import create input_tensor = create(ast.Tuple, elts=list(input_tensor), ctx=ast.Load()) else: input_tensor = state.ast_arg(1) else: input_tensor = state.ast_arg(1) helper_func_name = _register_helper_function(state, cast("int", combine_graph_id)) reduce_expr = _create_reduce_expression( input_tensor, dim, helper_func_name, bool(keep_dims) ) if is_tuple_input: return _create_tuple_result_expressions(state, reduce_expr) return reduce_expr def _infer_builtin_reduction_type_for_cute( state: CodegenState, combine_graph_id: int ) -> str | None: helper_graph_info = _get_helper_graph_info(state, combine_graph_id) output_values = _helper_graph_output_values(helper_graph_info) if output_values is None or len(output_values) != 1: return None output_node = output_values[0] if output_node.op != "call_function": return None return _target_to_builtin_reduction(output_node.target) def _target_to_builtin_reduction(target: object) -> str | None: if target == torch.ops.aten.add.Tensor: return "sum" if target == torch.ops.aten.maximum.default: return "max" if target == torch.ops.aten.minimum.default: return "min" if target == torch.ops.aten.mul.Tensor: return "prod" return None def _get_helper_graph_info( state: CodegenState, combine_graph_id: int ) -> HelperFunctionGraphInfo: from .._compiler.device_ir import HelperFunctionGraphInfo helper_graph_info = state.get_graph(combine_graph_id) assert isinstance(helper_graph_info, HelperFunctionGraphInfo) return helper_graph_info def _helper_graph_output_values( helper_graph_info: HelperFunctionGraphInfo, ) -> list[torch.fx.Node] | None: output_nodes = list(helper_graph_info.graph.find_nodes(op="output")) if len(output_nodes) != 1: return None output_value = output_nodes[0].args[0] if isinstance(output_value, torch.fx.Node): return [output_value] if not isinstance(output_value, (tuple, list)): return None nodes: list[torch.fx.Node] = [] for node in output_value: if not isinstance(node, torch.fx.Node): return None nodes.append(node) return nodes def _infer_tuple_builtin_reduction_types_for_cute( state: CodegenState, combine_graph_id: int, tuple_arity: int ) -> tuple[str, ...] | None: helper_graph_info = _get_helper_graph_info(state, combine_graph_id) output_values = _helper_graph_output_values(helper_graph_info) if output_values is None or len(output_values) != tuple_arity: return None placeholders = list(helper_graph_info.graph.find_nodes(op="placeholder")) if len(placeholders) != 2 * tuple_arity: return None reduction_types: list[str] = [] for i, output_node in enumerate(output_values): if output_node.op != "call_function": return None reduction_type = _target_to_builtin_reduction(output_node.target) if reduction_type is None: return None left_node = placeholders[i] right_node = placeholders[i + tuple_arity] if output_node.args != (left_node, right_node): return None reduction_types.append(reduction_type) return tuple(reduction_types) def _infer_tuple_argreduce_type_for_cute( state: CodegenState, combine_graph_id: int ) -> str | None: helper_graph_info = _get_helper_graph_info(state, combine_graph_id) output_values = _helper_graph_output_values(helper_graph_info) if output_values is None or len(output_values) != 2: return None value_where_node, index_where_node = output_values if ( value_where_node.op != "call_function" or value_where_node.target != torch.ops.aten.where.self ): return None if ( index_where_node.op != "call_function" or index_where_node.target != torch.ops.aten.where.self ): return None if len(value_where_node.args) != 3 or len(index_where_node.args) != 3: return None compare_node = value_where_node.args[0] if compare_node is not index_where_node.args[0]: return None if ( not isinstance(compare_node, torch.fx.Node) or compare_node.op != "call_function" or len(compare_node.args) != 2 ): return None compare_target = compare_node.target if compare_target not in {torch.ops.aten.gt.Tensor, torch.ops.aten.lt.Tensor}: return None placeholders = list(helper_graph_info.graph.find_nodes(op="placeholder")) if len(placeholders) != 4: return None left_value, left_index, right_value, right_index = placeholders if value_where_node.args[1:] not in { (right_value, left_value), (left_value, right_value), }: return None if index_where_node.args[1:] not in { (right_index, left_index), (left_index, right_index), }: return None if value_where_node.args[1:] == (right_value, left_value): choose_right_when_true = True elif value_where_node.args[1:] == (left_value, right_value): choose_right_when_true = False else: return None expected_index_branches = ( (right_index, left_index) if choose_right_when_true else (left_index, right_index) ) if index_where_node.args[1:] != expected_index_branches: return None compare_lhs, compare_rhs = compare_node.args if compare_lhs not in (left_value, right_value): return None if compare_rhs not in (left_value, right_value): return None if compare_lhs is compare_rhs: return None def selected_for_pair(left: int, right: int) -> int: lhs = left if compare_lhs is left_value else right rhs = left if compare_rhs is left_value else right if compare_target == torch.ops.aten.gt.Tensor: cond = lhs > rhs else: cond = lhs < rhs if cond: return right if choose_right_when_true else left return left if choose_right_when_true else right selected_01 = selected_for_pair(0, 1) selected_10 = selected_for_pair(1, 0) if selected_01 == 1 and selected_10 == 1: return "argmax" if selected_01 == 0 and selected_10 == 0: return "argmin" return None @_decorators.codegen(_reduce, "cute") def _(state: CodegenState) -> ast.AST | list[ast.AST]: from .._compiler.ast_extension import expr_from_string combine_graph_id = state.proxy_arg(0) dim = state.proxy_arg(2) is_tuple_input = bool(state.proxy_arg(4)) if dim is None: raise exc.BackendUnsupported("cute", "hl.reduce(..., dim=None)") from .._compiler.compile_environment import CompileEnvironment backend = CompileEnvironment.current().backend dim_int = cast("int", dim) combine_graph_id_int = cast("int", combine_graph_id) if not is_tuple_input: reduction_type = _infer_builtin_reduction_type_for_cute( state, combine_graph_id_int ) if reduction_type is None: raise exc.BackendUnsupported( "cute", "hl.reduce custom combine function", ) input_name = state.codegen.lift( state.ast_arg(1), dce=True, prefix="reduce_input" ).id return expr_from_string( backend.reduction_expr( input_name, reduction_type, dim_int, ) ) proxy_input = state.proxy_arg(1) ast_input = state.ast_args[1] if not isinstance(proxy_input, (tuple, list)) or not isinstance( ast_input, (tuple, list) ): raise exc.BackendUnsupported("cute", "hl.reduce tuple inputs") tuple_arity = len(proxy_input) if len(ast_input) != tuple_arity: raise exc.BackendUnsupported("cute", "hl.reduce tuple inputs") if reduction_types := _infer_tuple_builtin_reduction_types_for_cute( state, combine_graph_id_int, tuple_arity ): result_exprs: list[ast.AST] = [] for i, reduction_type in enumerate(reduction_types): input_node = ast_input[i] assert isinstance(input_node, ast.AST), input_node input_name = state.codegen.lift( input_node, dce=True, prefix=f"reduce_input_{i}" ).id result_exprs.append( expr_from_string( backend.reduction_expr( input_name, reduction_type, dim_int, ) ) ) return result_exprs argreduce_type = _infer_tuple_argreduce_type_for_cute(state, combine_graph_id_int) if argreduce_type is None: raise exc.BackendUnsupported("cute", "hl.reduce tuple custom combine function") if tuple_arity != 2: raise exc.BackendUnsupported( "cute", "hl.reduce tuple arg-reductions require 2 tuple elements", ) if not isinstance(proxy_input[0], torch.Tensor) or not isinstance( proxy_input[1], torch.Tensor ): raise exc.BackendUnsupported("cute", "hl.reduce tuple arg-reduction inputs") if not isinstance(ast_input[0], ast.AST) or not isinstance(ast_input[1], ast.AST): raise exc.BackendUnsupported("cute", "hl.reduce tuple arg-reduction inputs") value_name = state.codegen.lift(ast_input[0], dce=True, prefix="reduce_value").id index_name = state.codegen.lift(ast_input[1], dce=True, prefix="reduce_index").id index_dtype = proxy_input[1].dtype value_reduction = "max" if argreduce_type == "argmax" else "min" reduced_value_expr = expr_from_string( backend.reduction_expr( value_name, value_reduction, dim_int, ) ) reduced_index_expr = expr_from_string( backend.argreduce_result_expr( value_name, index_name, argreduce_type, dim_int, index_dtype, index_dtype=index_dtype, ) ) return [reduced_value_expr, reduced_index_expr] def _register_helper_function(state: CodegenState, combine_graph_id: int) -> str: """Register the helper function and return its final name.""" from .._compiler.device_ir import HelperFunctionGraphInfo helper_graph_info = state.get_graph(combine_graph_id) assert isinstance(helper_graph_info, HelperFunctionGraphInfo) state.codegen.device_function.register_helper_function(helper_graph_info) # Get the final name from the helper manager (which uses the namespace) return state.codegen.device_function.helper_manager.get_final_name( helper_graph_info ) def _create_reduce_expression( input_tensor: ast.AST, dim: object, helper_func_name: str, keep_dims: bool ) -> ast.AST: """Create the tl.reduce expression.""" from .._compiler.ast_extension import expr_from_string if dim is None: # Reduce all dimensions if keep_dims: template = ( f"tl.reduce({{input_tensor}}, None, {helper_func_name}, keep_dims=True)" ) else: template = f"tl.reduce({{input_tensor}}, None, {helper_func_name})" return expr_from_string( template, input_tensor=input_tensor, ) # Reduce specific dimension if keep_dims: template = f"tl.reduce({{input_tensor}}, {{dim_value}}, {helper_func_name}, keep_dims=True)" else: template = f"tl.reduce({{input_tensor}}, {{dim_value}}, {helper_func_name})" return expr_from_string( template, input_tensor=input_tensor, # pyrefly: ignore [bad-argument-type] dim_value=ast.Constant(value=dim), ) def _create_tuple_result_expressions( state: CodegenState, reduce_expr: ast.AST ) -> list[ast.AST]: """Create getitem expressions for tuple results.""" from .._compiler.ast_extension import expr_from_string raw_input = state.ast_args[1] num_elements = len(raw_input) if isinstance(raw_input, tuple) else 2 return [ expr_from_string( "{reduce_result}[{index}]", reduce_result=reduce_expr, index=ast.Constant(value=i), ) for i in range(num_elements) ]