from __future__ import annotations
import ast
import operator
from typing import TYPE_CHECKING
from typing import cast
from typing import overload
import torch
import torch._higher_order_ops as higher_order_ops
from torch.fx.experimental import proxy_tensor
from .. import exc
from . import _decorators
if TYPE_CHECKING:
from .._compiler.device_ir import HelperFunctionGraphInfo
from .._compiler.helper_function import CombineFunction
from .._compiler.inductor_lowering import CodegenState
from .._compiler.type_propagation import Origin
from .._compiler.type_propagation import TypeInfo
__all__ = ["associative_scan", "cumprod", "cumsum"]
@overload
@_decorators.device_func_replacement(higher_order_ops.associative_scan)
@_decorators.api(is_device_only=True)
def associative_scan(
combine_fn: CombineFunction,
input_tensor: torch.Tensor,
dim: int,
reverse: bool = False,
) -> torch.Tensor: ...
@overload
@_decorators.device_func_replacement(higher_order_ops.associative_scan)
@_decorators.api(is_device_only=True)
def associative_scan(
combine_fn: CombineFunction,
input_tensor: tuple[torch.Tensor, ...],
dim: int,
reverse: bool = False,
) -> tuple[torch.Tensor, ...]: ...
[docs]
@_decorators.device_func_replacement(higher_order_ops.associative_scan)
@_decorators.api(is_device_only=True)
def associative_scan(
combine_fn: CombineFunction,
input_tensor: torch.Tensor | tuple[torch.Tensor, ...],
dim: int,
reverse: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""
Applies an associative scan operation along a specified dimension.
Computes the prefix scan (cumulative operation) along a dimension using
a custom combine function. Unlike :func:`~helion.language.reduce`, this
preserves the input shape.
Args:
combine_fn: A binary function that combines two elements element-wise.
Must be associative for correct results.
Can be tensor->tensor or tuple->tuple function.
input_tensor: Input tensor or tuple of tensors to scan
dim: The dimension along which to scan
reverse: If True, performs the scan in reverse order
Returns:
torch.Tensor or tuple[torch.Tensor, ...]: Tensor(s) with same shape as input
containing the scan result
See Also:
- :func:`~helion.language.reduce`: For dimension-reducing operations
- :func:`~helion.language.cumsum`: For cumulative sum
- :func:`~helion.language.cumprod`: For cumulative product
Note:
- combine_fn must be associative (not necessarily commutative)
- Output has same shape as input (unlike reduce)
- For standard scans, use :func:`~helion.language.cumsum` or :func:`~helion.language.cumprod` (faster)
- Reverse scan applies the operation from right to left
"""
raise exc.NotInsideKernel
@_decorators.register_fake(associative_scan)
def _(
combine_fn: CombineFunction,
input_tensor: torch.Tensor | tuple[torch.Tensor, ...],
dim: int,
reverse: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Fake implementation that returns fake tensors with the same shape as input."""
if isinstance(input_tensor, (tuple, list)):
return tuple(torch.empty_like(t) for t in input_tensor)
return torch.empty_like(input_tensor)
@_decorators.ref(associative_scan)
def _(
combine_fn: CombineFunction,
input_tensor: torch.Tensor | tuple[torch.Tensor, ...],
dim: int,
reverse: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
return higher_order_ops.associative_scan(
combine_fn, input_tensor, dim, reverse=reverse
)
@_decorators.register_to_device_ir(associative_scan)
def _(
tracer: proxy_tensor.PythonKeyTracer,
combine_fn: CombineFunction,
input_tensor: torch.Tensor | tuple[torch.Tensor, ...],
dim: int,
reverse: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""
Device IR implementation that handles tracing for associative_scan. We map
associative_scan to _associative_scan, with a pre-traced graph for the combine
function.
"""
from .._compiler.device_ir import DeviceIR
from .._compiler.device_ir import HelperFunctionGraphInfo
from .._compiler.device_ir import args_to_proxies
from .._compiler.device_ir import select_decomp_table
from .._compiler.helper_function import create_combine_function_wrapper
from .._compiler.helper_function import extract_helper_function_name
is_tuple_input = isinstance(input_tensor, (tuple, list))
if is_tuple_input:
assert all(isinstance(t, torch.Tensor) for t in input_tensor), (
"associative_scan input must be a tuple of tensors"
)
else:
assert isinstance(input_tensor, torch.Tensor), (
"associative_scan input must be a tensor"
)
assert isinstance(dim, int), "associative_scan dim must be an integer"
assert callable(combine_fn), "combine_fn must be callable"
# Extract the function name before wrapping
original_function_name = extract_helper_function_name(combine_fn)
combine_fn = create_combine_function_wrapper(
combine_fn, is_tuple_input=is_tuple_input, target_format="unpacked"
)
# Create fake inputs for the combine function
fake_inputs = []
for tensor in input_tensor if is_tuple_input else [input_tensor]:
fake_inputs.extend(
[
torch.empty([1], dtype=tensor.dtype, device=tensor.device),
torch.empty([1], dtype=tensor.dtype, device=tensor.device),
]
)
combine_graph = proxy_tensor.make_fx(
combine_fn, decomposition_table=select_decomp_table()
)(*fake_inputs).graph
combine_graph_id = DeviceIR.current().add_graph(
combine_graph,
HelperFunctionGraphInfo,
node_args=[],
original_function_name=original_function_name,
)
# Create the associative_scan tracing operation
scan_args = (combine_graph_id, input_tensor, dim, reverse, is_tuple_input)
proxy_args, proxy_kwargs = args_to_proxies(tracer, scan_args)
proxy_out = tracer.create_proxy(
"call_function",
_associative_scan,
proxy_args,
proxy_kwargs,
)
# Create new output tensors to avoid aliasing input with output.
# tl.associative_scan modifies its input in-place, so we must track
# distinct output tensors to ensure the input remains usable after the scan.
if is_tuple_input:
output_tensors = []
assert isinstance(input_tensor, (tuple, list))
for i, tensor in enumerate(input_tensor):
output_tensor = torch.empty_like(tensor)
element_proxy = tracer.create_proxy(
"call_function",
operator.getitem,
(proxy_out, i),
{},
)
proxy_tensor.track_tensor_tree(
output_tensor, element_proxy, constant=None, tracer=tracer
)
output_tensors.append(output_tensor)
return tuple(output_tensors)
output_tensor = torch.empty_like(input_tensor)
proxy_tensor.track_tensor_tree(
output_tensor, proxy_out, constant=None, tracer=tracer
)
return output_tensor
@_decorators.type_propagation(associative_scan)
def _(
combine_fn: TypeInfo,
input_tensor: TypeInfo,
dim: TypeInfo,
reverse: TypeInfo | None = None,
*,
origin: Origin,
) -> TypeInfo:
"""Type propagation for associative_scan - output has same type as input."""
from .._compiler.type_propagation import CallableType
from .._compiler.type_propagation import SequenceType
from .._compiler.type_propagation import TensorType
# Validate that combine_fn is callable
if not isinstance(combine_fn, CallableType):
raise exc.TypeInferenceError(f"combine_fn must be callable, got {combine_fn}")
# Validate that input_tensor is a tensor or tuple of tensors
if isinstance(input_tensor, TensorType):
# Single tensor case
return input_tensor
if isinstance(input_tensor, SequenceType):
# Tuple of tensors case - validate all elements are tensors
for elem_type in input_tensor.unpack():
if not isinstance(elem_type, TensorType):
raise exc.TypeInferenceError(
f"All elements in tuple must be tensors, got {elem_type}"
)
# Return the same tuple type
return input_tensor
raise exc.TypeInferenceError(
f"input_tensor must be a tensor or tuple of tensors, got {input_tensor}"
)
[docs]
@_decorators.device_func_replacement(torch.cumsum)
def cumsum(input_tensor: torch.Tensor, dim: int, reverse: bool = False) -> torch.Tensor:
"""
Compute the cumulative sum along a specified dimension.
Equivalent to ``hl.associative_scan(torch.add, input_tensor, dim, reverse)``.
Args:
input_tensor: Input tensor to compute cumulative sum
dim: The dimension along which to compute cumulative sum
reverse: If True, performs the cumsum in reverse order
Returns:
torch.Tensor: Tensor with same shape as input containing cumulative sum
See Also:
- :func:`~helion.language.associative_scan`: For custom scan operations
- :func:`~helion.language.cumprod`: For cumulative product
- :func:`~helion.language.reduce`: For dimension-reducing operations
Note:
- Output has same shape as input
- Reverse=True computes cumsum from right to left
- Equivalent to torch.cumsum
"""
return associative_scan(torch.add, input_tensor, dim, reverse)
[docs]
@_decorators.device_func_replacement(torch.cumprod)
def cumprod(
input_tensor: torch.Tensor, dim: int, reverse: bool = False
) -> torch.Tensor:
"""
Compute the cumulative product along a specified dimension.
Equivalent to ``hl.associative_scan(torch.mul, input_tensor, dim, reverse)``.
Args:
input_tensor: Input tensor to compute cumulative product
dim: The dimension along which to compute cumulative product
reverse: If True, performs the cumprod in reverse order
Returns:
torch.Tensor: Tensor with same shape as input containing cumulative product
See Also:
- :func:`~helion.language.associative_scan`: For custom scan operations
- :func:`~helion.language.cumsum`: For cumulative sum
- :func:`~helion.language.reduce`: For dimension-reducing operations
Note:
- Output has same shape as input
- Reverse=True computes cumprod from right to left
- Equivalent to torch.cumprod
"""
return associative_scan(torch.mul, input_tensor, dim, reverse)
@_decorators.api()
def _associative_scan(
combine_graph_id: int,
input_tensor: torch.Tensor | tuple[torch.Tensor, ...],
dim: int,
reverse: bool = False,
is_tuple_input: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Device IR implementation of associative scan, not meant to be called directly."""
raise AssertionError("this should never be called")
@_decorators.register_fake(_associative_scan)
def _(
combine_graph_id: int,
input_tensor: torch.Tensor | tuple[torch.Tensor, ...],
dim: int,
reverse: bool = False,
is_tuple_input: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Fake implementation that returns a tensor/tuple with the same shape as input."""
if is_tuple_input:
assert isinstance(input_tensor, (tuple, list)), input_tensor
return tuple(torch.empty_like(t) for t in input_tensor)
assert isinstance(input_tensor, torch.Tensor), input_tensor
return torch.empty_like(input_tensor)
@_decorators.codegen(_associative_scan, "triton")
def _(state: CodegenState) -> ast.AST | list[ast.AST]:
"""Generate code for associative scan with combine function."""
combine_graph_id = state.proxy_arg(0)
dim = state.proxy_arg(2)
reverse = state.proxy_arg(3)
is_tuple_input = state.proxy_arg(4)
input_tensor = _get_input_tensor_ast(state, bool(is_tuple_input))
helper_func_name = _register_helper_function(state, cast("int", combine_graph_id))
scan_expr = _create_scan_expression(
input_tensor, cast("int", dim), helper_func_name, bool(reverse)
)
if is_tuple_input:
return _create_tuple_result_expressions(state, scan_expr)
return scan_expr
@_decorators.codegen(_associative_scan, "cute")
def _(state: CodegenState) -> ast.AST:
from torch.fx.node import Node
from .._compiler.ast_extension import expr_from_string
from .._compiler.ast_extension import statement_from_string
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.cute.indexing import CuteSortableLoad
from .._compiler.device_ir import HelperFunctionGraphInfo
combine_graph_id = cast("int", state.proxy_arg(0))
dim = cast("int", state.proxy_arg(2))
reverse = bool(state.proxy_arg(3))
is_tuple_input = bool(state.proxy_arg(4))
if is_tuple_input:
raise exc.BackendUnsupported("cute", "tuple associative_scan")
helper_graph_info = state.get_graph(combine_graph_id)
assert isinstance(helper_graph_info, HelperFunctionGraphInfo)
op = _scan_combine_operator(helper_graph_info)
if op not in ("add", "max", "min", "mul"):
raise exc.BackendUnsupported("cute", "associative_scan combine function")
fx_node = state.fx_node
if fx_node is None:
raise exc.BackendUnsupported("cute", "associative_scan without FX node")
input_node = fx_node.args[1]
input_tensor = fx_node.meta["val"]
if dim < 0:
dim += input_tensor.ndim
if dim != input_tensor.ndim - 1:
raise exc.BackendUnsupported("cute", "associative_scan non-last dimension")
scan_source = state.ast_args[1]
sorted_source: tuple[CuteSortableLoad, bool] | None = None
if (
isinstance(input_node, Node)
and input_node.target is operator.getitem
and isinstance(input_node.args[0], Node)
and input_node.args[0].target is torch.ops.aten.sort.default
):
sort_node = input_node.args[0]
load = sort_node.meta.get("cute_sort_load")
descending = sort_node.meta.get("cute_sort_descending")
if isinstance(load, CuteSortableLoad) and isinstance(descending, bool):
sorted_source = (load, descending)
if sorted_source is None:
if not isinstance(scan_source, CuteSortableLoad):
if isinstance(input_node, Node):
scan_source = input_node.meta.get("cute_sortable_load")
if not isinstance(scan_source, CuteSortableLoad):
raise exc.BackendUnsupported("cute", "associative_scan input")
load = scan_source
else:
load = sorted_source[0]
env = CompileEnvironment.current()
n = input_tensor.shape[-1]
n_hint = env.size_hint(n) if isinstance(n, torch.SymInt) else n
if not isinstance(n_hint, int):
raise exc.BackendUnsupported("cute", "dynamic associative_scan extent")
dtype_str = env.backend.dtype_str(input_tensor.dtype)
index_dtype = env.backend.dtype_str(env.index_dtype)
out_pos = state.device_function.new_var("scan_out_pos")
scan_i = state.device_function.new_var("scan_i")
acc = state.device_function.new_var("scan_acc")
initialized = state.device_function.new_var("scan_initialized")
include = state.device_function.new_var("scan_include")
value = state.device_function.new_var("scan_value")
state.codegen.add_statement(
statement_from_string(
f"{out_pos} = {index_dtype}({load.index_exprs[load.sort_index_pos]})"
)
)
identity = "1" if op == "mul" else "0"
state.codegen.add_statement(
statement_from_string(f"{acc} = {dtype_str}({identity})")
)
state.codegen.add_statement(statement_from_string(f"{initialized} = False"))
if op == "add":
combine_expr = f"{acc} + {value}"
elif op == "mul":
combine_expr = f"{acc} * {value}"
elif op == "max":
combine_expr = f"{acc} if {acc} > {value} else {value}"
elif op == "min":
combine_expr = f"{acc} if {acc} < {value} else {value}"
else:
raise AssertionError(op)
if sorted_source is not None:
value_lines = _cute_sorted_value_lines(
state, load, sorted_source[1], scan_i, value, n_hint
)
else:
value_lines = [f" {value} = {_cute_scan_load_expr(load, scan_i)}"]
include_expr = f"{scan_i} >= {out_pos}" if reverse else f"{scan_i} <= {out_pos}"
state.codegen.add_statement(
statement_from_string(
"\n".join(
[
f"for {scan_i} in range(cutlass.Int32(0), cutlass.Int32({n_hint}), cutlass.Int32(1)):",
f" {include} = {include_expr}",
*value_lines,
f" {acc} = ({combine_expr}) if ({include} and {initialized}) else ({value} if {include} else {acc})",
f" {initialized} = True if {include} else {initialized}",
]
)
)
)
return expr_from_string(acc)
def _get_input_tensor_ast(state: CodegenState, is_tuple_input: bool) -> ast.AST:
"""Get the input tensor AST, handling tuple inputs specially."""
if not is_tuple_input:
return state.ast_arg(1)
raw_input = state.ast_args[1]
if isinstance(raw_input, tuple):
from .._compiler.ast_extension import create
tuple_elts = [
elt if isinstance(elt, ast.AST) else ast.Constant(value=elt)
for elt in raw_input
]
return create(ast.Tuple, elts=tuple_elts, ctx=ast.Load())
return state.ast_arg(1)
def _scan_combine_operator(helper_graph_info: HelperFunctionGraphInfo) -> str:
import operator as operator_mod
graph = helper_graph_info.graph
for node in graph.nodes:
if node.op != "call_function":
continue
if node.target in (
operator_mod.add,
torch.add,
torch.ops.aten.add.Tensor,
torch.ops.aten.add.Scalar,
):
return "add"
if node.target in (
operator_mod.mul,
torch.mul,
torch.ops.aten.mul.Tensor,
torch.ops.aten.mul.Scalar,
):
return "mul"
if node.target in (
torch.maximum,
torch.ops.aten.maximum.default,
):
return "max"
if node.target in (
torch.minimum,
torch.ops.aten.minimum.default,
):
return "min"
raise exc.BackendUnsupported("cute", "associative_scan combine graph")
def _cute_scan_load_expr(load: object, index: str) -> str:
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.cute.indexing import CuteSortableLoad
assert isinstance(load, CuteSortableLoad)
index_exprs = list(load.index_exprs)
index_exprs[load.sort_index_pos] = index
expr = f"{load.tensor_name}[{', '.join(index_exprs)}]"
if load.mask_expr is not None:
dtype_str = CompileEnvironment.current().backend.dtype_str(load.dtype)
return f"({expr} if {load.mask_expr} else {dtype_str}(0))"
return expr
def _cute_sorted_value_lines(
state: CodegenState,
load: object,
descending: bool,
out_pos: str,
output_var: str,
n_hint: int,
) -> list[str]:
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.cute.indexing import CuteSortableLoad
assert isinstance(load, CuteSortableLoad)
env = CompileEnvironment.current()
dtype_str = env.backend.dtype_str(load.dtype)
index_dtype = env.backend.dtype_str(env.index_dtype)
sorted_value = state.device_function.new_var("scan_sorted_value")
candidate = state.device_function.new_var("scan_sort_k")
probe = state.device_function.new_var("scan_sort_j")
candidate_rank = state.device_function.new_var("scan_sort_rank")
candidate_value = state.device_function.new_var("scan_sort_candidate")
probe_value = state.device_function.new_var("scan_sort_probe")
before = state.device_function.new_var("scan_sort_before")
selected = state.device_function.new_var("scan_sort_selected")
cmp_op = ">" if descending else "<"
return [
f" {sorted_value} = {dtype_str}(0)",
f" for {candidate} in range(cutlass.Int32(0), cutlass.Int32({n_hint}), cutlass.Int32(1)):",
f" {candidate_value} = {_cute_scan_load_expr(load, candidate)}",
f" {candidate_rank} = {index_dtype}(0)",
f" for {probe} in range(cutlass.Int32(0), cutlass.Int32({n_hint}), cutlass.Int32(1)):",
f" {probe_value} = {_cute_scan_load_expr(load, probe)}",
f" {before} = ({probe_value} {cmp_op} {candidate_value}) or (({probe_value} == {candidate_value}) and ({probe} < {candidate}))",
f" {candidate_rank} = {candidate_rank} + ({index_dtype}(1) if {before} else {index_dtype}(0))",
f" {selected} = {candidate_rank} == {out_pos}",
f" {sorted_value} = {candidate_value} if {selected} else {sorted_value}",
f" {output_var} = {sorted_value}",
]
def _register_helper_function(state: CodegenState, combine_graph_id: int) -> str:
"""Register the helper function and return its final name."""
from .._compiler.device_ir import HelperFunctionGraphInfo
helper_graph_info = state.get_graph(combine_graph_id)
assert isinstance(helper_graph_info, HelperFunctionGraphInfo)
state.codegen.device_function.register_helper_function(helper_graph_info)
# Get the final name from the helper manager (which uses the namespace)
return state.codegen.device_function.helper_manager.get_final_name(
helper_graph_info
)
def _create_scan_expression(
input_tensor: ast.AST, dim: int, helper_func_name: str, reverse: bool
) -> ast.AST:
"""Create the tl.associative_scan expression."""
from .._compiler.ast_extension import expr_from_string
template = (
f"tl.associative_scan({{input_tensor}}, {{dim_value}}, {helper_func_name}, reverse=True)"
if reverse
else f"tl.associative_scan({{input_tensor}}, {{dim_value}}, {helper_func_name})"
)
return expr_from_string(
template,
input_tensor=input_tensor,
dim_value=ast.Constant(value=dim),
)
def _create_tuple_result_expressions(
state: CodegenState, scan_expr: ast.AST
) -> list[ast.AST]:
"""Create getitem expressions for tuple results."""
from .._compiler.ast_extension import expr_from_string
raw_input = state.ast_args[1]
num_elements = len(raw_input) if isinstance(raw_input, tuple) else 2
return [
expr_from_string(
"{scan_result}[{index}]", scan_result=scan_expr, index=ast.Constant(value=i)
)
for i in range(num_elements)
]