Rate this Page

Source code for helion.language.memory_ops

from __future__ import annotations

import ast
import contextlib
import dataclasses
import logging
import operator
import textwrap
from typing import TYPE_CHECKING

import torch
from torch.fx import has_side_effect
from torch.fx.node import map_arg

from .. import exc
from .._compiler.ast_extension import expr_from_string
from .._compiler.ast_extension import statement_from_string
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.compile_environment import _symint_expr
from .._compiler.cute.cute_epilogue import Tcgen05UnaryEpilogueChain
from .._compiler.cute.cute_epilogue import _AuxiliaryTensorStep
from .._compiler.cute.cute_epilogue import analyze_tcgen05_unary_epilogue_chain
from .._compiler.cute.cute_fx_walk import reach_tcgen05_matmul_anchors
from .._compiler.cute.cutedsl_compat import emit_pipeline_advance
from .._compiler.cute.strategies import tcgen05_explicit_d_store_tile_expr
from .._compiler.cute.strategies import tcgen05_is_two_cta_m128
from .._compiler.cute.strategies import tcgen05_resolve_epilogue_tile
from .._compiler.cute.strategies import tcgen05_two_cta_m128_epilogue_tile_expr
from .._compiler.cute.tcgen05_constants import (
    TCGEN05_ACC_WAIT_PLACEMENT_BEFORE_SUBTILE_LOOP,
)
from .._compiler.cute.tcgen05_constants import TCGEN05_ACC_WAIT_PLACEMENT_CONFIG_KEY
from .._compiler.cute.tcgen05_constants import TCGEN05_ACC_WAIT_PLACEMENT_SUBTILE_LOOP
from .._compiler.cute.tcgen05_constants import TCGEN05_C_ACQUIRE_PLACEMENT_CONFIG_KEY
from .._compiler.cute.tcgen05_constants import TCGEN05_C_ACQUIRE_PLACEMENT_FIRST_IN_LOOP
from .._compiler.cute.tcgen05_constants import (
    TCGEN05_C_ACQUIRE_PLACEMENT_LATER_BEFORE_BARRIER,
)
from .._compiler.cute.tcgen05_constants import TCGEN05_C_ACQUIRE_PLACEMENT_PRE_LOOP
from .._compiler.cute.tcgen05_constants import TCGEN05_C_STORE_MODE_CONFIG_KEY
from .._compiler.cute.tcgen05_constants import TCGEN05_C_STORE_MODE_NORMAL
from .._compiler.cute.tcgen05_constants import TCGEN05_C_STORE_MODE_SKIP_EPILOGUE_STORE
from .._compiler.cute.tcgen05_constants import TCGEN05_EPILOGUE_LAYOUT_CONFIG_KEY
from .._compiler.cute.tcgen05_constants import (
    TCGEN05_EPILOGUE_LAYOUT_MODULE_HELPER_ACC_T2R,
)
from .._compiler.cute.tcgen05_constants import (
    TCGEN05_EPILOGUE_LAYOUT_MODULE_HELPER_STORE_TAIL,
)
from .._compiler.cute.tcgen05_constants import TCGEN05_EPILOGUE_LAYOUT_NORMAL
from .._compiler.cute.tcgen05_constants import (
    TCGEN05_EPILOGUE_LAYOUT_SPLIT_ACC_T2R_STORE_TAIL,
)
from .._compiler.cute.tcgen05_constants import TCGEN05_EPILOGUE_LAYOUT_SPLIT_FIRST_T2R
from .._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_N
from .._compiler.cute.tcgen05_pure_matmul import Tcgen05TmaStoreBodyCoreParams
from .._compiler.cute.tcgen05_pure_matmul import Tcgen05TmaStorePipelineParams
from .._compiler.cute.tcgen05_pure_matmul import Tcgen05TmaStoreSubtileLoopParams
from .._compiler.cute.tcgen05_pure_matmul import Tcgen05TmaStoreTailParams
from .._compiler.host_function import HostFunction
from .._compiler.indexing_strategy import SubscriptIndexing
from .._compiler.indexing_strategy import TileWithOffsetInfo
from .._compiler.indexing_strategy import _get_tile_with_offset_info
from .._compiler.pallas import codegen as pallas_codegen
from .._compiler.utils import compute_slice_size
from .._compiler.variable_origin import GridOrigin
from .._compiler.variable_origin import TileBeginOrigin
from .._compiler.variable_origin import TileCountOrigin
from .._compiler.variable_origin import TileEndOrigin
from .._compiler.variable_origin import TileIdOrigin
from . import _decorators
from .stack_tensor import StackTensor

if TYPE_CHECKING:
    from .._compiler.inductor_lowering import CodegenState
    from .._compiler.tile_strategy import LoopDimInfo

from .._compiler.host_function import SymbolOrigin

# TileBeginWithOffset removed - using TileBeginWithOffsetPattern instead

__all__ = ["load", "store"]

log = logging.getLogger(__name__)


# Map short config names to full Triton API names for eviction policies
_EVICTION_POLICY_MAP = {
    "": None,
    "first": "evict_first",
    "last": "evict_last",
}


@dataclasses.dataclass(frozen=True)
class _AuxStepRecord:
    """Per-step splice-side AST locals for one auxiliary chain step.

    Holds the underlying aux tensor name, broadcast axis (None for
    exact-shape rank-2 aux), and the AST var names allocated for
    the partition pipeline. ``aux_view2d`` is set only for
    broadcast aux steps; exact-shape steps leave it ``None``. Used
    by ``_codegen_cute_store_tcgen05_tile`` to thread per-aux
    locals through the per-output-tile setup helper and the
    per-subtile load source helper.
    """

    aux_tensor_name: str
    broadcast_axis: int | None
    aux_tile: str
    aux_part_base: str
    aux_xfm: str
    aux_planned: str
    aux_epi: str
    aux_dtype: str
    aux_dtype_bits: int
    aux_extent: int | None
    ttr_aux: str
    ttr_aux_grouped: str
    ttr_aux_subtile: str
    aux_rmem: str
    aux_loaded: str
    aux_view2d: str | None
    # Pre-wait register hoist (bm=128 2-CTA family only): name of the
    # whole-fragment register tensor filled by ``autovec_copy`` BEFORE the
    # accumulator ``consumer_wait`` so the rowvec GMEM latency hides under
    # the MMA wait. ``None`` keeps the per-subtile GMEM load.
    aux_rmem_full: str | None = None


@dataclasses.dataclass(frozen=True)
class _RowvecAuxStageRecord:
    """Per-tile compact SMEM staging locals for one row-vector aux step."""

    smem_layout: str
    smem_ptr: str
    smem: str
    tiled_copy: str
    thr_copy: str
    gmem_tile: str
    gmem_part: str
    smem_part: str
    coord: str
    limit: str
    pred: str
    copy_bits: int
    copy_elems: int
    aux_extent: int


def _tcgen05_rowvec_aux_stage_copy_elems(
    aux_dtype_bits: int,
    block_n: int,
    aux_extent: int | None,
    *,
    copy_bits: int = 128,
) -> int | None:
    """Return the vector width when a row-vector aux can be staged safely."""

    if aux_extent is None or aux_dtype_bits <= 0:
        return None
    if copy_bits % aux_dtype_bits != 0:
        return None
    copy_elems = copy_bits // aux_dtype_bits
    if copy_elems <= 0:
        return None
    if block_n % copy_elems != 0 or aux_extent % copy_elems != 0:
        return None
    return copy_elems


[docs] @has_side_effect @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def store( tensor: torch.Tensor | StackTensor, index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, ) -> None: """Store a value to a tensor using a list of indices. This function is equivalent to `tensor[index] = value` but allows setting `extra_mask=` to mask elements beyond the default masking based on the hl.tile range. Args: tensor: The tensor / stack tensor to store to index: The indices to use to index into the tensor value: The value to store extra_mask: The extra mask (beyond automatic tile bounds masking) to apply to the tensor Returns: None """ raise exc.NotInsideKernel
@_decorators.prepare_args(store) def _( tensor: torch.Tensor | StackTensor, index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, ) -> tuple[ torch.Tensor | tuple, list[object], torch.Tensor | torch.SymInt | float | int, torch.Tensor | None, ]: from .tile_proxy import Tile if isinstance(value, torch.Tensor) and value.dtype != tensor.dtype: value = value.to(tensor.dtype) index = Tile._tiles_to_sizes_for_index(index) if isinstance(tensor, StackTensor): return (tuple(tensor), index, value, extra_mask) if isinstance(tensor, torch.Tensor): return (tensor, index, value, extra_mask) raise NotImplementedError(f"Cannot store to type: {type(tensor)}") @_decorators.register_fake(store) def _( tensor: torch.Tensor | tuple[object, ...], index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, ) -> None: return None @_decorators.codegen(store, "triton") def _(state: CodegenState) -> ast.AST: tensor = state.proxy_arg(0) subscript = state.proxy_arg(1) assert isinstance(subscript, (list, tuple)) value = state.ast_arg(2) extra_mask = state.ast_args[3] assert isinstance(extra_mask, (type(None), ast.AST)) if isinstance(tensor, torch.Tensor): device_fn = state.device_function fx_node = state.fx_node assert fx_node is not None epilogue_subtile_group_id = fx_node.meta.get("epilogue_subtile_group_id") if epilogue_subtile_group_id is None: indexing_idx = device_fn.allocate_store_index() elif fx_node.meta.get("epilogue_subtile_primary_output", False): indexing_idx = device_fn.allocate_store_index() device_fn.epilogue_subtile_store_indices[epilogue_subtile_group_id] = ( indexing_idx ) else: indexing_idx = device_fn.epilogue_subtile_store_indices[ epilogue_subtile_group_id ] strategy = device_fn.get_indexing_strategy(indexing_idx) cache_modifier = None if state.codegen.on_device: modifier_idx = device_fn.device_store_cache_modifier_index device_fn.device_store_cache_modifier_index += 1 modifiers = state.config.store_cache_modifiers if modifier_idx < len(modifiers) and modifiers[modifier_idx]: cache_modifier = ast.Constant(value=modifiers[modifier_idx]) if state.codegen.store_transform is not None: return state.codegen.store_transform( state, tensor, [*subscript], value, extra_mask, cache_modifier, strategy.codegen_store, ) return strategy.codegen_store( state, tensor, [*subscript], value, extra_mask, cache_modifier ) if isinstance(tensor, tuple): from .._compiler.indexing_strategy import StackIndexingStrategy # Fusion is not supported for stack stores (multi-tensor device pointers); # fall through to the unfused path regardless of store_transform. device_fn = state.device_function device_fn.allocate_store_index() cache_modifier = None if state.codegen.on_device: modifier_idx = device_fn.device_store_cache_modifier_index device_fn.device_store_cache_modifier_index += 1 modifiers = state.config.store_cache_modifiers if modifier_idx < len(modifiers) and modifiers[modifier_idx]: cache_modifier = ast.Constant(value=modifiers[modifier_idx]) stack_tensor_ast = state.ast_args[0] assert isinstance(stack_tensor_ast, tuple) assert len(stack_tensor_ast) == 2 _tensor_like_ast, dev_ptrs_ast = stack_tensor_ast return StackIndexingStrategy.codegen_store( state, tensor, dev_ptrs_ast, [*subscript], value, extra_mask, cache_modifier ) raise NotImplementedError(f"Cannot store to type: {type(tensor)}") def _record_pad_info( state: CodegenState, tensor: torch.Tensor, tensor_dim: int, block_id: int, extra_pad: int = 0, ) -> None: """Record that a tensor dimension uses pl.ds() and may need host-side padding. *extra_pad* accounts for non-zero loop begins: 0 when the loop starts at offset 0, ``begin % block_size`` for a constant begin, or ``block_size - 1`` for a data-dependent begin. Note: stores one entry per (tensor, dim). If two inner loops tile the same dim with different block_ids, the last one wins. This is fine when both loops use the same block size (the common case). """ pad_info = state.device_function.pallas_pad_info tensor_id = id(tensor) if tensor_id not in pad_info: pad_info[tensor_id] = {} pad_info[tensor_id][tensor_dim] = (block_id, extra_pad) def _maybe_get_symbol_origin(idx: object) -> SymbolOrigin | None: if not isinstance(idx, torch.SymInt): return None expr = _symint_expr(idx) if expr is None: return None return HostFunction.current().expr_to_origin.get(expr) @_decorators.codegen(store, "pallas") def _(state: CodegenState) -> None: tensor = state.proxy_arg(0) subscript = state.proxy_arg(1) assert isinstance(subscript, (list, tuple)) value = state.ast_arg(2) assert isinstance(tensor, torch.Tensor) name = state.device_function.tensor_arg(tensor).name name = pallas_codegen.vmem_name(state, name) # Increment memory op index to stay in sync with triton backend device_fn = state.device_function device_fn.device_store_index += 1 device_fn.device_memory_op_index += 1 parts, _ = pallas_codegen.index_parts(state, subscript, tensor) value = pallas_codegen.sliced_value_for_store( state, tensor, subscript, parts, value ) idx_str = ", ".join(parts) patterns = state.fx_node.meta.get("indexing_patterns") if state.fx_node else () from .._compiler.pallas.gather import emit_scatter_store from .._compiler.pallas.plan_tiling import IndirectScatterPattern scatter_patterns = [ pattern for pattern in patterns or () if isinstance(pattern, IndirectScatterPattern) ] assert len(scatter_patterns) <= 1, ( "Pallas store expected at most one indirect scatter pattern" ) if scatter_patterns: value = emit_scatter_store( state, scatter_patterns[0].plan, name, idx_str, value ) from .._compiler.pallas.ordered_carry import emit_carry_store if not scatter_patterns and state.device_function.carry_tiles: if emit_carry_store(state, tensor, subscript, name, idx_str, value): return state.codegen.add_statement( statement_from_string(f"{name}[{idx_str}] = {{value}}", value=value) ) def _matching_block_ids(env: CompileEnvironment, size: object) -> list[int]: """Find all block_ids that match the given dimension size.""" candidates: list[int] = [] if isinstance(size, (int, torch.SymInt)): if (direct := env.get_block_id(size)) is not None: candidates.append(direct) if not isinstance(size, (int, torch.SymInt)): return candidates for info in env.block_sizes: if not isinstance(info.size, (int, torch.SymInt)): continue if not env.known_equal(info.size, size): continue if info.block_id not in candidates: candidates.append(info.block_id) return candidates def _log_cute_layout(state: CodegenState, op_name: str) -> None: """Log the CuTe layout annotation for the current node, if any. This is used during CuTe load/store codegen to make layout info visible for debugging and future codegen integration. """ layout = state.cute_layout if layout is None: return node_name = state.fx_node.name if state.fx_node else "?" log.debug( "cute %s %s: layout tag=%s thread=%s value=%s", op_name, node_name, layout.tag.value, layout.thread_shape, layout.value_shape, ) def _cute_remap_block_id(state: CodegenState, block_id: int) -> int: """Apply the active matmul-operand block-id remap, if any. Used while re-materializing a matmul operand load so its contraction dimension is indexed by the active contraction block instead of the loop-invariant block it was originally lowered with. Returns *block_id* unchanged when no remap is active. """ remap = state.device_function.cute_state.matmul_operand_block_remap if not remap: return block_id return remap.get(block_id, block_id) def _cute_index_override(state: CodegenState, block_id: int) -> str | None: """Return a raw index-expression override for *block_id*, if active. Applied after ``_cute_remap_block_id``. When set (only while re-materializing the rhs of a static-MN-collapse baddbmm), the operand's free (N) axis is indexed by this serial-loop variable instead of the shared M thread index, and masking for that axis is suppressed. """ override = state.device_function.cute_state.matmul_operand_index_override if not override: return None return override.get(_cute_remap_block_id(state, block_id)) def _cute_active_index_var(state: CodegenState, block_id: int) -> str | None: if (override := _cute_index_override(state, block_id)) is not None: return override block_id = _cute_remap_block_id(state, block_id) loops = state.codegen.active_device_loops.get(block_id) if loops: return loops[-1].strategy.index_var(block_id) grid_state = state.codegen.current_grid_state if grid_state is not None and block_id in grid_state.block_ids: return grid_state.strategy.index_var(block_id) return None def _cute_active_mask_var(state: CodegenState, block_id: int) -> str | None: if _cute_index_override(state, block_id) is not None: return None block_id = _cute_remap_block_id(state, block_id) loops = state.codegen.active_device_loops.get(block_id) if loops: return loops[-1].strategy.mask_var(block_id) return None def _cute_unique_graph_block_id(state: CodegenState) -> int | None: fx_node = state.fx_node if fx_node is None: return None graph_block_ids = [ graph_info.block_ids for graph_info in state.codegen.codegen_graphs if graph_info.graph is fx_node.graph and hasattr(graph_info, "block_ids") ] if len(graph_block_ids) != 1 or len(graph_block_ids[0]) != 1: return None (block_id,) = graph_block_ids[0] return block_id def _maybe_codegen_cute_packed_affine_lhs_load( state: CodegenState, tensor: torch.Tensor, subscript: list[object] | tuple[object, ...], extra_mask: ast.AST | None, ) -> object | None: from .._compiler.cute.indexing import CutePackedAffineLoad from .._compiler.cute.indexing import match_cute_affine_range_iota from .._compiler.cute.indexing import match_cute_stack_reshape_rhs from .matmul_ops import dot fx_node = state.fx_node if ( fx_node is None or len(fx_node.users) != 1 or len(subscript) not in (2, 3) or len(fx_node.args) < 2 ): return None fx_subscript = fx_node.args[1] if not isinstance(fx_subscript, (list, tuple)) or len(fx_subscript) != len( subscript ): return None range_node = fx_subscript[-1] if not isinstance(range_node, torch.fx.Node): return None affine_range = match_cute_affine_range_iota(range_node) if affine_range is None: return None user = next(iter(fx_node.users)) if user.op != "call_function" or user.target not in { dot, torch.ops.aten.bmm.default, torch.ops.aten.baddbmm.default, torch.ops.aten.mm.default, torch.ops.aten.addmm.default, }: return None rhs_index = ( 2 if user.target in (torch.ops.aten.addmm.default, torch.ops.aten.baddbmm.default) else 1 ) rhs_arg = user.args[rhs_index] if not isinstance(rhs_arg, torch.fx.Node): return None packed_rhs = match_cute_stack_reshape_rhs(rhs_arg) if packed_rhs is None: return None _, factor = packed_rhs if factor != affine_range.factor: return None packed_block_id = _cute_unique_graph_block_id(state) if packed_block_id is None: return None packed_index = _cute_active_index_var(state, packed_block_id) if packed_index is None: return None leading_subscript = [*subscript[:-1]] row_index_exprs = _cute_index_exprs( state, leading_subscript, tensor=tensor, inactive_slice_expr="None", inactive_singleton_slice_expr="0", ) if len(row_index_exprs) != len(leading_subscript): return None tensor_name = state.device_function.tensor_arg(tensor).name mask_terms: list[str] = [] row_mask = _cute_combined_mask(state, leading_subscript, extra_mask, tensor=tensor) if row_mask is not None: mask_terms.append(row_mask) if packed_mask := _cute_active_mask_var(state, packed_block_id): mask_terms.append(f"({packed_mask})") mask_expr = " and ".join(mask_terms) if mask_terms else None zero = CompileEnvironment.current().backend.dtype_str(tensor.dtype) terms: list[ast.AST] = [] for offset in range(factor): index_expr = ", ".join( [ *row_index_exprs, f"cutlass.Int32({factor}) * ({packed_index}) + cutlass.Int32({offset})", ] ) term = expr_from_string(f"{tensor_name}[{index_expr}]") if mask_expr is not None: term = expr_from_string( f"({{value}} if {mask_expr} else {zero}(0))", value=term, ) terms.append(term) return CutePackedAffineLoad(tuple(terms)) def _maybe_codegen_cute_packed_rhs_load( state: CodegenState, tensor: torch.Tensor, subscript: list[object] | tuple[object, ...], extra_mask: ast.AST | None, ) -> ast.AST | None: from .._compiler.cute.indexing import match_cute_duplicate_stack_reshape_rhs fx_node = state.fx_node if fx_node is None or len(subscript) not in (2, 3) or len(fx_node.users) != 1: return None user = next(iter(fx_node.users)) if user.op != "call_function" or user.target is not torch.ops.aten.stack.default: return None stack_users = list(user.users) if len(stack_users) != 1 or not isinstance(stack_users[0], torch.fx.Node): return None rhs_node = stack_users[0] packed_rhs = match_cute_duplicate_stack_reshape_rhs(rhs_node) if packed_rhs != ( fx_node, len(user.args[0]) if isinstance(user.args[0], (list, tuple)) else 0, ): return None packed_block_id = _cute_unique_graph_block_id(state) if packed_block_id is None: return None packed_index = _cute_active_index_var(state, packed_block_id) if packed_index is None: return None leading_subscript = [*subscript[:-2]] col_index_exprs = _cute_index_exprs( state, [subscript[-1]], tensor=tensor, inactive_slice_expr="None", inactive_singleton_slice_expr="0", ) if len(col_index_exprs) != 1: return None (col_index,) = col_index_exprs leading_index_exprs = _cute_index_exprs( state, leading_subscript, tensor=tensor, inactive_slice_expr="None", inactive_singleton_slice_expr="0", ) if len(leading_index_exprs) != len(leading_subscript): return None tensor_name = state.device_function.tensor_arg(tensor).name load_index_expr = ", ".join([*leading_index_exprs, packed_index, col_index]) load_expr: ast.AST = expr_from_string(f"{tensor_name}[{load_index_expr}]") mask_terms: list[str] = [] col_mask = _cute_combined_mask( state, [*leading_subscript, subscript[-1]], extra_mask, tensor=tensor, ) if col_mask is not None: mask_terms.append(col_mask) if packed_mask := _cute_active_mask_var(state, packed_block_id): mask_terms.append(f"({packed_mask})") if not mask_terms: return load_expr zero = CompileEnvironment.current().backend.dtype_str(tensor.dtype) return expr_from_string( f"({{value}} if {' and '.join(mask_terms)} else {zero}(0))", value=load_expr, ) def _cute_index_exprs( state: CodegenState, subscript: list[object] | tuple[object, ...], ast_subscript: list[object] | tuple[object, ...] | None = None, tensor: torch.Tensor | None = None, *, inactive_slice_expr: str | None = None, inactive_singleton_slice_expr: str | None = None, ) -> list[str]: env = CompileEnvironment.current() def symint_index_expr(idx: torch.SymInt, used_block_ids: set[int]) -> str: expr = _symint_expr(idx) if expr is not None: origin_info = HostFunction.current().expr_to_origin.get(expr) if origin_info is not None and isinstance(origin_info.origin, GridOrigin): if type(origin_info.origin) is not GridOrigin: block_id = origin_info.origin.block_id loop_info = active_loop_info(block_id) begin_var = tile_begin_expr(block_id, loop_info) block_size_var = ( state.device_function.block_size_var(block_id) or "1" ) if isinstance(origin_info.origin, TileBeginOrigin): return begin_var if isinstance(origin_info.origin, TileEndOrigin): if loop_info is not None and loop_info.end_var_name is not None: return env.backend.minimum_expr( f"({begin_var}) + ({block_size_var})", loop_info.end_var_name, ) return f"({begin_var}) + ({block_size_var})" if isinstance(origin_info.origin, TileCountOrigin): end_var = ( loop_info.end_var_name if loop_info is not None and loop_info.end_var_name is not None else f"({begin_var}) + ({block_size_var})" ) extent = f"({end_var}) - ({begin_var})" return env.backend.cdiv_expr( extent, block_size_var, is_device=True ) if isinstance(origin_info.origin, TileIdOrigin): if block_size_var == "1": return begin_var return f"({begin_var}) // ({block_size_var})" return state.sympy_expr(expr) block_id = env.get_block_id(idx) if block_id is not None: used_block_ids.add(block_id) return index_var_for_block_id(block_id, idx) if expr is not None: return state.sympy_expr(expr) raise exc.BackendUnsupported("cute", f"unlowerable symbolic index: {idx}") def active_loop_info(block_id: int) -> LoopDimInfo | None: block_id = _cute_remap_block_id(state, block_id) loops = state.codegen.active_device_loops.get(block_id) if loops: return loops[-1].block_id_to_info.get(block_id) grid_state = state.codegen.current_grid_state if grid_state is not None: return grid_state.block_id_to_info.get(block_id) return None def active_local_coord(block_id: int) -> str | None: from .._compiler.cute.cute_reshape import _grid_local_coord_expr block_id = _cute_remap_block_id(state, block_id) loops = state.codegen.active_device_loops.get(block_id) if loops: thread_axis = loops[-1].block_thread_axes.get(block_id) if thread_axis is not None: return _grid_local_coord_expr(state.codegen, block_id, thread_axis) grid_state = state.codegen.current_grid_state if grid_state is not None: thread_axis = grid_state.block_thread_axes.get(block_id) if thread_axis is not None: return _grid_local_coord_expr(state.codegen, block_id, thread_axis) return None def tile_begin_expr(block_id: int, loop_info: LoopDimInfo | None) -> str: block_id = _cute_remap_block_id(state, block_id) loops = state.codegen.active_device_loops.get(block_id) if loops: return state.codegen.offset_var(block_id) begin_var = "0" if loop_info is not None and loop_info.begin_var_name is not None: begin_var = loop_info.begin_var_name global_index = active_index_var(block_id) local_coord = active_local_coord(block_id) if global_index is not None and local_coord is not None: return state.codegen.lift( expr_from_string(f"({global_index}) - ({local_coord})"), dce=True, prefix="tile_begin", ).id if global_index is not None: return global_index return begin_var def active_index_var(block_id: int) -> str | None: if (override := _cute_index_override(state, block_id)) is not None: return override block_id = _cute_remap_block_id(state, block_id) loops = state.codegen.active_device_loops.get(block_id) if loops: return loops[-1].strategy.index_var(block_id) grid_state = state.codegen.current_grid_state if grid_state is not None and block_id in grid_state.block_ids: return grid_state.strategy.index_var(block_id) return None def resolve_active_slice_block_id( size: object, used_block_ids: set[int], ) -> int | None: candidates = _matching_block_ids(env, size) active_candidates = [ block_id for block_id in candidates if active_index_var(block_id) is not None ] active_unused_candidates = [ block_id for block_id in active_candidates if block_id not in used_block_ids ] if len(active_unused_candidates) == 1: return active_unused_candidates[0] if len(active_candidates) == 1: return active_candidates[0] if len(active_unused_candidates) > 1: reduction_unused = [ block_id for block_id in active_unused_candidates if env.block_sizes[block_id].reduction ] if len(reduction_unused) == 1: return reduction_unused[0] if len(active_candidates) > 1: reduction_active = [ block_id for block_id in active_candidates if env.block_sizes[block_id].reduction ] if len(reduction_active) == 1: return reduction_active[0] return None def index_var_for_block_id(block_id: int, size: object) -> str: if (idx_var := active_index_var(block_id)) is not None: return idx_var raise exc.BackendUnsupported( "cute", ( "indexing dimension is not active in this scope " f"(block_id={block_id}, size={size})" ), ) def local_coord_for_block_id(block_id: int, begin_var: str) -> str | None: if (local_coord := active_local_coord(block_id)) is not None: return local_coord if (idx_var := active_index_var(block_id)) is not None: return f"({idx_var}) - ({begin_var})" return None def tile_with_offset_index_expr(tile_info: TileWithOffsetInfo) -> str: block_id = tile_info.block_id begin_var = tile_begin_expr(block_id, active_loop_info(block_id)) local_coord = local_coord_for_block_id(block_id, begin_var) if local_coord is None: raise exc.BackendUnsupported( "cute", ( "indexing dimension is not active in this scope " f"(block_id={block_id})" ), ) offset_expr = state.device_function.literal_expr(tile_info.offset) return f"({begin_var}) + cutlass.Int32({offset_expr}) + ({local_coord})" used_block_ids = { block_id for idx in subscript if isinstance(idx, torch.SymInt) if (block_id := env.get_block_id(idx)) is not None } result = [] tensor_dim = 0 for pos, idx in enumerate(subscript): ast_idx = None if ast_subscript is not None: ast_idx = ast_subscript[pos] if idx is None: continue if ( tensor is not None and tensor_dim < tensor.ndim and env.known_equal(tensor.shape[tensor_dim], 1) and not (isinstance(idx, slice) and idx == slice(None)) ): result.append("0") tensor_dim += 1 continue if ( tile_info := _get_tile_with_offset_info( idx, getattr(state, "fx_node", None), pos ) ) is not None and tile_info.block_size is not None: used_block_ids.add(tile_info.block_id) result.append(tile_with_offset_index_expr(tile_info)) tensor_dim += 1 continue if isinstance(idx, torch.SymInt): result.append(symint_index_expr(idx, used_block_ids)) tensor_dim += 1 elif isinstance(idx, int): result.append(str(idx)) tensor_dim += 1 elif isinstance(idx, torch.Tensor): from .._compiler.cute.indexing import CuteAffineRangeIndex if isinstance(ast_idx, CuteAffineRangeIndex): raise exc.BackendUnsupported( "cute", "affine hl.arange() indexing is only supported in CuTe packed-matmul load fusion", ) if not isinstance(ast_idx, ast.AST): raise exc.BackendUnsupported( "cute", f"tensor index without AST at position {pos}" ) lifted = state.codegen.lift(ast_idx, dce=True, prefix="index") index_dtype = env.backend.dtype_str(env.index_dtype) result.append(f"{index_dtype}({lifted.id})") tensor_dim += 1 elif isinstance(idx, slice) and idx == slice(None): if tensor is None: raise exc.BackendUnsupported("cute", "slice indexing without tensor") dim_size = tensor.shape[tensor_dim] block_id = resolve_active_slice_block_id(dim_size, used_block_ids) if block_id is not None: idx_var = active_index_var(block_id) assert idx_var is not None used_block_ids.add(block_id) result.append(idx_var) tensor_dim += 1 continue if inactive_singleton_slice_expr is not None and env.known_equal( dim_size, 1 ): result.append(inactive_singleton_slice_expr) tensor_dim += 1 continue if inactive_slice_expr is None: raise exc.BackendUnsupported( "cute", ( "indexing dimension is not active in this scope " f"(tensor_dim={pos}, size={dim_size})" ), ) result.append(inactive_slice_expr) tensor_dim += 1 elif isinstance(idx, slice) and (idx.step is None or idx.step == 1): # Partial slice (e.g. :16, 16:, or 5:20) if tensor is None: raise exc.BackendUnsupported( "cute", "partial slice indexing without tensor" ) dim_size = tensor.shape[tensor_dim] slice_size = compute_slice_size(idx, dim_size) start = idx.start if idx.start is not None else 0 block_id = resolve_active_slice_block_id(slice_size, used_block_ids) if block_id is not None: idx_var = active_index_var(block_id) assert idx_var is not None used_block_ids.add(block_id) if start == 0: result.append(idx_var) else: start_expr = state.device_function.literal_expr(start) result.append(f"({start_expr} + {idx_var})") tensor_dim += 1 continue raise exc.BackendUnsupported( "cute", ( "partial slice dimension is not active in this scope " f"(tensor_dim={pos}, size={slice_size})" ), ) elif isinstance(idx, slice): raise exc.BackendUnsupported( "cute", f"strided slices (step={idx.step}) are not supported" ) else: raise exc.BackendUnsupported("cute", f"index type: {type(idx)}") return result def _cute_index_tuple(index_exprs: list[str]) -> str: if len(index_exprs) == 1: return f"({index_exprs[0]},)" return f"({', '.join(index_exprs)})" def _cute_scalar_pointer_expr(tensor_name: str, index_exprs: list[str]) -> str: env = CompileEnvironment.current() index_dtype = env.index_type() offset = " + ".join( f"({index_dtype}({index}) * {index_dtype}({tensor_name}.layout.stride[{dim}]))" for dim, index in enumerate(index_exprs) ) return f"({tensor_name}.iterator + {offset})" def _cute_scalar_storage_dtype(dtype: torch.dtype) -> str: if dtype in (torch.float4_e2m1fn_x2, torch.float8_e4m3fn): return "cutlass.Uint8" return CompileEnvironment.current().backend.dtype_str(dtype) def _cute_scalar_load_expr( tensor_name: str, index_exprs: list[str], dtype: torch.dtype, ) -> str: if "None" in index_exprs: return f"{tensor_name}[{', '.join(index_exprs)}]" if dtype in (torch.float4_e2m1fn_x2, torch.float8_e4m3fn): return ( f"cute.arch.load({_cute_scalar_pointer_expr(tensor_name, index_exprs)}, " "cutlass.Uint8)" ) return f"{_cute_scalar_pointer_expr(tensor_name, index_exprs)}.load()" def _cute_scalar_store_expr( tensor_name: str, index_exprs: list[str], value: str ) -> str: if "None" in index_exprs: return f"{tensor_name}.__setitem__({_cute_index_tuple(index_exprs)}, {value})" return f"{_cute_scalar_pointer_expr(tensor_name, index_exprs)}.store({value})" # Maximum bytes per vector load/store transaction (LDG.128/STG.128). _CUTE_VECTOR_MAX_BYTES = 16 # Dtype -> (cutlass scalar type name, max vector width). Used for the # ``vec`` mode that issues an explicit # ``cute.arch.load(ptr, ir.VectorType.get([V], elem.mlir_type))`` and folds # the result via ``_cute_pre_vec_fold``. _CUTE_VECTOR_DTYPES: dict[torch.dtype, tuple[str, int]] = { torch.float32: ("cutlass.Float32", _CUTE_VECTOR_MAX_BYTES // 4), torch.float16: ("cutlass.Float16", _CUTE_VECTOR_MAX_BYTES // 2), torch.bfloat16: ("cutlass.BFloat16", _CUTE_VECTOR_MAX_BYTES // 2), } # ``unroll`` mode loads bf16/fp16 inputs as Uint16 vectors and bitcasts each # extracted lane back to the original dtype. This avoids the CuTe DSL # crash that fires when subscripting a bf16/fp16 vector value. Cutlass # scalar type for the extracted lane is paired with the vec-element type # name used in ``ir.VectorType.get``. _CUTE_VECTOR_UNROLL_DTYPES: dict[torch.dtype, str] = { torch.float16: "cutlass.Float16", torch.bfloat16: "cutlass.BFloat16", } # 1-byte fp8 dtypes also use ``unroll`` mode. Rather than a # ``VectorType([V], Uint8)`` load (which ICEs at V=8 in the CuTe DSL and # emits two LDG.32s for V=4), an fp8 vec chunk is loaded as a SINGLE packed # integer (``Uint32`` for V=4, ``Uint64`` for V=8) — one LDG.32 / LDG.64 — # and each lane byte is extracted with a shift+mask. The extracted ``Uint8`` # is decoded downstream by the matmul fallback's PTX helper. _CUTE_VECTOR_UNROLL_BYTE_DTYPES: frozenset[torch.dtype] = frozenset( {torch.float8_e4m3fn} ) # Packed-integer cutlass type per total byte width of an fp8 vec chunk. _CUTE_BYTE_PACK_TYPE: dict[int, str] = { 1: "cutlass.Uint8", 2: "cutlass.Uint16", 4: "cutlass.Uint32", 8: "cutlass.Uint64", } def _cute_is_byte_packed(dtype: torch.dtype) -> bool: return dtype in _CUTE_VECTOR_UNROLL_BYTE_DTYPES def _cute_is_unroll_dtype(dtype: torch.dtype) -> bool: """True for dtypes that use ``unroll`` mode: bf16/fp16 (Uint16 vector + bitcast) and fp8 (packed-integer load + shift extract).""" return ( dtype in _CUTE_VECTOR_UNROLL_DTYPES or dtype in _CUTE_VECTOR_UNROLL_BYTE_DTYPES ) def _cute_unroll_vec_elem_type(dtype: torch.dtype, vec_width: int = 1) -> str: """Cutlass load type for an ``unroll``-mode hoisted vec load. fp8 loads ``vec_width`` bytes as a single packed integer; bf16/fp16 load a ``Uint16`` vector element (the ``VectorType`` width is applied by callers). """ if _cute_is_byte_packed(dtype): pack = _CUTE_BYTE_PACK_TYPE.get(vec_width) assert pack is not None, f"unsupported fp8 vec_width {vec_width}" return pack return "cutlass.Uint16" def _cute_lane_axis_pos(strategy: object, block_id: int, index_exprs: list[str]) -> int: """Index_exprs position of the stride-1 lane axis for a tile_unroll hoist. Defaults to the last position (row-major lhs); the dispatcher records a different position (e.g. 0 for a K-major rhs) in ``_cute_lane_axis_pos_by_block``. """ pos_by_block = getattr(strategy, "_cute_lane_axis_pos_by_block", None) if isinstance(pos_by_block, dict): pos = pos_by_block.get(block_id) if isinstance(pos, int): return pos return len(index_exprs) - 1 def _cute_unroll_vec_load_dtype_arg(dtype: torch.dtype, vec_width: int) -> str: """The dtype argument to ``cute.arch.load`` for an unroll-mode hoist. fp8 loads ``vec_width`` contiguous bytes as ONE packed scalar integer (no ``VectorType`` — avoids the V=8 ``nvvm.load.ext`` ICE and emits a single LDG). bf16/fp16 load a ``Uint16`` vector of width ``vec_width``. """ if _cute_is_byte_packed(dtype): return _cute_unroll_vec_elem_type(dtype, vec_width) + ".mlir_type" return f"ir.VectorType.get([{vec_width}], cutlass.Uint16.mlir_type)" def _cute_unroll_vec_load_expr( ptr_expr: str, dtype: torch.dtype, vec_width: int ) -> str: """Build the ``cute.arch.load(...)`` RHS for an unroll-mode hoist.""" if _cute_is_byte_packed(dtype): pack = _cute_unroll_vec_elem_type(dtype, vec_width) return f"cute.arch.load({ptr_expr}, {pack})" return ( f"cute.arch.load({ptr_expr}, " f"ir.VectorType.get([{vec_width}], cutlass.Uint16.mlir_type))" ) def _cute_unroll_vec_extract(hoist_var: str, idx: str, dtype: torch.dtype) -> str: """Per-lane extract expr from a hoisted ``unroll``-mode vec load. fp8: ``hoist_var`` is a packed integer (Uint32/Uint64); byte ``idx`` is extracted with a shift+mask and returned as a ``Uint8`` (decoded downstream). bf16/fp16: ``hoist_var`` is a ``Uint16`` vector; lane ``idx`` is bitcast back to the original dtype. """ if dtype in _CUTE_VECTOR_UNROLL_BYTE_DTYPES: return f"cutlass.Uint8(({hoist_var} >> (8 * ({idx}))) & 0xFF)" elem_dtype = _CUTE_VECTOR_UNROLL_DTYPES[dtype] return f"cutlass.Uint16({hoist_var}[{idx}]).bitcast({elem_dtype})" def _cute_vector_load_expr( tensor_name: str, index_exprs: list[str], dtype: torch.dtype, *, vec_width: int, ) -> str: elem_str, _ = _CUTE_VECTOR_DTYPES[dtype] ptr = _cute_scalar_pointer_expr(tensor_name, index_exprs) return ( f"cute.arch.load({ptr}, ir.VectorType.get([{vec_width}], {elem_str}.mlir_type))" ) def _cute_vector_store_expr( tensor_name: str, index_exprs: list[str], value: str, dtype: torch.dtype, *, vec_width: int, ) -> str: elem_str, _ = _CUTE_VECTOR_DTYPES[dtype] ptr = _cute_scalar_pointer_expr(tensor_name, index_exprs) return ( f"cute.arch.store({ptr}, {value}, " f"ir.VectorType.get([{vec_width}], {elem_str}.mlir_type))" ) def _cute_register_unroll_vec_hoist( state: CodegenState, strategy: object, # LoopedReductionStrategy at runtime tensor: torch.Tensor, tensor_name: str, index_exprs: list[str], vec_width: int, ) -> str: """Register a Uint16 vec load to be hoisted above the constexpr V-loop in the active lane body and return the per-element extract expression. The hoist runs once per outer-lane iter; the constexpr V-loop's body receives ``hoist_var[vi].bitcast(dtype)`` (a scalar) so the existing cast/mul/accumulate pipeline keeps working unchanged. """ elem_dtype = _CUTE_VECTOR_UNROLL_DTYPES[tensor.dtype] base_index_var = getattr(strategy, "_cute_lane_base_index_var", None) lane_body = getattr(strategy, "_cute_lane_body", None) assert isinstance(base_index_var, str) assert isinstance(lane_body, list) # The inner reduction-axis index_expr is the last entry; swap it with # the per-lane base so the vec load points at the start of the V-wide # chunk this thread owns. base_exprs = list(index_exprs) base_exprs[-1] = base_index_var base_ptr_expr = _cute_scalar_pointer_expr(tensor_name, base_exprs) cache_key = (tensor_name, base_ptr_expr) cache = getattr(strategy, "_cute_lane_vec_loads", None) if cache is None: cache = {} # pyrefly: ignore [missing-attribute] strategy._cute_lane_vec_loads = cache if cache_key not in cache: hoist_var = state.device_function.new_var( f"_unroll_vec_{len(cache)}", dce=False ) cache[cache_key] = (hoist_var, tensor.dtype) hoist_stmt = statement_from_string( f"{hoist_var} = cute.arch.load({base_ptr_expr}, " f"ir.VectorType.get([{vec_width}], cutlass.Uint16.mlir_type))" ) # Insert the hoist just BEFORE the constexpr V-loop (the last entry # in lane_body). ``lane_body[-1]`` is the constexpr loop. lane_body.insert(len(lane_body) - 1, hoist_stmt) else: hoist_var, _ = cache[cache_key] # The constexpr V-loop's target var is the last element's loop var. constexpr_loop = lane_body[-1] assert isinstance(constexpr_loop, ast.For) assert isinstance(constexpr_loop.target, ast.Name) vec_lane_var = constexpr_loop.target.id return f"cutlass.Uint16({hoist_var}[{vec_lane_var}]).bitcast({elem_dtype})" def _cute_register_tile_unroll_vec_hoist( state: CodegenState, strategy: object, # BlockSizeTileStrategy (CuteNDTileStrategy) block_id: int, tensor: torch.Tensor, tensor_name: str, index_exprs: list[str], vec_width: int, ) -> str: """Tile-loop variant of ``_cute_register_unroll_vec_hoist`` for ``CuteNDTileStrategy`` lane loops. Splices a single ``cute.arch.load(base_ptr, <elem>x V)`` into the outer-lane body (above the constexpr V-loop) and returns the per-element extract expression so the existing scalar pipeline keeps working. bf16/fp16 load as ``Uint16`` and bitcast; fp8 loads ``vec_width`` bytes as one packed integer that the matmul fallback decodes downstream. """ base_var_by_block = getattr(strategy, "_cute_lane_base_index_var_by_block", {}) lane_body_by_block = getattr(strategy, "_cute_lane_body_by_block", {}) vec_lane_var_by_block = getattr(strategy, "_cute_vec_lane_var_by_block", {}) base_index_var = base_var_by_block.get(block_id) lane_body = lane_body_by_block.get(block_id) vec_lane_var = vec_lane_var_by_block.get(block_id) assert isinstance(base_index_var, str) assert isinstance(lane_body, list) assert isinstance(vec_lane_var, str) # The lane-axis index_expr (stride-1 dim) is swapped with the per-lane # base so the vec load points at the start of the V-wide chunk this # thread owns. The position is the last entry for a row-major lhs, or # the recorded position for a K-major rhs. lane_pos = _cute_lane_axis_pos(strategy, block_id, index_exprs) base_exprs = list(index_exprs) base_exprs[lane_pos] = base_index_var base_ptr_expr = _cute_scalar_pointer_expr(tensor_name, base_exprs) cache_key = (tensor_name, base_ptr_expr) cache_by_block = getattr(strategy, "_cute_lane_vec_loads_by_block", None) if cache_by_block is None: cache_by_block = {} # pyrefly: ignore [missing-attribute] strategy._cute_lane_vec_loads_by_block = cache_by_block cache = cache_by_block.setdefault(block_id, {}) if cache_key not in cache: hoist_var = state.device_function.new_var( f"_tile_unroll_vec_{block_id}_{len(cache)}", dce=False ) cache[cache_key] = (hoist_var, tensor.dtype) # Guard the LDG against per-thread OOB: on the very last grid # block + tail outer-tile iter, a thread whose vec base equals # ``numel`` would otherwise read past the end of the underlying # allocation (the next row doesn't exist for the last grid # block). Use an "anchor pointer" fallback for the unsafe # threads: it points inside the tensor (specifically at the # per-thread base of the FIRST outer-tile iter, which is the # ``base_ptr_expr`` with the outer-lane index folded to 0). The # fetched bytes are then ignored downstream by the per-lane # mask gate that wraps the bitcast result. env_local = CompileEnvironment.current() numel = env_local.block_sizes[block_id].numel numel_expr = state.sympy_expr(numel) # Build the "anchor" pointer: same index_exprs but with the # inner reduction-axis index forced to 0. This is the # ``tile_offset == 0, lane_var == 0, vec_lane_var == 0`` base # for the very first outer-tile iter, which is always in-bounds # for any grid block. anchor_exprs = list(index_exprs) anchor_exprs[lane_pos] = "0" anchor_ptr_expr = _cute_scalar_pointer_expr(tensor_name, anchor_exprs) guarded_ptr = ( f"({base_ptr_expr} if {base_index_var} < {numel_expr} " f"else {anchor_ptr_expr})" ) hoist_stmt = statement_from_string( f"{hoist_var} = {_cute_unroll_vec_load_expr(guarded_ptr, tensor.dtype, vec_width)}" ) # Insert the hoist just BEFORE the constexpr V-loop (the last # entry in lane_body). lane_body.insert(len(lane_body) - 1, hoist_stmt) else: hoist_var, _ = cache[cache_key] return _cute_unroll_vec_extract(hoist_var, vec_lane_var, tensor.dtype) def _cute_register_tile_unroll_vec_hoist_split2( state: CodegenState, strategy: object, # BlockSizeTileStrategy (CuteNDTileStrategy) block_id: int, tensor: torch.Tensor, tensor_name: str, index_exprs: list[str], vec_width: int, ) -> str: """Split-2 variant of ``_cute_register_tile_unroll_vec_hoist`` for V=8 on fp16/bf16. The CuTe DSL's ``nvvm.load.ext`` ICEs at V=8 for these dtypes, so the full 16-byte LDG.128 is decomposed into TWO back-to-back V=4 loads (lanes 0-3 and 4-7). The SASS scheduler is free to overlap the two LDGs, so the per-thread bytes-per-load grows from 8 (V=4) to the full 16 (effective V=8) without invoking the DSL bug. Returns a per-vec-lane expression of the form:: ( cutlass.Uint16(_tile_unroll_vec_ < n > _ < m > _a[vi]).bitcast(dtype) if vi < 4 else cutlass.Uint16(_tile_unroll_vec_ < n > _ < m > _b[vi - 4]).bitcast( dtype ) ) Because ``vec_lane_var`` is the target of a ``cutlass.range_constexpr(8)`` loop, it is a Python-int constant at each unrolled iter, so the ``if vi < 4`` branch folds away at trace time and the emitted SASS contains only the active load's extract. """ assert vec_width == 8, ( "tile_unroll_split2 expects V=8 (4+4); other widths use tile_unroll" ) half = vec_width // 2 vec_elem_type = _cute_unroll_vec_elem_type(tensor.dtype) base_var_by_block = getattr(strategy, "_cute_lane_base_index_var_by_block", {}) lane_body_by_block = getattr(strategy, "_cute_lane_body_by_block", {}) vec_lane_var_by_block = getattr(strategy, "_cute_vec_lane_var_by_block", {}) base_index_var = base_var_by_block.get(block_id) lane_body = lane_body_by_block.get(block_id) vec_lane_var = vec_lane_var_by_block.get(block_id) assert isinstance(base_index_var, str) assert isinstance(lane_body, list) assert isinstance(vec_lane_var, str) lane_pos = _cute_lane_axis_pos(strategy, block_id, index_exprs) base_exprs = list(index_exprs) base_exprs[lane_pos] = base_index_var base_ptr_expr_a = _cute_scalar_pointer_expr(tensor_name, base_exprs) # The second-half pointer points 4 elements past the first. Build # it by substituting ``base_index_var + half`` for the inner index. base_exprs_b = list(index_exprs) base_exprs_b[lane_pos] = f"({base_index_var} + {half})" base_ptr_expr_b = _cute_scalar_pointer_expr(tensor_name, base_exprs_b) cache_key = (tensor_name, base_ptr_expr_a, "split2") cache_by_block = getattr(strategy, "_cute_lane_vec_loads_by_block", None) if cache_by_block is None: cache_by_block = {} # pyrefly: ignore [missing-attribute] strategy._cute_lane_vec_loads_by_block = cache_by_block cache = cache_by_block.setdefault(block_id, {}) if cache_key not in cache: slot = len(cache) hoist_var_a = state.device_function.new_var( f"_tile_unroll_vec_{block_id}_{slot}_a", dce=False ) hoist_var_b = state.device_function.new_var( f"_tile_unroll_vec_{block_id}_{slot}_b", dce=False ) # Stash both names plus the split marker so this entry doesn't # collide with the V=4 cache_key shape. Downstream readers # don't introspect this tuple — it's just a sentinel. cache[cache_key] = ((hoist_var_a, hoist_var_b), tensor.dtype) env_local = CompileEnvironment.current() numel = env_local.block_sizes[block_id].numel numel_expr = state.sympy_expr(numel) anchor_exprs = list(index_exprs) anchor_exprs[lane_pos] = "0" anchor_ptr_expr = _cute_scalar_pointer_expr(tensor_name, anchor_exprs) # The first-half OOB guard checks the same V-aligned base used by # the V=4 path; the second-half pointer is ``base + 4`` and only # needs guarding when ``base + 4 < numel``. Reuse the same # anchor pointer for both halves' fallbacks (the per-element # mask gate downstream drops any anchor-fetched bytes anyway). guarded_ptr_a = ( f"({base_ptr_expr_a} if {base_index_var} < {numel_expr} " f"else {anchor_ptr_expr})" ) guarded_ptr_b = ( f"({base_ptr_expr_b} if ({base_index_var} + {half}) < {numel_expr} " f"else {anchor_ptr_expr})" ) hoist_stmt_a = statement_from_string( f"{hoist_var_a} = cute.arch.load({guarded_ptr_a}, " f"ir.VectorType.get([{half}], {vec_elem_type}.mlir_type))" ) hoist_stmt_b = statement_from_string( f"{hoist_var_b} = cute.arch.load({guarded_ptr_b}, " f"ir.VectorType.get([{half}], {vec_elem_type}.mlir_type))" ) # Insert both hoists just BEFORE the constexpr V-loop (the last # entry in lane_body). Emit them back-to-back so the SASS # scheduler can issue the two LDGs together. lane_body.insert(len(lane_body) - 1, hoist_stmt_a) lane_body.insert(len(lane_body) - 1, hoist_stmt_b) else: (hoist_var_a, hoist_var_b), _ = cache[cache_key] extract_a = _cute_unroll_vec_extract(hoist_var_a, vec_lane_var, tensor.dtype) extract_b = _cute_unroll_vec_extract( hoist_var_b, f"{vec_lane_var} - {half}", tensor.dtype ) return f"({extract_a} if {vec_lane_var} < {half} else {extract_b})" def _cute_vector_load_ctx( state: CodegenState, tensor: torch.Tensor, subscript: list[object] | tuple[object, ...], index_exprs: list[str], extra_mask: ast.AST | None, ) -> tuple[int, int, str] | None: """Return (vec_width, lane_block_id, mode) when a vec load may be emitted. ``mode`` is one of ``"vec"`` (explicit ``cute.arch.load(..., V)``) or ``"unroll"`` (per-element scalar bitcast inside a constexpr V-loop). Returns None when any predicate for a 128-bit gmem load fails, in which case the caller falls back to ``_cute_scalar_load_expr``. """ from .._compiler.reduction_strategy import LoopedReductionStrategy env = CompileEnvironment.current() if env.backend.name != "cute": return None if extra_mask is not None: return None if "None" in index_exprs: return None if tensor.dtype not in _CUTE_VECTOR_DTYPES and not _cute_is_unroll_dtype( tensor.dtype ): return None # Only enable the vec path when the load's result eventually feeds a # reduction op. The consume-sweep mixes the loaded vector with scalar # values (e.g. the post-reduction inverse-RMS), and broadcasting # scalar->vec is not supported by the CuTe DSL today. When the load's # immediate user is a dtype cast (``to(torch.float32)``), the # ``"unroll"`` mode further down keeps the strategy on a per-element # scalar pipeline and the explicit-vec path is skipped — the explicit # ``cute.arch.load(ptr, ir.VectorType.get([V], dtype.mlir_type))`` form # would otherwise crash inside the CuTe DSL when subscripting bf16/fp16 # vectors. fx_node = state.fx_node if fx_node is None: return None visited: set[torch.fx.Node] = set() pending = list(fx_node.users.keys()) feeds_reduction = False while pending: user = pending.pop() if user in visited: continue visited.add(user) target_name = getattr(user.target, "__name__", "") or "" target_qualname = getattr(user.target, "_qualname", "") or "" if ( "reduction" in target_name or "_inductor_lowering_extra" in target_name or "reduction" in target_qualname ): feeds_reduction = True break pending.extend(user.users.keys()) # Note: ``feeds_reduction`` is required ONLY for the ``vec`` mode below; # the ``unroll`` mode also applies to the consume sweep where the load # result feeds an elementwise pipeline (no reduction). # The lane/vec axis must be a tensor dim that is stride-1 so that # consecutive lane iters fetch consecutive bytes. For a row-major lhs # the reduction axis is the LAST subscript position; for a column-major # rhs (e.g. the K-major ``y`` of a tcgen05 fp8 matmul) it is the FIRST. # ``_cute_lane_axis_pos`` records the index_exprs position of that # stride-1 lane axis so the hoist substitutes the per-lane base there # (not blindly at ``[-1]``). # Find the stride-1 dim WITHOUT forcing specialization of a symbolic # stride: a contiguous dim has a concrete ``int`` stride of 1, so only # accept plain ints here. Calling ``int()`` on a ``SymInt`` stride would # bake the (otherwise-dynamic) size into the kernel — see the # ``test_mark_static`` regression where ``int(stride(0))`` specialized # ``n``. stride1_tensor_dim: int | None = None for d in range(tensor.ndim): s = tensor.stride(d) if isinstance(s, int) and s == 1: stride1_tensor_dim = d break if stride1_tensor_dim is None: return None # Locate the non-None subscript carrying an active lane block. Slices # resolve to the matching tensor-dim block via the strategy that's # currently active for that block. Prefer the block sitting on the # stride-1 tensor dim (the true lane axis), and record its index_exprs # position. inner_block_id: int | None = None lane_axis_pos: int | None = None expr_pos = -1 tensor_dim = 0 for idx in subscript: if idx is None: continue expr_pos += 1 if isinstance(idx, torch.SymInt): bid = env.get_block_id(idx) if bid is not None and state.codegen.active_device_loops.get(bid): if tensor_dim == stride1_tensor_dim or inner_block_id is None: inner_block_id = bid lane_axis_pos = expr_pos elif isinstance(idx, slice) and idx == slice(None): if tensor_dim < tensor.ndim: dim_size = tensor.shape[tensor_dim] for cand_bid, bs in enumerate(env.block_sizes): if not isinstance(bs.size, (int, torch.SymInt)): continue bs_numel = bs.numel # Try a few candidate forms for the size equality # check: sympy.Integer (most common via specialize()), # int, and torch.SymInt all flow through known_equal # after we coerce to plain int when possible. bs_int: int | torch.SymInt | None if isinstance(bs_numel, (int, torch.SymInt)): bs_int = bs_numel else: try: bs_int = int(bs_numel) except (TypeError, ValueError): bs_int = None if bs_int is None: continue dim_int: int | torch.SymInt | None if isinstance(dim_size, (int, torch.SymInt)): dim_int = dim_size else: try: dim_int = int(dim_size) except (TypeError, ValueError): dim_int = None if dim_int is None: continue if env.known_equal( bs_int, dim_int ) and state.codegen.active_device_loops.get(cand_bid): if tensor_dim == stride1_tensor_dim or inner_block_id is None: inner_block_id = cand_bid lane_axis_pos = expr_pos break tensor_dim += 1 if inner_block_id is None or lane_axis_pos is None: return None loops = state.codegen.active_device_loops.get(inner_block_id) if not loops: return None strategy = getattr(loops[-1], "strategy", None) if isinstance(strategy, LoopedReductionStrategy): vec_width = getattr(strategy, "_cute_reduction_vec_width", 1) if vec_width <= 1: return None if strategy._mask_var is not None: return None if strategy._cute_reduction_lane_extent <= 0: return None mode = getattr(strategy, "_cute_reduction_vec_mode", "vec") if mode == "vec": if not feeds_reduction: return None if tensor.dtype not in _CUTE_VECTOR_DTYPES: return None return vec_width, inner_block_id, "vec" if mode == "unroll": if tensor.dtype not in _CUTE_VECTOR_UNROLL_DTYPES: return None # The CuTe DSL's ``nvvm.load.ext`` only supports vec sizes 2 # and 4 for bf16/fp16 (V=8 raises ICE). Cap effective V # here so the autotuner's V=8 seed still compiles instead # of crashing. if vec_width > 4: return None # Need a lane base index var + a constexpr V-loop var; both # are set up by the strategy's codegen_device_loop. if ( getattr(strategy, "_cute_lane_base_index_var", None) is None or getattr(strategy, "_cute_lane_body", None) is None ): return None return vec_width, inner_block_id, "unroll" return None # CuTe N-D tile strategy with lane loops: vec is set up per-block in # ``CuteNDTileStrategy.__init__`` when the autotuner picks # ``cute_vector_widths[block_id]`` > 1 and EPT is divisible by V. Mode # is forced to ``"unroll"`` (per-element bitcast) for fp16/bf16 since # subscripting a bf16/fp16 vector in the CuTe DSL is unsafe; fp32 # could in principle use ``"vec"`` but the per-element pipeline runs # most of the consume-sweep code after a cast, so unroll is the # robust choice. from .._compiler.tile_strategy import BlockSizeTileStrategy if isinstance(strategy, BlockSizeTileStrategy): vec_by_block = getattr(strategy, "_cute_lane_vec_width_by_block", None) if not isinstance(vec_by_block, dict): return None vec_width = vec_by_block.get(inner_block_id, 1) if vec_width <= 1: return None if not _cute_is_unroll_dtype(tensor.dtype): return None # The CuTe DSL's ``nvvm.load.ext`` ICEs at V=8 for fp16/bf16 (and # for the V=8 ``Uint8`` vector used by fp8), so widths > 4 cannot # use a single ``cute.arch.load``. V=8 still # gets full LDG.128 throughput via the ``tile_unroll_split2`` # mode: two back-to-back ``cute.arch.load(..., V=4)`` calls # (covering vec lanes 0-3 and 4-7) emit as two LDG.64s that the # SASS scheduler can overlap. Wider Vs (16, 32, ...) are not # supported. if vec_width > 8: return None if vec_width == 8 and vec_width % 4 != 0: return None base_var_by_block = getattr( strategy, "_cute_lane_base_index_var_by_block", None ) lane_body_by_block = getattr(strategy, "_cute_lane_body_by_block", None) vec_lane_var_by_block = getattr(strategy, "_cute_vec_lane_var_by_block", None) if ( not isinstance(base_var_by_block, dict) or not isinstance(lane_body_by_block, dict) or not isinstance(vec_lane_var_by_block, dict) or inner_block_id not in base_var_by_block or inner_block_id not in lane_body_by_block or inner_block_id not in vec_lane_var_by_block ): return None # When the per-thread vec base could straddle the tensor edge # (e.g. ``numel`` not a multiple of V), the masked-tail iter # could load garbage in some lanes. Gate the per-element mask # path correctly by requiring ``numel % V == 0`` so partial-vec # straddles are impossible. numel = env.block_sizes[inner_block_id].numel if not env.known_multiple(numel, vec_width): return None # Record the index_exprs position of the stride-1 lane axis so the # hoist substitutes the per-lane base there. Row-major lhs loads # use the last position; a column-major rhs (K-major ``y``) uses # position 0. pos_by_block = getattr(strategy, "_cute_lane_axis_pos_by_block", None) if not isinstance(pos_by_block, dict): pos_by_block = {} # pyrefly: ignore [missing-attribute] strategy._cute_lane_axis_pos_by_block = pos_by_block pos_by_block[inner_block_id] = lane_axis_pos # fp8 loads a packed Uint64 (V=8) / Uint32 (V=4) in the regular # ``tile_unroll`` path — no ``VectorType`` so no V=8 ICE, hence no # split2 needed. bf16/fp16 V=8 still needs the 2x V=4 split. if vec_width == 8 and not _cute_is_byte_packed(tensor.dtype): return vec_width, inner_block_id, "tile_unroll_split2" return vec_width, inner_block_id, "tile_unroll" return None def _cute_stack_tensor_offset_expr( state: CodegenState, tensor_like: torch.Tensor, subscript: list[object], ast_subscript: list[object] | tuple[object, ...], ) -> str: env = CompileEnvironment.current() index_exprs = _cute_index_exprs( state, subscript, ast_subscript, tensor=tensor_like, inactive_slice_expr="None", inactive_singleton_slice_expr="0", ) if "None" in index_exprs: raise exc.BackendUnsupported("cute", "inactive stack tensor load dimension") index_dtype = env.index_type() terms = [] for dim, index in enumerate(index_exprs): stride = tensor_like.stride(dim) stride_expr = ( str(stride) if isinstance(stride, int) else state.sympy_expr(stride) ) terms.append(f"({index_dtype}({index}) * {index_dtype}({stride_expr}))") return " + ".join(terms) if terms else "0" def _cute_stack_tensor_mask_expr( state: CodegenState, tensor_like: torch.Tensor, dev_ptrs: torch.Tensor, subscript: list[object], extra_mask: ast.AST | None, ) -> str | None: terms = [] tensor_mask = _cute_combined_mask( state, subscript, extra_mask, tensor=tensor_like, include_tensor_index_masks=False, ) if tensor_mask is not None: terms.append(tensor_mask) stack_mask = _cute_combined_mask( state, [slice(None)] * dev_ptrs.ndim, None, tensor=dev_ptrs, ) if stack_mask is not None and stack_mask not in terms: terms.append(stack_mask) if not terms: return None return " and ".join(f"({term})" for term in terms) def _cute_stack_tensor_pointer_expr( target_dtype: str, dev_ptrs_ast: ast.AST, offset_expr: str, ) -> ast.AST: return expr_from_string( f"(cute.make_ptr({target_dtype}, cutlass.Int64({{base}}), " f"cute.AddressSpace.gmem) + ({offset_expr}))", base=dev_ptrs_ast, ) def _codegen_cute_store_stack_load( state: CodegenState, tensor: torch.Tensor, subscript: tuple[object, ...] | list[object], ast_subscript: tuple[object, ...] | list[object], value: ast.AST, extra_mask: ast.AST | None, value_node: torch.fx.Node, ) -> ast.AST | None: if value_node.op != "call_function" or value_node.target is not load: return None stack_arg = value_node.args[0] if not isinstance(stack_arg, tuple) or len(stack_arg) != 2: return None ptr_node = stack_arg[1] if ( not isinstance(ptr_node, torch.fx.Node) or ptr_node.op != "call_function" or ptr_node.target is not load or len(ptr_node.args) < 2 ): return None dev_ptrs = ( ptr_node.args[0].meta.get("val") if isinstance(ptr_node.args[0], torch.fx.Node) else None ) ptr_subscript = ptr_node.args[1] if not isinstance(dev_ptrs, torch.Tensor) or not isinstance( ptr_subscript, (list, tuple) ): return None tensor_like_node = stack_arg[0] tensor_like = ( tensor_like_node.meta.get("val") if isinstance(tensor_like_node, torch.fx.Node) else tensor_like_node ) if not isinstance(tensor_like, torch.Tensor): return None if ( dev_ptrs.ndim == 2 and len(ptr_subscript) == 2 and all(isinstance(idx, slice) and idx == slice(None) for idx in ptr_subscript) and len(subscript) >= 3 and isinstance(subscript[0], slice) and subscript[0] == slice(None) and isinstance(subscript[1], slice) and subscript[1] == slice(None) ): stack_value_subscript = value_node.args[1] if not isinstance(stack_value_subscript, (list, tuple)): return None stack_value_subscript_proxy = map_arg( stack_value_subscript, lambda arg: arg.meta["val"] ) stack_value_subscript_ast = map_arg( stack_value_subscript, lambda arg: state.env[arg] ) tensor_offset_expr = _cute_stack_tensor_offset_expr( state, tensor_like, [*stack_value_subscript_proxy], [*stack_value_subscript_ast], ) target_index_exprs = _cute_index_exprs( state, [*subscript], ast_subscript, tensor=tensor, inactive_singleton_slice_expr="0", ) if len(target_index_exprs) != tensor.ndim: return None first_stack_index = target_index_exprs[0] target_tail = target_index_exprs[2:] loop_var = state.device_function.new_var("stack_dim", dce=True) env = CompileEnvironment.current() index_dtype = env.index_type() dev_ptrs_name = state.device_function.tensor_arg(dev_ptrs).name tensor_name = state.device_function.tensor_arg(tensor).name target_dtype = env.backend.dtype_str(tensor.dtype) dev_ptr_offset = ( f"{index_dtype}({first_stack_index}) * " f"{index_dtype}({dev_ptrs.stride(0)}) + " f"{index_dtype}({loop_var}) * {index_dtype}({dev_ptrs.stride(1)})" ) stack_ptr_expr = ( f"(cute.make_ptr({target_dtype}, " f"cutlass.Int64(({dev_ptrs_name}.iterator + {dev_ptr_offset}).load()), " f"cute.AddressSpace.gmem) + ({tensor_offset_expr}))" ) target_indices = [first_stack_index, loop_var, *target_tail] store_expr = _cute_scalar_store_expr( tensor_name, target_indices, f"({stack_ptr_expr}).load()", ) mask_expr = _cute_combined_mask(state, [*subscript], extra_mask, tensor=tensor) if mask_expr is None: body = f" {store_expr}" else: body = f" if {mask_expr}:\n {store_expr}" state.add_statement( statement_from_string( f"for {loop_var} in range({dev_ptrs.size(1)}):\n{body}" ) ) return ast.Constant(value=None) ptr_subscript_proxy = map_arg(ptr_subscript, lambda arg: arg.meta["val"]) ptr_subscript_ast = map_arg(ptr_subscript, lambda arg: state.env[arg]) ptr_index_exprs = _cute_index_exprs( state, [*ptr_subscript_proxy], [*ptr_subscript_ast], tensor=dev_ptrs, inactive_slice_expr="None", inactive_singleton_slice_expr="0", ) if "None" in ptr_index_exprs: return None target_index_exprs = _cute_index_exprs( state, [*subscript], ast_subscript, tensor=tensor, inactive_singleton_slice_expr="0", ) ptr_pos = 0 rewritten_index_exprs = [] for idx, index_expr in zip(subscript, target_index_exprs, strict=True): if isinstance(idx, slice) and idx == slice(None): replacement = ( ptr_index_exprs[ptr_pos] if ptr_pos < len(ptr_index_exprs) else None ) ptr_pos += 1 rewritten_index_exprs.append( replacement if replacement is not None else index_expr ) else: if ptr_pos < len(ptr_subscript_proxy) and not ( isinstance(ptr_subscript_proxy[ptr_pos], slice) and ptr_subscript_proxy[ptr_pos] == slice(None) ): ptr_pos += 1 rewritten_index_exprs.append(index_expr) tensor_name = state.device_function.tensor_arg(tensor).name backend = CompileEnvironment.current().backend target_dtype = backend.dtype_str(tensor.dtype) value = expr_from_string( backend.ast_to_dtype_expr("{value}", target_dtype), value=value, ) store_expr = expr_from_string( _cute_scalar_store_expr(tensor_name, rewritten_index_exprs, "{value}"), value=value, ) mask_expr = _cute_combined_mask(state, [*subscript], extra_mask, tensor=tensor) if mask_expr is None: return store_expr mask_ast = expr_from_string(mask_expr) assert isinstance(mask_ast, ast.expr) assert isinstance(store_expr, ast.expr) state.add_statement( ast.fix_missing_locations( ast.If( test=mask_ast, body=[ast.Expr(value=store_expr)], orelse=[], ) ) ) return ast.Constant(value=None) def _cute_affine_range_block_id(state: CodegenState, affine: object) -> int | None: from .._compiler.cute.indexing import CuteAffineRangeIndex if not isinstance(affine, CuteAffineRangeIndex): return None env = CompileEnvironment.current() base_meta = getattr(affine.base, "meta", {}) base_val = base_meta.get("val") if isinstance(base_meta, dict) else None block_id = env.resolve_block_id(base_val) if base_val is not None else None if block_id is None: codegen = base_meta.get("codegen") if isinstance(base_meta, dict) else None if isinstance(codegen, ast.Name) and codegen.id.startswith("_BLOCK_SIZE_"): with contextlib.suppress(ValueError): block_id = int(codegen.id.removeprefix("_BLOCK_SIZE_")) if block_id is None: return None if state.fx_node is not None: return env.resolve_codegen_block_id( block_id, state.codegen, state.fx_node.graph ) return block_id def _cute_affine_range_expr( state: CodegenState, affine: object, lane_var: str, *, dtype: torch.dtype | None = None, ) -> str | None: from .._compiler.cute.indexing import CuteAffineRangeIndex if not isinstance(affine, CuteAffineRangeIndex): return None if affine.step != 1 or affine.factor <= 0: return None block_id = _cute_affine_range_block_id(state, affine) if block_id is None: return None index_var = _cute_active_index_var(state, block_id) if index_var is None: return None expr = f"({affine.factor}) * ({index_var}) + cutlass.Int32({lane_var})" if dtype is not None: expr = f"{CompileEnvironment.current().backend.dtype_str(dtype)}({expr})" return expr def _codegen_cute_affine_range_store( state: CodegenState, tensor: torch.Tensor, subscript: list[object] | tuple[object, ...], ast_subscript: list[object] | tuple[object, ...], value: object, extra_mask: ast.AST | None, value_node: torch.fx.Node | None = None, ) -> ast.AST | None: from .._compiler.ast_extension import create from .._compiler.cute.indexing import CuteAffineRangeIndex affine_positions = [ (pos, idx) for pos, idx in enumerate(ast_subscript) if isinstance(idx, CuteAffineRangeIndex) ] if len(affine_positions) != 1 or len(subscript) != 1 or extra_mask is not None: return None _pos, affine = affine_positions[0] block_id = _cute_affine_range_block_id(state, affine) if block_id is None: return None lane_var = state.device_function.new_var("affine_lane", dce=True) index_expr = _cute_affine_range_expr( state, affine, lane_var, dtype=CompileEnvironment.current().index_dtype ) if index_expr is None: return None backend = CompileEnvironment.current().backend if ( value_node is not None and value_node.op == "call_function" and value_node.target is load ): source_tensor_node = value_node.args[0] if not isinstance(source_tensor_node, torch.fx.Node): return None source_tensor = source_tensor_node.meta.get("val") if not isinstance(source_tensor, torch.Tensor): return None source_subscript = value_node.args[1] if ( not isinstance(source_subscript, (list, tuple)) or len(source_subscript) != 1 ): return None ast_source_subscript = list( map_arg(tuple(source_subscript), lambda arg: state.env[arg]) ) (source_affine,) = ast_source_subscript if not isinstance(source_affine, CuteAffineRangeIndex): return None if source_affine.factor != affine.factor: return None source_index_expr = _cute_affine_range_expr( state, source_affine, lane_var, dtype=CompileEnvironment.current().index_dtype, ) if source_index_expr is None: return None source_name = state.device_function.tensor_arg(source_tensor).name value_expr = f"{source_name}[{source_index_expr}]" if source_tensor.dtype is torch.bool: value_expr = f"({value_expr} != cutlass.Uint8(0))" elif isinstance(value, CuteAffineRangeIndex): value_expr = _cute_affine_range_expr(state, value, lane_var, dtype=value.dtype) if value_expr is None: return None elif isinstance(value, ast.AST): value_expr = ast.unparse(value) elif isinstance(value, (int, float, bool)): value_expr = repr(value) else: return None target_dtype = backend.dtype_str(tensor.dtype) value_expr = backend.ast_to_dtype_expr(value_expr, target_dtype) tensor_name = state.device_function.tensor_arg(tensor).name store_expr = ( f"{tensor_name}.__setitem__({_cute_index_tuple([index_expr])}, {value_expr})" ) mask_var = _cute_active_mask_var(state, block_id) if mask_var is not None: store_expr = f"{store_expr} if {mask_var} else None" return create( ast.For, target=create(ast.Name, id=lane_var, ctx=ast.Store()), iter=expr_from_string(f"range({affine.factor})"), body=[create(ast.Expr, value=expr_from_string(store_expr))], orelse=[], type_comment=None, ) def _codegen_cute_affine_reshape_store( state: CodegenState, tensor: torch.Tensor, subscript: list[object] | tuple[object, ...], ast_subscript: list[object] | tuple[object, ...], extra_mask: ast.AST | None, value_node: torch.fx.Node | None, ) -> ast.AST | None: """Lower a 2-D affine-row store fed by a reshape/stack chain. Handles ``out[(begin*K):(begin*K + block*K), tile_n] = reshaped`` where the leading index is a ``CuteAffineRangeIndex`` (factor ``K``) over the m-tile, the trailing index is the n-tile, and the value is a row-major shape chain (e.g. ``stack([a, b], dim=1).reshape(block*K, block_n)``). Each m-tile thread owns row ``m_local`` of the source; the reshaped tensor has ``K`` rows per source row, so the thread loops ``s in range(K)`` and writes the value resolved at flat index ``(K*m_local + s)*block_n + n_local`` to output row ``K*m_global + s``, column ``n_global``. """ from .._compiler.ast_extension import create from .._compiler.cute.cute_reshape import _get_block_local_coord from .._compiler.cute.cute_reshape import resolve_cute_shape_chain_value_at from .._compiler.cute.indexing import CuteAffineRangeIndex from .._compiler.cute.indexing import is_cute_shape_chain_target from .._compiler.generate_ast import GenerateAST if ( tensor.ndim != 2 or len(subscript) != 2 or len(ast_subscript) != 2 or extra_mask is not None or value_node is None or not isinstance(state.codegen, GenerateAST) ): return None affine = ast_subscript[0] if not isinstance(affine, CuteAffineRangeIndex): return None if affine.step != 1 or affine.factor <= 0: return None n_index = subscript[1] if not isinstance(n_index, torch.SymInt): return None env = CompileEnvironment.current() block_id_n = env.get_block_id(n_index) if block_id_n is None: return None block_id_m = _cute_affine_range_block_id(state, affine) if block_id_m is None: return None if value_node.op != "call_function" or not is_cute_shape_chain_target( value_node.target ): return None value_val = value_node.meta.get("val") if not isinstance(value_val, torch.Tensor) or value_val.ndim != 2: return None m_global = _cute_active_index_var(state, block_id_m) n_global = _cute_active_index_var(state, block_id_n) if m_global is None or n_global is None: return None m_local = _get_block_local_coord(state.codegen, block_id_m) n_local = _get_block_local_coord(state.codegen, block_id_n) if m_local is None or n_local is None: return None block_n = state.device_function.resolved_block_size(block_id_n) if not isinstance(block_n, int): return None factor = affine.factor lane_var = state.device_function.new_var("affine_lane", dce=True) row_local = f"cutlass.Int32({factor}) * ({m_local}) + cutlass.Int32({lane_var})" flat_index = ( f"(({row_local}) * cutlass.Int32({block_n})) + ({n_local})" if block_n != 1 else f"({row_local}) + ({n_local})" ) value_ast = resolve_cute_shape_chain_value_at(state, value_node, flat_index) if value_ast is None: return None backend = env.backend index_dtype = backend.dtype_str(env.index_dtype) target_dtype = backend.dtype_str(tensor.dtype) value_expr = backend.ast_to_dtype_expr(ast.unparse(value_ast), target_dtype) # Bind the resolved (possibly select-based) value to a variable so the CuTe # DSL sees the stack `ifexp` as its own assignment rather than nested inside # the `.store(...)` call / masked store ternary. value_var = state.device_function.new_var("affine_value", dce=True) row_index = ( f"{index_dtype}(cutlass.Int32({factor}) * ({m_global}) " f"+ cutlass.Int32({lane_var}))" ) col_index = f"{index_dtype}({n_global})" tensor_name = state.device_function.tensor_arg(tensor).name store_expr = _cute_scalar_store_expr(tensor_name, [row_index, col_index], value_var) store_stmt: ast.stmt = create(ast.Expr, value=expr_from_string(store_expr)) mask_parts = [ mask for mask in ( _cute_active_mask_var(state, block_id_m), _cute_active_mask_var(state, block_id_n), ) if mask is not None ] if mask_parts: # Use a guard statement (not a ternary) so the CuTe DSL accepts the # device-value mask condition. mask_ast = expr_from_string(" and ".join(mask_parts)) assert isinstance(mask_ast, ast.expr) store_stmt = ast.fix_missing_locations( ast.If(test=mask_ast, body=[store_stmt], orelse=[]) ) return create( ast.For, target=create(ast.Name, id=lane_var, ctx=ast.Store()), iter=expr_from_string(f"range({factor})"), body=[ statement_from_string(f"{value_var} = {value_expr}"), store_stmt, ], orelse=[], type_comment=None, ) def _is_cute_affine_range_load_for_store( state: CodegenState, subscript: list[object] | tuple[object, ...], ast_subscript: list[object] | tuple[object, ...], ) -> bool: from .._compiler.cute.indexing import CuteAffineRangeIndex from .._compiler.cute.indexing import match_cute_affine_range_iota def compatible_store_user(user: torch.fx.Node) -> bool: if ( user.op != "call_function" or user.target is not store or len(user.args) < 4 or user.args[2] is not state.fx_node or user.args[3] is not None ): return False store_subscript = user.args[1] return ( isinstance(store_subscript, (list, tuple)) and len(store_subscript) == 1 and isinstance(store_subscript[0], torch.fx.Node) and match_cute_affine_range_iota(store_subscript[0]) is not None ) return ( state.fx_node is not None and len(state.fx_node.users) > 0 and all(compatible_store_user(user) for user in state.fx_node.users) and len(subscript) == 1 and len(ast_subscript) == 1 and isinstance(ast_subscript[0], CuteAffineRangeIndex) ) def _cute_positive_1d_slice_bounds( tensor: torch.Tensor, index: object ) -> tuple[int, int, int, int] | None: if not isinstance(index, slice) or index == slice(None): return None with contextlib.suppress(TypeError): dim_size = int(tensor.shape[0]) start, stop, step = index.indices(dim_size) if step <= 0: return None length = max(0, (stop - start + step - 1) // step) return start, stop, step, length return None def _is_cute_strided_slice_load_for_store( state: CodegenState, tensor: torch.Tensor, subscript: list[object] | tuple[object, ...], ) -> bool: def compatible_store_user(user: torch.fx.Node) -> bool: if ( user.op != "call_function" or user.target is not store or len(user.args) < 4 or user.args[2] is not state.fx_node or user.args[3] is not None ): return False target_node = user.args[0] if not isinstance(target_node, torch.fx.Node): return False target_tensor = target_node.meta.get("val") if not isinstance(target_tensor, torch.Tensor) or target_tensor.ndim != 1: return False store_subscript = user.args[1] return ( isinstance(store_subscript, (list, tuple)) and len(store_subscript) == 1 and _cute_positive_1d_slice_bounds(target_tensor, store_subscript[0]) is not None ) return ( state.fx_node is not None and len(state.fx_node.users) > 0 and all(compatible_store_user(user) for user in state.fx_node.users) and tensor.ndim == 1 and len(subscript) == 1 and _cute_positive_1d_slice_bounds(tensor, subscript[0]) is not None ) def _codegen_cute_strided_slice_store( state: CodegenState, tensor: torch.Tensor, subscript: list[object] | tuple[object, ...], value: object, extra_mask: ast.AST | None, value_node: torch.fx.Node | None = None, ) -> ast.AST | None: from .._compiler.ast_extension import create if tensor.ndim != 1 or len(subscript) != 1 or extra_mask is not None: return None target_bounds = _cute_positive_1d_slice_bounds(tensor, subscript[0]) if target_bounds is None: return None target_start, _target_stop, target_step, target_length = target_bounds env = CompileEnvironment.current() backend = env.backend index_dtype = backend.dtype_str(env.index_dtype) loop_var = state.device_function.new_var("slice_idx", dce=True) target_index = f"{index_dtype}({target_start} + {loop_var} * {target_step})" if ( value_node is not None and value_node.op == "call_function" and value_node.target is load ): source_tensor_node = value_node.args[0] if not isinstance(source_tensor_node, torch.fx.Node): return None source_tensor = source_tensor_node.meta.get("val") if not isinstance(source_tensor, torch.Tensor) or source_tensor.ndim != 1: return None source_subscript = value_node.args[1] if ( not isinstance(source_subscript, (list, tuple)) or len(source_subscript) != 1 ): return None source_bounds = _cute_positive_1d_slice_bounds( source_tensor, source_subscript[0] ) if source_bounds is None: return None source_start, _source_stop, source_step, source_length = source_bounds if source_length != target_length: return None source_index = f"{index_dtype}({source_start} + {loop_var} * {source_step})" source_name = state.device_function.tensor_arg(source_tensor).name value_expr = f"{source_name}[{source_index}]" if source_tensor.dtype is torch.bool: value_expr = f"({value_expr} != cutlass.Uint8(0))" elif isinstance(value, ast.AST): value_expr = ast.unparse(value) elif isinstance(value, (int, float, bool)): value_expr = repr(value) else: return None target_name = state.device_function.tensor_arg(tensor).name target_dtype = backend.dtype_str(tensor.dtype) value_expr = backend.ast_to_dtype_expr(value_expr, target_dtype) store_expr = f"{target_name}.__setitem__(({target_index},), {value_expr})" return create( ast.For, target=create(ast.Name, id=loop_var, ctx=ast.Store()), iter=expr_from_string(f"range({target_length})"), body=[create(ast.Expr, value=expr_from_string(store_expr))], orelse=[], type_comment=None, ) def _cute_combined_mask( state: CodegenState, subscript: list[object] | tuple[object, ...], extra_mask: ast.AST | None, tensor: torch.Tensor | None = None, *, include_tensor_index_masks: bool = True, ) -> str | None: env = CompileEnvironment.current() terms: list[str] = [] def mask_var_for_block_id(block_id: int) -> str | None: if _cute_index_override(state, block_id) is not None: return None block_id = _cute_remap_block_id(state, block_id) loops = state.codegen.active_device_loops.get(block_id) if loops: return loops[-1].strategy.mask_var(block_id) return None def active_index_var(block_id: int) -> str | None: if (override := _cute_index_override(state, block_id)) is not None: return override block_id = _cute_remap_block_id(state, block_id) loops = state.codegen.active_device_loops.get(block_id) if loops: return loops[-1].strategy.index_var(block_id) grid_state = state.codegen.current_grid_state if grid_state is not None and block_id in grid_state.block_ids: return grid_state.strategy.index_var(block_id) return None def active_local_coord(block_id: int) -> str | None: from .._compiler.cute.cute_reshape import _grid_local_coord_expr if _cute_index_override(state, block_id) is not None: return None block_id = _cute_remap_block_id(state, block_id) loops = state.codegen.active_device_loops.get(block_id) if loops: thread_axis = loops[-1].block_thread_axes.get(block_id) if thread_axis is not None: return _grid_local_coord_expr(state.codegen, block_id, thread_axis) grid_state = state.codegen.current_grid_state if grid_state is not None: thread_axis = grid_state.block_thread_axes.get(block_id) if thread_axis is not None: return _grid_local_coord_expr(state.codegen, block_id, thread_axis) return None def tile_begin_expr(block_id: int) -> str: block_id = _cute_remap_block_id(state, block_id) loops = state.codegen.active_device_loops.get(block_id) if loops: return state.codegen.offset_var(block_id) global_index = active_index_var(block_id) local_coord = active_local_coord(block_id) if global_index is not None and local_coord is not None: return state.codegen.lift( expr_from_string(f"({global_index}) - ({local_coord})"), dce=True, prefix="tile_begin", ).id if global_index is not None: return global_index return "0" def tile_with_offset_mask_terms( tile_info: TileWithOffsetInfo, tensor_dim: int, ) -> list[str]: block_id = tile_info.block_id local_coord = active_local_coord(block_id) begin_var = tile_begin_expr(block_id) if local_coord is None: if (idx_var := active_index_var(block_id)) is None: raise exc.BackendUnsupported( "cute", ( "indexing dimension is not active in this scope " f"(block_id={block_id})" ), ) local_coord = f"({idx_var}) - ({begin_var})" tile_terms = [] if tile_info.block_size is not None: block_size_expr = state.device_function.literal_expr(tile_info.block_size) tile_terms.append(f"({local_coord}) < cutlass.Int32({block_size_expr})") if tensor is not None and tensor_dim < tensor.ndim: offset_expr = state.device_function.literal_expr(tile_info.offset) dim_size = _cute_tensor_dim_size_expr(state, tensor, tensor_dim) tile_terms.append( f"(({begin_var}) + cutlass.Int32({offset_expr}) + " f"({local_coord})) < {dim_size}" ) return tile_terms def tensor_index_bounds_term(pos: int, tensor_dim: int) -> str | None: if tensor is None or tensor_dim >= tensor.ndim: return None ast_args = state.ast_args if not ( isinstance(ast_args, list) and len(ast_args) > 1 and isinstance(ast_args[1], (list, tuple)) and len(ast_args[1]) == len(subscript) and isinstance(ast_args[1][pos], ast.AST) ): return None index_var = state.codegen.lift( ast_args[1][pos], dce=True, prefix="index_mask", ).id index_dtype = env.backend.dtype_str(env.index_dtype) dim_size = _cute_tensor_dim_size_expr(state, tensor, tensor_dim) return f"{index_dtype}({index_var}) < {dim_size}" if extra_mask is not None: terms.append(state.codegen.lift(extra_mask, dce=True, prefix="mask").id) seen: set[int] = set() tensor_dim = 0 for pos, idx in enumerate(subscript): block_id: int | None = None if idx is None: continue if ( tile_info := _get_tile_with_offset_info( idx, getattr(state, "fx_node", None), pos ) ) is not None and tile_info.block_size is not None: seen.add(tile_info.block_id) for term in tile_with_offset_mask_terms(tile_info, tensor_dim): if term not in terms: terms.append(term) tensor_dim += 1 continue if isinstance(idx, torch.SymInt): block_id = env.get_block_id(idx) elif isinstance(idx, slice) and idx == slice(None) and tensor is not None: for bid in _matching_block_ids(env, tensor.shape[tensor_dim]): if bid not in seen and mask_var_for_block_id(bid) is not None: block_id = bid break elif isinstance(idx, slice) and idx != slice(None) and tensor is not None: slice_size = compute_slice_size(idx, tensor.shape[tensor_dim]) for bid in _matching_block_ids(env, slice_size): if bid not in seen and mask_var_for_block_id(bid) is not None: block_id = bid break elif isinstance(idx, torch.Tensor): added_tensor_index_mask = False # A free ``hl.arange`` mapped onto a synthetic thread axis carries no # block id, so the loops below add no bound for it. Emit its lane # bound explicitly to mask the out-of-bounds lanes a wider sibling # branch can introduce on a shared axis. arange_bound = _cute_synthetic_arange_lane_bound( state, pos, tensor, tensor_dim ) if arange_bound is not None and arange_bound not in terms: terms.append(arange_bound) # A reduction/grid dim mapped onto a thread axis can be *widened* # beyond its own extent by a mutually-exclusive sibling branch that # reuses the same axis (the launch block is sized to the widest # user). When this access addresses such a dim per-lane, bound the # lane to its own dim size so the surplus lanes a wider sibling adds # do not load/store out of bounds. ``active_local_coord`` is the # per-lane in-axis coordinate; ``< dim_size`` is a no-op when the # axis already matches this dim. if tensor is not None and tensor_dim < tensor.ndim: for bid in _matching_block_ids(env, tensor.shape[tensor_dim]): local_coord = active_local_coord(bid) if local_coord is not None: lane_bound = ( f"({local_coord}) < " f"{_cute_tensor_dim_size_expr(state, tensor, tensor_dim)}" ) if lane_bound not in terms: terms.append(lane_bound) break if not include_tensor_index_masks: for dim_size in idx.shape: for bid in _matching_block_ids(env, dim_size): if bid in seen or not env.is_jagged_tile(bid): continue mask_var = mask_var_for_block_id(bid) if mask_var is not None: added_tensor_index_mask = True seen.add(bid) if mask_var not in terms: terms.append(mask_var) break if ( not added_tensor_index_mask and (bound := tensor_index_bounds_term(pos, tensor_dim)) is not None ): if bound not in terms: terms.append(bound) tensor_dim += 1 continue for dim_size in idx.shape: for bid in _matching_block_ids(env, dim_size): if bid in seen: continue mask_var = mask_var_for_block_id(bid) if mask_var is not None: added_tensor_index_mask = True seen.add(bid) if mask_var not in terms: terms.append(mask_var) break else: continue if ( not added_tensor_index_mask and (bound := tensor_index_bounds_term(pos, tensor_dim)) is not None ): if bound not in terms: terms.append(bound) tensor_dim += 1 continue else: tensor_dim += 1 continue if block_id is None or block_id in seen: tensor_dim += 1 continue seen.add(block_id) if (mask_var := mask_var_for_block_id(block_id)) is not None: if mask_var not in terms: terms.append(mask_var) tensor_dim += 1 if not terms: return None return " and ".join(f"({term})" for term in terms) def _cute_synthetic_arange_lane_bound( state: CodegenState, subscript_pos: int, tensor: torch.Tensor | None, tensor_dim: int, ) -> str | None: """Bounds mask for a free ``hl.arange`` index mapped onto a synthetic CUDA thread axis (CuTe backend). The launch block is sized to the *widest* arange across all (mutually exclusive) grid branches that share a thread axis. A narrower arange in another branch -- e.g. ``hl.arange(0, 32)`` sharing a 64-wide axis with a sibling's ``hl.arange(0, 64)`` -- therefore addresses lanes beyond its own extent. Those extra lanes carry a coordinate ``thread_idx()[axis] >= size`` and, without this mask, perform out-of-bounds loads/stores. Returns the term ``(thread_idx()[axis]) < length`` (a no-op when the block matches the arange exactly), or ``None`` when this index is not such a synthetic arange. The synthetic-axis coordinate is the arange's *position* ``0..length-1`` regardless of ``start``/``step`` (the ``start + step *`` wrapping is applied separately), so the surplus lanes a wider sibling adds are exactly those with position ``>= length``. Bounding to the arange's own ``length`` is therefore correct for canonical and non-canonical (non-zero start / non-unit step) arange dims alike. """ if tensor is None or tensor_dim >= tensor.ndim: return None cg = state.codegen # A free arange maps onto a synthetic thread axis either directly # (``cute_synthetic_arange_axes`` key -> axis, coord ``thread_idx()[axis]``) # or, when it overflows the thread budget, onto a lane loop whose coordinate # is cached in ``cute_synthetic_arange_lane_exprs``. axes = getattr(cg, "cute_synthetic_arange_axes", None) or {} lane_exprs = getattr(cg, "cute_synthetic_arange_lane_exprs", None) or {} if not axes and not lane_exprs: return None fx_node = getattr(state, "fx_node", None) if fx_node is None or len(fx_node.args) < 2: return None subscript_arg = fx_node.args[1] if not isinstance(subscript_arg, (list, tuple)) or subscript_pos >= len( subscript_arg ): return None idx_node = subscript_arg[subscript_pos] if not isinstance(idx_node, torch.fx.Node): return None from .._compiler.cute.iota_utils import cute_free_arange_indexed_dim_key dim_key = cute_free_arange_indexed_dim_key(idx_node, cg) if dim_key is None: return None # ``key`` is ``(dim_key, length, start, step)``. The bound masks lanes beyond # this arange's own extent, which is its ``length`` -- independent of # ``start``/``step`` -- so match purely on ``dim_key``. def _match(key: object) -> bool: return isinstance(key, tuple) and len(key) == 4 and key[0] == dim_key coord = None length: object = None for key, axis in axes.items(): if _match(key): coord = f"cutlass.Int32(cute.arch.thread_idx()[{axis}])" length = key[1] break if coord is None: for key, expr in lane_exprs.items(): if _match(key): coord = expr length = key[1] break if coord is None: return None return f"({coord}) < {length}" def _cute_tensor_dim_size_expr( state: CodegenState, tensor: torch.Tensor, dim: int ) -> str: return state.device_function.tensor_size(tensor, dim).name def _cute_tile_begin_expr(state: CodegenState, idx: object) -> str: env = CompileEnvironment.current() def active_index_var(block_id: int) -> str | None: loops = state.codegen.active_device_loops.get(block_id) if loops: return loops[-1].strategy.index_var(block_id) grid_state = state.codegen.current_grid_state if grid_state is not None and block_id in grid_state.block_ids: return grid_state.strategy.index_var(block_id) return None def active_local_coord(block_id: int) -> str | None: from .._compiler.cute.cute_reshape import _grid_local_coord_expr loops = state.codegen.active_device_loops.get(block_id) if loops: thread_axis = loops[-1].block_thread_axes.get(block_id) if thread_axis is not None: return _grid_local_coord_expr(state.codegen, block_id, thread_axis) grid_state = state.codegen.current_grid_state if grid_state is not None: thread_axis = grid_state.block_thread_axes.get(block_id) if thread_axis is not None: return _grid_local_coord_expr(state.codegen, block_id, thread_axis) return None def tile_begin_from_block_id(block_id: int) -> str: loops = state.codegen.active_device_loops.get(block_id) if loops: return state.codegen.offset_var(block_id) global_index = active_index_var(block_id) local_coord = active_local_coord(block_id) if global_index is not None and local_coord is not None: return state.codegen.lift( expr_from_string(f"({global_index}) - ({local_coord})"), dce=True, prefix="tile_begin", ).id if global_index is not None: return global_index return "0" if isinstance(idx, int): return str(idx) if not isinstance(idx, torch.SymInt): raise exc.BackendUnsupported("cute", f"tile base index type: {type(idx)}") expr = _symint_expr(idx) if expr is not None: origin_info = HostFunction.current().expr_to_origin.get(expr) if origin_info is not None and isinstance(origin_info.origin, TileBeginOrigin): return tile_begin_from_block_id(origin_info.origin.block_id) block_id = env.get_block_id(idx) if block_id is not None: return tile_begin_from_block_id(block_id) if expr is not None: return state.sympy_expr(expr) raise exc.BackendUnsupported("cute", f"unlowerable tile base index: {idx}") def _codegen_cute_store_tcgen05_tile( state: CodegenState, tensor: torch.Tensor, subscript: list[object] | tuple[object, ...], ast_subscript: list[object] | tuple[object, ...], extra_mask: ast.AST | None, value_name: str, epilogue_chain: Tcgen05UnaryEpilogueChain | None = None, ) -> list[ast.AST] | ast.AST | None: df = state.device_function candidate_names = df.variable_aliases(value_name) tcgen05_value = df.cute_state.get_tcgen05_store_value(candidate_names) if tcgen05_value is None: return None if extra_mask is not None: if tcgen05_value.pure_matmul_role_lifecycle: raise exc.BackendUnsupported( "cute", "tcgen05 pure role-lifecycle store cannot use an extra store mask", ) return None if tensor.ndim != 2: if tcgen05_value.pure_matmul_role_lifecycle: raise exc.BackendUnsupported( "cute", "tcgen05 pure role-lifecycle store requires a rank-2 tensor target", ) return None if tcgen05_value.pure_matmul_role_lifecycle: if epilogue_chain is not None: raise exc.BackendUnsupported( "cute", "tcgen05 pure role-lifecycle supports only identity pure-matmul stores", ) # When one matmul accumulator fans out to multiple output stores (e.g. # aux = pre-activation and out = gelu(pre)), the per-matmul TMA-store # atom/tensor kernel-arg names allocated in cute_mma are shared by every # store site. Emitting them verbatim at each site produces duplicate kernel # parameters (SyntaxError) and binds both device epilogues to the same TMA # descriptor. The secondary store gets fresh per-store descriptor names so # each store threads its own TMA descriptor; the first store keeps the # original names. The secondary store also reuses the accumulator the first # store already consumed: the accumulator TMEM stays live until the # one-shot teardown frees it, so the secondary store reads it directly # without re-running the accumulator pipeline's consumer wait/release/advance # (those would hang waiting on a producer that has already drained) and # without re-emitting the matmul drain / TMEM-free teardown. is_secondary_store = ( tcgen05_value.use_tma_store_epilogue and not tcgen05_value.pure_matmul_role_lifecycle and df.cute_state.tcgen05_tma_store_names_already_emitted(tcgen05_value) ) if is_secondary_store: tcgen05_value = dataclasses.replace( tcgen05_value, tma_store_atom=df.new_var("tcgen05_tma_store_atom"), tma_store_tensor=df.new_var("tcgen05_tma_store_tensor"), ) tcgen05_lifecycle = tcgen05_value.lifecycle_context tcgen05_pure_matmul_object = tcgen05_value.pure_matmul_object # Snapshot the accumulator consumer-state stage index. The primary store # captures it before advancing the consumer state; fan-out stores read the # same live TMEM stage through the snapshot rather than the already-advanced # live index. For single-store kernels the assignment is unused and DCE # drops it, so the generated code is unchanged. tcgen05_acc_stage_index_var, tcgen05_acc_stage_index_is_primary = ( df.cute_state.get_or_create_tcgen05_acc_stage_index_var( tcgen05_lifecycle.acc_consumer_state, df.new_var, ) ) # The snapshot is captured at top level (before the store's control-flow # block) by the primary store so fan-out stores can read it; CuTe DSL # forbids defining a value inside one control-flow block and reading it in # another. For single-store kernels the assignment is unused and DCE drops # it, keeping generated code unchanged. tcgen05_acc_stage_index_top_level_stmts = ( [ statement_from_string( f"{tcgen05_acc_stage_index_var} = " f"{tcgen05_lifecycle.acc_consumer_state}.index" ) ] if tcgen05_acc_stage_index_is_primary else [] ) # The primary store keeps reading the live consumer index so single-store # codegen is byte-identical; only fan-out stores route through the snapshot. tcgen05_acc_stage_index_expr = ( f"{tcgen05_lifecycle.acc_consumer_state}.index" if not is_secondary_store else tcgen05_acc_stage_index_var ) # Backstop for callers that bypass Config.normalize() validation; # see _tcgen05_epi_warp_count docstring and cute_plan.md. if tcgen05_value.epi_warp_count != 4: raise exc.BackendUnsupported( "cute", f"tcgen05 SIMT-store epilogue requires " f"tcgen05_num_epi_warps=4 (got {tcgen05_value.epi_warp_count}). " "CUTLASS tmem_warp_shape_mn=(4,1) hard-codes a 4-warp t2r " "partition for the supported tcgen05 path; per-warp " "tcgen05.ld semantics make the partition uncoverable by " "fewer warps. Lifts when the c_pipeline-driven multi-warp " "epilogue lands (see cute_plan.md).", ) backend = CompileEnvironment.current().backend tensor_name = df.tensor_arg(tensor).name target_dtype = backend.dtype_str(tensor.dtype) # The matmul plan computed `tcgen05_epi_tile` (role-local t2r # partition) with `epi_elem_dtype_str`; the store path below # recomputes `tcgen05_store_epi_tile` with `target_dtype`. They must # match or `compute_epilogue_tile_shape` selects different `tile_n` # values on the two sides and the t2r / r2s SMEM staging silently # corrupts. The loud-failure backstop covers cases where MMA-codegen- # time forward-tracing of the matmul fx_node could not pin a unique # store target dtype. if ( tcgen05_value.epi_elem_dtype_str and tcgen05_value.epi_elem_dtype_str != target_dtype ): raise exc.BackendUnsupported( "cute", "tcgen05 epilogue element-type mismatch: matmul plan was set " f"up with epi_elem_dtype_str={tcgen05_value.epi_elem_dtype_str!r} " f"but the store target tensor dtype is {target_dtype!r}.", ) base_indices = [_cute_tile_begin_expr(state, idx) for idx in subscript] if len(base_indices) != 2: if tcgen05_value.pure_matmul_role_lifecycle: raise exc.BackendUnsupported( "cute", "tcgen05 pure role-lifecycle store requires a rank-2 tile store", ) return None m_size = _cute_tensor_dim_size_expr(state, tensor, 0) n_size = _cute_tensor_dim_size_expr(state, tensor, 1) tile_coord_m = f"({base_indices[0]}) // cutlass.Int32({tcgen05_value.bm})" tile_coord_n = f"({base_indices[1]}) // cutlass.Int32({tcgen05_value.bn})" full_tile = df.new_var("tcgen05_full_tile") gmem_tile = df.new_var("tcgen05_gC") coord_tile = df.new_var("tcgen05_cC") tcgc_base = df.new_var("tcgen05_tCgC_base") tccc_base = df.new_var("tcgen05_tCcC_base") tcgc = df.new_var("tcgen05_tCgC") tcgc_planned = df.new_var("tcgen05_tCgC_planned") tccc = df.new_var("tcgen05_tCcC") tacc = df.new_var("tcgen05_tAcc") epi_tile = df.new_var("tcgen05_store_epi_tile") tiled_copy_t2r = df.new_var("tcgen05_tiled_copy_t2r") thr_copy_t2r = df.new_var("tcgen05_thr_copy_t2r") ttr_tacc_base = df.new_var("tcgen05_tTR_tAcc_base") tcgc_epi = df.new_var("tcgen05_tCgC_epi") tccc_epi = df.new_var("tcgen05_tCcC_epi") ttr_gc = df.new_var("tcgen05_tTR_gC") ttr_cc = df.new_var("tcgen05_tTR_cC") ttr_racc = df.new_var("tcgen05_tTR_rAcc") ttr_rd = df.new_var("tcgen05_tTR_rD") ttr_tacc_stage = df.new_var("tcgen05_tTR_tAcc_stage") ttr_tacc = df.new_var("tcgen05_tTR_tAcc") ttr_gc_grouped = df.new_var("tcgen05_tTR_gC_grouped") ttr_cc_grouped = df.new_var("tcgen05_tTR_cC_grouped") ttr_tacc_mn = df.new_var("tcgen05_tTR_tAcc_mn") ttr_gc_subtile = df.new_var("tcgen05_tTR_gC_subtile") ttr_cc_subtile = df.new_var("tcgen05_tTR_cC_subtile") acc_vec = df.new_var("tcgen05_acc_vec") kernel_desc = df.new_var("tcgen05_kernel_desc") mcld = df.new_var("tcgen05_mcld") num_bits = df.new_var("tcgen05_num_bits") simt_atom = df.new_var("tcgen05_simt_atom") smem_d_layout = df.new_var("tcgen05_sD_layout") smem_d_ptr = df.new_var("tcgen05_sD_ptr") smem_d = df.new_var("tcgen05_sD") tiled_copy_r2s = df.new_var("tcgen05_tiled_copy_r2s") trs_rd = df.new_var("tcgen05_tRS_rD") trs_racc = df.new_var("tcgen05_tRS_rAcc") trs_sd = df.new_var("tcgen05_tRS_sD") bsg_sd = df.new_var("tcgen05_bSG_sD") bsg_gd_partitioned = df.new_var("tcgen05_bSG_gD_partitioned") bsg_gd = df.new_var("tcgen05_bSG_gD") c_buffer = df.new_var("tcgen05_c_buffer") epilog_sync_barrier = df.new_var("tcgen05_epilog_sync_barrier") c_pipeline_producer_group = df.new_var("tcgen05_c_pipeline_producer_group") c_pipeline = df.new_var("tcgen05_c_pipeline") subtile_count = df.new_var("tcgen05_subtile_count") # Workstream A Stage 4 (cycle 93, Path B): the C-store producer->consumer # edge over the C-ring SMEM (``tRS_sD``, depth ``c_stage_count``). Producer # = the 4 epi warps (arrive after R2S + ``fence_view_async_shared``); # consumer = the single store warp (waits, issues the TMA-D, releases the # SMEM stage). Replaces the second ``epilog_sync_barrier`` (R2S-visible) # CTA-wide barrier with a cheaper cross-warp pipeline edge that lets the # epi warps proceed to the next subtile while the store warp drains. c_store_edge_barriers = df.new_var("tcgen05_c_store_edge_barriers") c_store_edge_producer_group = df.new_var("tcgen05_c_store_edge_producer_group") c_store_edge_consumer_group = df.new_var("tcgen05_c_store_edge_consumer_group") c_store_edge = df.new_var("tcgen05_c_store_edge") c_store_edge_producer_state = df.new_var("tcgen05_c_store_edge_producer_state") c_store_edge_consumer_state = df.new_var("tcgen05_c_store_edge_consumer_state") # Separate consumer state for the LAGGED release. The store warp's TMA-D is # an async bulk copy that reads the C-ring SMEM stage; the stage may not be # reused (epi R2S overwrite) until that read completes. ``c_pipeline`` # (PipelineTmaStore) tracks store completion via ``cp_async_bulk_wait_group`` # (read=True), which after committing store i and waiting drains every store # except the ``c_stages - 1`` most recent. So the store warp releases the # C-ring stage from ``c_stages - 1`` subtiles ago (provably drained), lagging # the consumer-wait by ``c_stages - 1``. This leaves exactly one free stage # (edge depth ``c_stages``), giving the ~1-subtile store/T2R overlap the # acc_stages=2 bound permits. The first ``c_stages - 1`` releases are # suppressed (no drained stage yet); the trailing stages release naturally # in subsequent tiles as the global subtile index advances. c_store_edge_release_state = df.new_var("tcgen05_c_store_edge_release_state") epi_warp_ids = ", ".join( f"cutlass.Int32({i})" for i in range(tcgen05_value.epi_warp_count) ) if tcgen05_value.epi_warp_count == 1: epi_warp_ids += "," # Per-aux-step plumbing: per-thread auxiliary tensor reads at # the splice site. For each ``_AuxiliaryTensorStep`` in the # chain we register the auxiliary tensor as a kernel arg, # allocate fresh AST var names for the partitioning chain, and # later (inside each per-thread splice site) emit per-subtile # ``aux_loaded = ...`` lines that the chain renderer references. # Static-full TMA-store tiles use the historical direct # ``ttr_aux_subtile.load()`` form. SIMT-store edge tiles use a # predicated GMEM-to-register copy first, so the aux read observes # the same runtime predicate as the output store. aux_steps_in_chain: tuple[_AuxiliaryTensorStep, ...] = ( epilogue_chain.auxiliary_tensor_steps if epilogue_chain is not None else () ) aux_step_records: list[_AuxStepRecord] = [] for aux_idx, aux_step in enumerate(aux_steps_in_chain): aux_tensor_node = aux_step.load_node.args[0] assert isinstance(aux_tensor_node, torch.fx.Node) aux_torch_tensor = aux_tensor_node.meta.get("val") assert isinstance(aux_torch_tensor, torch.Tensor) aux_tensor_name = df.tensor_arg(aux_torch_tensor).name aux_dtype = backend.dtype_str(aux_torch_tensor.dtype) aux_dtype_bits = aux_torch_tensor.dtype.itemsize * 8 # Aux tensors must be passed through to the device function as # placeholder args so the wrapper plumbs them into the cute # kernel signature (the role-local persistent path otherwise # treats unreferenced tensors as captures, which doesn't work # for tensors only read inside a per-subtile loop body). df.placeholder_args.add(aux_tensor_name) # Broadcast aux steps need a fresh AST var for the 2-D view # of the rank-1 underlying tensor (stride 0 on the orthogonal # axis). Exact-shape aux steps leave ``aux_view2d`` as None. # broadcast_axis 0/1 build a stride-0 2-D view of a rank-1 tensor; # the colvec form (2) reuses the exact-shape pipeline over its own # (M, N) stride-(1,0) view, so it needs no separate ``aux_view2d``. aux_view2d = ( df.new_var(f"tcgen05_aux_view2d_{aux_idx}") if aux_step.broadcast_axis in (0, 1) else None ) aux_step_records.append( _AuxStepRecord( aux_tensor_name=aux_tensor_name, broadcast_axis=aux_step.broadcast_axis, aux_tile=df.new_var(f"tcgen05_aux_tile_{aux_idx}"), aux_part_base=df.new_var(f"tcgen05_tCgAux_base_{aux_idx}"), aux_xfm=df.new_var(f"tcgen05_tCgAux_xfm_{aux_idx}"), aux_planned=df.new_var(f"tcgen05_tCgAux_planned_{aux_idx}"), aux_epi=df.new_var(f"tcgen05_tCgAux_epi_{aux_idx}"), aux_dtype=aux_dtype, aux_dtype_bits=aux_dtype_bits, aux_extent=( aux_torch_tensor.shape[0] if ( aux_step.broadcast_axis == 1 and isinstance(aux_torch_tensor.shape[0], int) ) else None ), ttr_aux=df.new_var(f"tcgen05_tTR_gAux_{aux_idx}"), ttr_aux_grouped=df.new_var(f"tcgen05_tTR_gAux_grouped_{aux_idx}"), ttr_aux_subtile=df.new_var(f"tcgen05_tTR_gAux_subtile_{aux_idx}"), aux_rmem=df.new_var(f"tcgen05_aux_rmem_{aux_idx}"), aux_loaded=df.new_var(f"tcgen05_aux_loaded_{aux_idx}"), aux_view2d=aux_view2d, # Pre-wait whole-fragment register hoist of N-broadcast # (rowvec) aux on the bm=128 2-CTA full-tile TMA-store path. # The fragment there is small (2 subtiles x epi-tile N of 64 # at bn=128 = a handful of fp32 registers per thread), so the # whole-fragment LDG fits without spills and hides its GMEM # latency under the MMA wait (standalone CUTLASS does the # same; ~2% on the 512x6144x2048 fp8 scaled_mm shape). It is # deliberately NOT applied to the bm=256 family: its larger # whole-tile fragment regressed via register spills (see the # fp8_gap_v2 history of the rowvec hoist removal at bn=128/ # epi-32 -- 409k LDL/STL on the 4096^3 shape). aux_rmem_full=( df.new_var(f"tcgen05_aux_rmem_full_{aux_idx}") if ( aux_step.broadcast_axis in (0, 1) and tcgen05_is_two_cta_m128( is_two_cta=tcgen05_lifecycle.is_two_cta, bm=tcgen05_value.bm, ) and tcgen05_value.use_tma_store_epilogue and not tcgen05_value.partial_output_tma_store ) else None ), ) ) # Pyrefly does not preserve the non-None ``tcgen05_value`` narrowing # inside the nested source-formatter closures, so keep local # string aliases for attributes the closures read. tcgen05_aux_bm = tcgen05_value.bm tcgen05_aux_bn = tcgen05_value.bn tcgen05_aux_thr_mma = tcgen05_value.thr_mma tcgen05_aux_epi_tidx = tcgen05_value.epi_tidx tcgen05_aux_epi_active = tcgen05_lifecycle.epi_active tcgen05_aux_epi_warp_count = tcgen05_value.epi_warp_count tcgen05_aux_epilogue_rest_mode = tcgen05_value.epilogue_rest_mode tcgen05_aux_use_tma_store_epilogue = tcgen05_value.use_tma_store_epilogue tcgen05_explicit_store_tile_expr: str | None = None if tcgen05_value.has_explicit_epilogue_tile: assert tcgen05_value.explicit_epi_tile_m is not None assert tcgen05_value.explicit_d_store_box_n is not None tcgen05_explicit_store_tile_expr = tcgen05_explicit_d_store_tile_expr( tcgen05_value.explicit_epi_tile_m, tcgen05_value.explicit_d_store_box_n, ) # Per-thread epilogue M extent. The tcgen05 TMEM->register (T2R) copy # distributes the epilogue tile's M dimension across the 128-lane TMEM # datapath: when ``epi_tile_m >= 128`` every lane owns >= 1 full output # row, so each thread's per-subtile fragment stays within a SINGLE M row; # when ``epi_tile_m < 128`` (e.g. block_m=64) a lane's fragment spans # multiple M rows. This decides whether the per-row colvec scalar read is # valid (see the ``broadcast_axis == 2`` branch below). Mirrors the # ``epi_tile_m`` computation in ``tcgen05_resolve_epilogue_tile`` / # ``cute_mma``: ``bm`` for the default tile, ``bm // 2`` for the 2-CTA # bm=128 family, or the user-supplied explicit tile M. # # Correctness of the colvec scalar fast-path rests on the invariant that # the T2R atom FILLS the 128-lane M datapath before packing multiple rows # per lane (a property of CUTLASS's epilogue_tmem_copy_and_partition, not # enforced here). Verified geometrically by partitioning an identity # coordinate tensor through the same T2R copy and reading the distinct M # coordinates a lane holds: per-CTA per-thread M-extent is 1 at # epi_tile_m=128 (block_m=128) and 1 at epi_tile_m=128 (block_m=256 2-CTA), # but 2 at epi_tile_m=64 (block_m=64) -- matching the >= 128 threshold # exactly. If a future CUTLASS changes the lane distribution only the # threshold needs revisiting: the materialize fallback below is always # correct, so the worst case is the fast-path mis-firing (caught by the # row-dependent colvec tests). _TCGEN05_TMEM_DATAPATH_M = 128 if tcgen05_value.has_explicit_epilogue_tile: assert tcgen05_value.explicit_epi_tile_m is not None tcgen05_epi_tile_m = tcgen05_value.explicit_epi_tile_m elif tcgen05_is_two_cta_m128( is_two_cta=tcgen05_lifecycle.is_two_cta, bm=tcgen05_value.bm ): tcgen05_epi_tile_m = tcgen05_value.bm // 2 else: tcgen05_epi_tile_m = tcgen05_value.bm tcgen05_colvec_fragment_single_m_row = ( tcgen05_epi_tile_m >= _TCGEN05_TMEM_DATAPATH_M ) # C-input warp productive-body gate (``cute_plan.md`` §7.5.3.2 # cycle 2b producer + consumer flip). When the matmul plan has # ``has_c_input_warp`` AND a non-empty ``aux_tensor_descriptors`` # tuple AND the aux pipeline plan was registered by # ``cute_mma._codegen_cute_mma``, the consumer-side per-thread # GMEM aux LDG flips to an SMEM read from the # ``c_pipeline_aux``-staged ring populated by the C-input warp's # cooperative copy. The producer body in # ``program_id._build_c_input_warp_role_local_while`` writes # ONE ``epi_tile`` subtile of the per-CTA aux region # (``(bm_per_cta, bn)`` under 2cta; ``(bm, bn)`` otherwise) per # stage per subtile iteration under ``producer_acquire`` / # ``producer_commit`` framing; the consumer issues one # ``consumer_wait`` / lane-0-gated ``consumer_release`` pair # per subtile and feeds the SMEM stage into Quack's # ``tiled_copy_s2r`` flow (``make_tiled_copy_D`` against # ``tiled_copy_t2r`` → ``partition_S(sC_ring)`` → per- # subtile ``cute.copy(s2r, sC[..., stage], rmem)`` → # ``rmem.load()``). Gate-closed configs keep the historical # GMEM path byte-identical. aux_matmul_plan = df.cute_state.matmul_plan aux_pipeline_plan_obj = df.cute_state.aux_pipeline_plan # Workstream A Stage 4 (cycle 93, Path B): when the plan carries a store # warp, the per-subtile R2S->TMA-D tail is split by warp role and the # second epilogue barrier is replaced by the C-store pipeline edge. The # store warp drains the TMA-D so the 4 epi warps proceed to the next # subtile's T2R. ``store_warps=0`` keeps the original fused tail unchanged # (the production path; byte-identical codegen). has_store_warp = aux_matmul_plan is not None and aux_matmul_plan.has_store_warp store_warp_predicate = ( f"{tcgen05_value.warp_idx} == cutlass.Int32({aux_matmul_plan.store_warp_id})" if aux_matmul_plan is not None and has_store_warp else "" ) # Match each store-side record to its descriptor by # ``load_node`` FX-node identity rather than positional # index. The descriptor walker dedups by ``store_value_node`` # at MMA-codegen time, so a single-store kernel's # descriptors and records share the same ``load_node`` # values in some permutation. The matmul plan's # ``aux_single_store_value`` gate (in ``cute_mma`` and the # ``program_id`` role-local-while admission) only allocates # the producer-side pipeline when every descriptor shares # one ``store_value_node``, so the multi-store fan-out # wedge (producer commits to rings the per-store consumer # never releases) cannot occur — the productive body # closes its gate at MMA-codegen time and the consumer # path here falls back to GMEM. Broadcast row-vector aux loads are # deliberately not staged by the C-input producer, so the per-record lookup # below allows a mixed chain: matched exact-shape records read from SMEM, # unmatched records keep the direct GMEM path. aux_step_load_nodes: tuple = ( tuple(rec_step.load_node for rec_step in aux_steps_in_chain) if aux_step_records else () ) aux_ring_index_by_step: list[int | None] = [] aux_descriptor_load_nodes: tuple = ( tuple(d.load_node for d in aux_matmul_plan.c_input_aux_tensor_descriptors) if aux_matmul_plan is not None else () ) for step_load_node in aux_step_load_nodes: try: aux_ring_index_by_step.append( aux_descriptor_load_nodes.index(step_load_node) ) except ValueError: aux_ring_index_by_step.append(None) aux_has_staged_steps = any( ring_idx is not None for ring_idx in aux_ring_index_by_step ) # Workstream A Stage 5 (cycle 94, the merge): the aux SMEM ring producer is # the C-input warp normally (SIMT or TMA), or the store warp under the merge # — but the store warp is TMA-ONLY (there is no SIMT store-warp producer; # ``store_warps=1 + SIMT aux`` falls back to direct-GMEM aux). The epi-warp # consumer reads the staged ring whenever a producer is present. The # ``aux_pipeline_plan_obj is not None`` term already closes this gate for # ``store_warps=1 + SIMT`` (``cute_mma`` never allocates the plan there); # the explicit ``use_tma_load`` term on the store-warp branch makes the # TMA-only requirement local and defensive. aux_producer_warp_present = aux_matmul_plan is not None and ( aux_matmul_plan.has_c_input_warp or ( aux_matmul_plan.has_store_warp and aux_pipeline_plan_obj is not None and aux_pipeline_plan_obj.use_tma_load ) ) use_aux_smem_source = ( aux_step_records and aux_matmul_plan is not None and aux_producer_warp_present and bool(aux_matmul_plan.c_input_aux_tensor_descriptors) and aux_pipeline_plan_obj is not None and aux_has_staged_steps # Multi-store fan-out gate (same predicate as the # producer-side allocator + role-local-while # admission). Without this guard the producer fires # ``producer_commit`` on rings whose only matching # consumer-store is a different per-store-codegen # invocation — the per-store splice site here only # releases its own subset, leaving the unmatched rings # uncommitted and deadlocking the producer once a CTA # wraps the pipeline depth. and len( {d.store_value_node for d in aux_matmul_plan.c_input_aux_tensor_descriptors} ) <= 1 ) if use_aux_smem_source: assert aux_pipeline_plan_obj is not None aux_pipeline_name = aux_pipeline_plan_obj.pipeline aux_consumer_state_name = aux_pipeline_plan_obj.consumer_state aux_pipeline_uses_tma_load = aux_pipeline_plan_obj.use_tma_load all_rings = aux_pipeline_plan_obj.rings aux_ring_smem_names: tuple[str | None, ...] = tuple( all_rings[ring_idx].smem if ring_idx is not None else None for ring_idx in aux_ring_index_by_step ) else: aux_pipeline_name = "" aux_consumer_state_name = "" aux_pipeline_uses_tma_load = False aux_ring_smem_names = tuple(None for _ in aux_step_records) # Row-vector aux (``bias[n]`` / rowwise ``scale_b[n]``) reads stay # per-subtile (the generic ``ttr_aux_subtile.load()`` path below, placed # after the c_pipeline acquire / acc ``consumer_wait`` / T2R prefix per the # cycle-69 placement). rowvec_aux_stage_records: list[_RowvecAuxStageRecord | None] = [] for aux_idx, rec in enumerate(aux_step_records): copy_bits = 128 copy_elems = _tcgen05_rowvec_aux_stage_copy_elems( rec.aux_dtype_bits, tcgen05_aux_bn, rec.aux_extent, copy_bits=copy_bits, ) if ( tcgen05_value.partial_output_tma_store and tcgen05_value.use_tma_store_epilogue and rec.broadcast_axis == 1 and copy_elems is not None ): assert rec.aux_extent is not None rowvec_aux_stage_records.append( _RowvecAuxStageRecord( smem_layout=df.new_var(f"tcgen05_aux_rowvec_smem_layout_{aux_idx}"), smem_ptr=df.new_var(f"tcgen05_aux_rowvec_smem_ptr_{aux_idx}"), smem=df.new_var(f"tcgen05_aux_rowvec_smem_{aux_idx}"), tiled_copy=df.new_var(f"tcgen05_aux_rowvec_tiled_copy_{aux_idx}"), thr_copy=df.new_var(f"tcgen05_aux_rowvec_thr_copy_{aux_idx}"), gmem_tile=df.new_var(f"tcgen05_aux_rowvec_gmem_tile_{aux_idx}"), gmem_part=df.new_var(f"tcgen05_aux_rowvec_gmem_part_{aux_idx}"), smem_part=df.new_var(f"tcgen05_aux_rowvec_smem_part_{aux_idx}"), coord=df.new_var(f"tcgen05_aux_rowvec_coord_{aux_idx}"), limit=df.new_var(f"tcgen05_aux_rowvec_limit_{aux_idx}"), pred=df.new_var(f"tcgen05_aux_rowvec_pred_{aux_idx}"), copy_bits=copy_bits, copy_elems=copy_elems, aux_extent=rec.aux_extent, ) ) else: rowvec_aux_stage_records.append(None) partial_tma_needs_full_tile_guard = tcgen05_value.partial_output_tma_store and any( # ``aux_ring_smem_names`` and ``rowvec_aux_stage_records`` are both # positionally aligned with ``aux_step_records``. name is None and rowvec_aux_stage_records[aux_idx] is None for aux_idx, name in enumerate(aux_ring_smem_names) ) def _rowvec_aux_smem_setup_lines() -> list[str]: """Emit compact per-tile SMEM allocation for staged row-vector aux.""" lines: list[str] = [] for aux_idx, rec in enumerate(aux_step_records): stage = rowvec_aux_stage_records[aux_idx] if stage is None: continue lines.extend( [ ( f"{stage.smem_layout} = cute.make_layout(" f"({tcgen05_aux_bn},), stride=(1,))" ), ( f"{stage.smem_ptr} = cute.arch.alloc_smem(" f"{rec.aux_dtype}, cute.cosize({stage.smem_layout}), " "alignment=128)" ), ( f"{stage.smem} = cute.make_tensor(" f"{stage.smem_ptr}, {stage.smem_layout})" ), ] ) return lines def _rowvec_aux_copy_lines() -> list[str]: """Emit the predicated GMEM-to-SMEM copy for staged row-vector aux.""" lines: list[str] = [] for aux_idx, rec in enumerate(aux_step_records): stage = rowvec_aux_stage_records[aux_idx] if stage is None: continue lines.append( f"if {tcgen05_aux_epi_active}:\n" f" {stage.tiled_copy} = cute.make_tiled_copy_tv(" f"cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), " f"{rec.aux_dtype}, num_bits_per_copy={stage.copy_bits}), " f"cute.make_layout({tcgen05_aux_epi_warp_count * 32}), " f"cute.make_layout({stage.copy_elems}))\n" f" {stage.thr_copy} = {stage.tiled_copy}.get_slice(" f"{tcgen05_aux_epi_tidx})\n" f" {stage.gmem_tile} = cute.local_tile(" f"{rec.aux_tensor_name}, ({tcgen05_aux_bn},), " f"({tile_coord_n},))\n" f" {stage.gmem_part} = {stage.thr_copy}.partition_S(" f"{stage.gmem_tile})\n" f" {stage.smem_part} = {stage.thr_copy}.partition_D(" f"{stage.smem})\n" f" {stage.coord} = {stage.thr_copy}.partition_S(" f"cute.make_identity_tensor({tcgen05_aux_bn}))\n" f" {stage.limit} = min({n_size} - ({base_indices[1]}), " f"cutlass.Int32({stage.aux_extent}) - ({base_indices[1]}), " f"cutlass.Int32({tcgen05_aux_bn}))\n" f" {stage.pred} = cute.make_rmem_tensor(" f"(1, cute.size({stage.smem_part}.shape[1])), cutlass.Boolean)\n" f" for _rowvec_i in cutlass.range(" f"cute.size({stage.smem_part}.shape[1]), unroll_full=True):\n" f" {stage.pred}[0, _rowvec_i] = " f"{stage.coord}[0, _rowvec_i] < {stage.limit}\n" f" cute.copy({stage.tiled_copy}, {stage.gmem_part}, " f"{stage.smem_part}, pred={stage.pred})\n" f" cute.arch.fence_acq_rel_cta()\n" f" {epilog_sync_barrier}.arrive_and_wait()" ) return lines def _simt_edge_coord_subtile_source(indent: str) -> str: return ( f"{indent}{coord_tile} = cute.local_tile(" f"cute.make_identity_tensor(({m_size}, {n_size})), " f"({tcgen05_aux_bm}, {tcgen05_aux_bn}), " f"({tile_coord_m}, {tile_coord_n}))\n" f"{indent}{tccc_base} = {tcgen05_aux_thr_mma}.partition_C(" f"{coord_tile})\n" f"{indent}{tccc} = " "cutlass.utils.gemm.sm100.transform_partitioned_tensor_layout(" f"{tccc_base})\n" f"{indent}{tccc_epi} = cute.flat_divide({tccc}, {epi_tile})\n" f"{indent}{ttr_cc} = {thr_copy_t2r}.partition_D({tccc_epi})\n" f"{indent}{ttr_cc_grouped} = cute.group_modes({ttr_cc}, 3, " f"cute.rank({ttr_cc}))\n" f"{indent}{ttr_cc_subtile} = {ttr_cc_grouped}[(None, None, None, " f"cutlass.Int32(_tcgen05_subtile))]\n" ) def _simt_edge_scalar_copy_source( indent: str, src: str, dst: str, *, include_coord_setup: bool = True ) -> str: # General SIMT edge copies keep the scalar loop unless the call site # retile below can build a predicate with one lane per logical element. return ( (_simt_edge_coord_subtile_source(indent) if include_coord_setup else "") + f"{indent}for _edge_i in range(cute.size({src}.shape)):\n" f"{indent} _coord = {ttr_cc_subtile}[_edge_i]\n" f"{indent} if cute.elem_less(_coord, ({m_size}, {n_size})):\n" f"{indent} {dst}[_edge_i] = {src}[_edge_i]\n" ) def _simt_edge_logical_divide_copy_source( indent: str, src: str, dst: str, *, include_coord_setup: bool = True, var_prefix: str = "tcgen05_edge", copy_atom: str | None = None, ) -> str: # Shared edge-only vector copy emitter. The make_layout(1) retile gives # cute.copy a per-element predicate, while var_prefix/copy_atom let the # same shape drive D stores or exact-aux G2R register loads. copy_atom = copy_atom or simt_atom edge_src = df.new_var(f"{var_prefix}_src") edge_dst = df.new_var(f"{var_prefix}_dst") edge_coord = df.new_var(f"{var_prefix}_coord") edge_pred = df.new_var(f"{var_prefix}_pred") return ( (_simt_edge_coord_subtile_source(indent) if include_coord_setup else "") + f"{indent}{edge_src} = cute.logical_divide({src}, cute.make_layout(1))\n" f"{indent}{edge_dst} = cute.logical_divide({dst}, cute.make_layout(1))\n" f"{indent}{edge_coord} = cute.logical_divide({ttr_cc_subtile}, cute.make_layout(1))\n" f"{indent}{edge_pred} = cute.make_rmem_tensor((1, {edge_src}.shape[1]), cutlass.Boolean)\n" f"{indent}for _edge_i in range(cute.size({edge_src}.shape[1])):\n" f"{indent} _coord = {edge_coord}[0, _edge_i]\n" f"{indent} {edge_pred}[0, _edge_i] = cute.elem_less(_coord, ({m_size}, {n_size}))\n" f"{indent}cute.copy({copy_atom}, {edge_src}, {edge_dst}, pred={edge_pred})\n" ) def _aux_tile_setup_lines( *, thr_copy_t2r_var: str, define_thr_copy_t2r: bool, force_gmem_aux: bool = False, retile_for_r2s: bool = False, ) -> list[str]: """Emit the per-output-tile aux partitioning lines. Each line goes once per output tile, before the per-subtile loop. Mirrors the existing ``tcgc -> tcgc_planned -> tcgc_epi -> ttr_gc -> ttr_gc_grouped`` pipeline used for the result D tensor, but partitions a separate auxiliary GMEM tensor per chain step. Calls ``thr_mma.partition_C`` and ``thr_copy_t2r.partition_D`` against the aux tile so the per-thread layout matches D's layout exactly — both the exact-shape (``residual[tile_m, tile_n]``) and rank-1 broadcast (``bias[tile_n]`` / ``bias[tile_m]``) forms feed the same downstream pipeline. For the broadcast form the helper first builds a 2-D view of the underlying rank-1 tensor with stride 0 on the orthogonal axis (see :class:`_AuxiliaryTensorStep` for the canonical contract). When ``define_thr_copy_t2r`` is True the helper emits the ``thr_copy_t2r = tiled_copy_t2r.get_slice(...)`` line first (the TMA-store path does not otherwise create ``thr_copy_t2r``); the SIMT path passes False because it already creates the slice as part of its existing partition pipeline. ``retile_for_r2s`` mirrors Quack's SM100 epilogue visitor layout: TMA-store chains read aux operands in the R2S-retiled layout so the chain carrier can be ``tRS_rAcc`` / ``tRS_rD`` instead of the raw T2R fragment layout. ``force_gmem_aux`` is used by the hybrid edge-only SIMT path: C-input staging is only safe for full tiles because the producer-side bulk copy is not predicated for M/N fringes. """ lines: list[str] = [] if not aux_step_records: return lines if define_thr_copy_t2r: lines.append( f"{thr_copy_t2r_var} = " f"{tiled_copy_t2r}.get_slice({tcgen05_aux_epi_tidx})" ) for aux_idx, rec in enumerate(aux_step_records): staged_ring_name = aux_ring_smem_names[aux_idx] rowvec_stage = rowvec_aux_stage_records[aux_idx] if ( use_aux_smem_source and staged_ring_name is not None and not force_gmem_aux ): # C-input warp productive-body gate is open for this exact-shape # descriptor: build the Quack-style SMEM->register path. Rowvec # broadcast records are not staged and fall through to the GMEM # partition setup below. assert aux_matmul_plan is not None ring_idx = aux_ring_index_by_step[aux_idx] assert ring_idx is not None aux_dtype_str = backend.dtype_str( aux_matmul_plan.c_input_aux_tensor_descriptors[ ring_idx ].host_tensor_val.dtype ) tiled_copy_s2r_var = f"{rec.aux_tile}_tiled_copy_s2r" thr_copy_s2r_var = f"{rec.aux_tile}_thr_copy_s2r" tsr_sc_var = f"{rec.aux_tile}_tSR_sC" trs_rc_var = f"{rec.aux_tile}_tRS_rC" tsr_rc_var = f"{rec.aux_tile}_tSR_rC" rmem_shape_expr = ( f"{trs_rd}.layout" if retile_for_r2s else f"{ttr_racc}.shape" ) lines.extend( [ ( f"{tiled_copy_s2r_var} = " f"cute.make_tiled_copy_D(" f"cute.make_copy_atom(" f"cute.nvgpu.CopyUniversalOp(), " f"{aux_dtype_str}), " f"{tiled_copy_t2r})" ), ( f"{thr_copy_s2r_var} = " f"{tiled_copy_s2r_var}.get_slice(" f"{tcgen05_aux_epi_tidx})" ), ( f"{tsr_sc_var} = " f"{thr_copy_s2r_var}.partition_S(" f"{staged_ring_name})" ), ( f"{trs_rc_var} = cute.make_rmem_tensor(" f"{rmem_shape_expr}, {aux_dtype_str})" ), (f"{tsr_rc_var} = {tiled_copy_s2r_var}.retile({trs_rc_var})"), ] ) continue if rec.broadcast_axis is None or rec.broadcast_axis == 2: # Exact-shape rank-2 aux (or the colvec form, which is a full # (M, N) stride-(1,0) view): slice the per-tile region of the # underlying 2-D tensor directly. The colvec's per-subtile read # is specialized to a scalar in ``_aux_subtile_load_source``. source_for_local_tile = rec.aux_tensor_name aux_tile_is_local = False elif rowvec_stage is not None: assert rec.broadcast_axis == 1 assert rec.aux_view2d is not None # The compact SMEM rowvec is allocated and populated per output # tile, so its 2-D broadcast view is already tile-sized. lines.append( f"{rec.aux_view2d} = cute.make_tensor(" f"{rowvec_stage.smem}.iterator, " f"cute.make_layout(({tcgen05_bm}, {tcgen05_bn}), " f"stride=(0, 1)))" ) source_for_local_tile = rec.aux_view2d aux_tile_is_local = True else: # M-axis (row) broadcast aux: build a 2-D logical view # over the underlying tensor's ``.iterator`` with # stride 0 on the leading (M) axis and stride 1 on the # trailing (N) axis. Stride 0 on M causes every lane # "owning" output ``(m, n)`` to read the same source # element regardless of m, which is the broadcast # semantic shared by two accepted forms: # * ``broadcast_axis == 1`` — a bare rank-1 tensor # ``bias[tile_n]`` with shape ``(N,)`` (rank-1 RHS # aligns to the trailing axis under PyTorch # broadcasting). # * ``broadcast_axis == 0`` — an explicit ``(1, N)`` # tensor ``bias[tile_m, tile_n]`` (row 0 broadcasts # over M). # Both have the same contiguous N-major memory layout # (element ``(0, n)`` at offset ``n``), so the # stride-(0, 1) view over ``.iterator`` is identical # and feeds the same ``partition_C → flat_divide → # partition_D`` pipeline used by exact-shape aux. # Mirrors Quack's ``RowVecLoad`` epilogue # (``quack/quack/epi_ops.py``). The classifier # (``aux_tensor_load_kind``) admits only these two # broadcast shapes; everything else drops to the # loud-failure backstop. assert rec.broadcast_axis in (0, 1) assert rec.aux_view2d is not None lines.append( f"{rec.aux_view2d} = cute.make_tensor(" f"{rec.aux_tensor_name}.iterator, " f"cute.make_layout(({m_size}, {n_size}), " f"stride=(0, 1)))" ) source_for_local_tile = rec.aux_view2d aux_tile_is_local = False if aux_tile_is_local: lines.append(f"{rec.aux_tile} = {source_for_local_tile}") else: lines.append( f"{rec.aux_tile} = cute.local_tile(" f"{source_for_local_tile}, ({tcgen05_bm}, {tcgen05_bn}), " f"({tile_coord_m}, {tile_coord_n}))" ) lines.extend( [ ( f"{rec.aux_part_base} = " f"{tcgen05_thr_mma}.partition_C({rec.aux_tile})" ), ( f"{rec.aux_xfm} = " "cutlass.utils.gemm.sm100.transform_partitioned_tensor_layout(" f"{rec.aux_part_base})" ), ( f"{rec.aux_planned} = cute.make_tensor(" f"{rec.aux_xfm}.iterator, " f"cute.append(cute.append(cute.append({rec.aux_xfm}.layout, " f"{tcgen05_aux_epilogue_rest_mode}), " f"{tcgen05_aux_epilogue_rest_mode}), " f"{tcgen05_aux_epilogue_rest_mode}))" ), ( f"{rec.aux_epi} = cute.flat_divide(" f"{rec.aux_planned}, {epi_tile})" ), (f"{rec.ttr_aux} = {thr_copy_t2r_var}.partition_D({rec.aux_epi})"), *( [f"{rec.ttr_aux} = {tiled_copy_r2s}.retile({rec.ttr_aux})"] if retile_for_r2s else [] ), ( f"{rec.ttr_aux_grouped} = cute.group_modes(" f"{rec.ttr_aux}, 3, cute.rank({rec.ttr_aux}))" ), # Pre-wait hoist: one cooperative LDG of the whole rowvec # fragment, issued here (before the accumulator # consumer_wait downstream) so the GMEM latency overlaps # the MMA wait. Per-subtile reads then come from # registers. See the ``aux_rmem_full`` field docs for the # family gate. *( [ ( f"{rec.aux_rmem_full} = cute.make_rmem_tensor(" f"{rec.ttr_aux_grouped}.shape, {rec.aux_dtype})" ), ( f"cute.autovec_copy({rec.ttr_aux_grouped}, " f"{rec.aux_rmem_full})" ), ] if rec.aux_rmem_full is not None and not force_gmem_aux else [] ), ] ) return lines def _materialize_broadcast_aux_source( indent: str, rec: object, carrier_name: str ) -> str: """Emit a per-subtile aux load that matches the accumulator carrier's register profile. Example output (``indent=' '``, ``carrier_name='tcgen05_tRS_rAcc'``):: tcgen05_aux_rmem_0 = cute.make_rmem_tensor( cute.make_layout(tcgen05_tRS_rAcc.shape), cutlass.Float32 ) cute.autovec_copy(tcgen05_tTR_gAux_subtile_0, tcgen05_aux_rmem_0) tcgen05_aux_loaded_0 = tcgen05_aux_rmem_0.load() """ return ( f"{indent}{rec.aux_rmem} = " # type: ignore[attr-defined] f"cute.make_rmem_tensor(cute.make_layout({carrier_name}.shape), " f"{rec.aux_dtype})\n" # type: ignore[attr-defined] f"{indent}cute.autovec_copy(" f"{rec.ttr_aux_subtile}, {rec.aux_rmem})\n" # type: ignore[attr-defined] f"{indent}{rec.aux_loaded} = " # type: ignore[attr-defined] f"{rec.aux_rmem}.load()\n" # type: ignore[attr-defined] ) def _aux_subtile_load_source( prelude_indent: str, carrier_name: str, *, force_simt_edge_aux: bool = False, safe_direct_aux_with_full_tile: bool = False, ) -> str: """Per-subtile aux GMEM-load source lines (one per aux step). Each step emits the per-thread GMEM subtile slice of ``tTR_gAux_grouped_<idx>`` followed by a ``.load()`` call into the per-subtile ``tcgen05_aux_loaded_*`` local. Goes inside the per-subtile loop body. The slice depends on ``_tcgen05_subtile`` so it cannot be hoisted out of the loop entirely. Splice sites choose where to place this block: the default TMA-store path keeps it after the c_pipeline acquire, acc ``consumer_wait``, and t2r async TMEM→reg copy so residual and bias fragments are not live through the store-prefix waits. SIMT fallback concatenates it with the chain prelude because it does not use the TMA aux-pipeline shape; diagnostic helper paths keep the same flat prelude order for unary chains and reject aux chains at validation time. Cycle 39 (GPU 6) replan note: an alternative form that pre-loads all subtile aux into a per-thread register tensor outside the per-subtile loop (``cute.autovec_copy`` from ``tTR_gAux_grouped_<idx>`` into a fresh ``tTR_rAux_<idx>``) was tested. The single cooperative LDG fired before the per-subtile loop, but the multi- subtile register tensor pushed local-memory spills from 356k to 1.17M and grew kernel duration from 308 µs to 332 µs. The per-subtile GMEM load form below pays one LDG per chain-add but the compiler IR / SASS scheduler already lifts the LDG ahead of the chain-add given the independent dependency graph. Cycle 69 found a related spill tradeoff inside the default TMA-store body: placing the per-subtile aux LDG after the acquire/T2R prefix removes most local-memory spill traffic, so that path no longer uses the older top-of-loop hoist. """ if not aux_step_records: return "" lines: list[str] = [] force_simt_edge_coord_emitted = False if use_aux_smem_source and not force_simt_edge_aux: # C-input warp productive-body gate is open: per-subtile # SMEM ring staging. Each subtile iteration waits on # ``c_pipeline_aux`` for the producer warp to fill the # active stage, then issues one filtered # ``cute.copy(tiled_copy_s2r, tSR_sC[..., stage], tSR_rC)`` # per descriptor to load the active stage into the # per-thread register tensor (Quack's # ``epilog_smem_load_and_partition`` flow from # ``quack/gemm_sm100.py``: ``tiled_copy_s2r`` is built via # ``make_tiled_copy_D`` against ``tiled_copy_t2r``; # ``tSR_sC = thr_copy_s2r.partition_S(sC_ring)`` selects # the SMEM source; ``tSR_rC`` is a re-layout view of the # same register memory as ``tRS_rC``). The chain reads # ``tRS_rC.load()`` (== ``aux_loaded``). The post-copy # lane-0-gated release plus state advance run in the # same per-subtile iteration so the producer can refill # the same stage on the very next persistent tile # (matches the consumer cooperative-group arrive count # of ``epi_warp_count`` set by # ``_emit_tcgen05_aux_pipeline_setup``). # # Note: ``partition_D(smem_stage).load()`` on # ``thr_copy_t2r`` (an earlier prior-subagent variant) # produced a deadlocking SMEM read — TMEM→reg-shaped # partition_D applied to a SMEM tensor does not # compose with the producer's # ``make_tiled_copy_tv`` cooperative copy in a way the # mbarrier handshake recognizes. The Quack-style # ``tiled_copy_s2r`` flow is the canonical CUTLASS-DSL # pattern. lines.append( f"{prelude_indent}{aux_pipeline_name}.consumer_wait(" f"{aux_consumer_state_name})\n" ) if aux_pipeline_uses_tma_load: # TMA producer writes arrive through the async proxy; after the # pipeline wait, fence that view before generic SMEM reads. # The warp sync mirrors CUTLASS/Quack's TMA-load consumer # sequence so every lane observes the fenced view before the # per-lane SMEM->register copy below. lines.extend( [ f"{prelude_indent}cute.arch.fence_view_async_shared()\n", f"{prelude_indent}cute.arch.sync_warp()\n", ] ) for aux_idx, rec in enumerate(aux_step_records): if aux_ring_smem_names[aux_idx] is None: continue tiled_copy_s2r_var = f"{rec.aux_tile}_tiled_copy_s2r" tsr_sc_var = f"{rec.aux_tile}_tSR_sC" trs_rc_var = f"{rec.aux_tile}_tRS_rC" tsr_rc_var = f"{rec.aux_tile}_tSR_rC" lines.extend( [ ( # The S2R visitor layout can carry zero/unused lanes; # filtering keeps the residual SMEM read footprint # aligned with the lanes that feed the R2S fragment. f"{prelude_indent}cute.copy(" f"{tiled_copy_s2r_var}, " f"cute.filter_zeros({tsr_sc_var}[None, None, None, " f"{aux_consumer_state_name}.index]), " f"cute.filter_zeros({tsr_rc_var}))\n" ), (f"{prelude_indent}{rec.aux_loaded} = {trs_rc_var}.load()\n"), ] ) lines.extend( [ ( f"{prelude_indent}with cute.arch.elect_one():\n" f"{prelude_indent} {aux_pipeline_name}.consumer_release(" f"{aux_consumer_state_name})\n" ), emit_pipeline_advance( aux_consumer_state_name, indent=prelude_indent ) + "\n", ] ) for aux_idx, rec in enumerate(aux_step_records): rowvec_stage = rowvec_aux_stage_records[aux_idx] if ( use_aux_smem_source and not force_simt_edge_aux and aux_ring_smem_names[aux_idx] is not None ): continue if force_simt_edge_aux: include_coord_setup = not force_simt_edge_coord_emitted force_simt_edge_coord_emitted = True if rec.broadcast_axis is None: edge_aux_copy_source = _simt_edge_logical_divide_copy_source( prelude_indent, rec.ttr_aux_subtile, rec.aux_rmem, include_coord_setup=include_coord_setup, var_prefix=f"{rec.aux_rmem}_edge", copy_atom=simt_edge_aux_atoms[aux_idx], ) else: # Rowvec broadcast stayed scalar in the cycle-74 ablation: # vectorizing it did not reduce stack pressure or runtime. edge_aux_copy_source = _simt_edge_scalar_copy_source( prelude_indent, rec.ttr_aux_subtile, rec.aux_rmem, include_coord_setup=include_coord_setup, ) lines.append( f"{prelude_indent}{rec.ttr_aux_subtile} = " f"{rec.ttr_aux_grouped}" f"[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n" f"{prelude_indent}{rec.aux_rmem} = " f"cute.make_rmem_tensor({rec.ttr_aux_subtile}.shape, " f"{rec.aux_dtype})\n" f"{prelude_indent}{rec.aux_rmem}.fill(0)\n" + edge_aux_copy_source + f"{prelude_indent}{rec.aux_loaded} = " f"{rec.aux_rmem}.load()\n" ) continue if rowvec_stage is None and ( safe_direct_aux_with_full_tile or not tcgen05_aux_use_tma_store_epilogue ): lines.append( f"{prelude_indent}{rec.ttr_aux_subtile} = " f"{rec.ttr_aux_grouped}" f"[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n" f"{prelude_indent}{rec.aux_loaded} = cute.full(" f"{rec.ttr_aux_subtile}.shape, 0, {rec.aux_dtype})\n" f"{prelude_indent}if {full_tile}:\n" f"{prelude_indent} {rec.aux_loaded} = " f"{rec.ttr_aux_subtile}.load()\n" f"{prelude_indent}else:\n" f"{prelude_indent} {rec.aux_rmem} = " f"cute.make_rmem_tensor({rec.ttr_aux_subtile}.shape, " f"{rec.aux_dtype})\n" f"{prelude_indent} {rec.aux_rmem}.fill(0)\n" f"{_simt_edge_scalar_copy_source(prelude_indent + ' ', rec.ttr_aux_subtile, rec.aux_rmem)}" f"{prelude_indent} {rec.aux_loaded} = " f"{rec.aux_rmem}.load()\n" ) continue if rowvec_stage is not None and not force_simt_edge_aux: # Row-vector staging broadcasts through a stride-0 M mode; filter # that layout so the SMEM read does not reload duplicate lanes. lines.append( f"{prelude_indent}{rec.ttr_aux_subtile} = " f"{rec.ttr_aux_grouped}" "[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n" f"{prelude_indent}{rec.aux_rmem} = " f"cute.make_rmem_tensor({rec.ttr_aux_subtile}.layout, " f"{rec.aux_dtype})\n" f"{prelude_indent}cute.autovec_copy(" f"cute.filter_zeros({rec.ttr_aux_subtile}), " f"cute.filter_zeros({rec.aux_rmem}))\n" f"{prelude_indent}{rec.aux_loaded} = {rec.aux_rmem}.load()\n" ) continue if rec.broadcast_axis == 2: # Column-vector (per-row) aux: a rank-2 ``(m, n)`` operand # broadcast over N (its value depends only on the row m). When # a thread's fragment is within a single M row the per-row value # is uniform across it, so the cheap scalar read # ``tTR_gAux[(0, 0, 0, subtile)]`` is exact (PR #2742). When it # spans multiple M rows the scalar applies row 0's value to # every row, so materialize per element instead. # # Decide this at codegen time via # ``tcgen05_colvec_fragment_single_m_row`` (epi_tile_m vs the # 128-lane TMEM datapath). A runtime layout test cannot: after # ``partition_D`` + ``group_modes`` the per-thread strides are # all dynamic (nothing for ``cute.filter`` to drop, and mode 0 # conflates M and N), so such tests silently degrade to # always-materialize. lines.append( f"{prelude_indent}{rec.ttr_aux_subtile} = " f"{rec.ttr_aux_grouped}" f"[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n" ) if tcgen05_colvec_fragment_single_m_row: lines.append( f"{prelude_indent}{rec.aux_loaded} = " f"{rec.ttr_aux_grouped}" f"[(0, 0, 0, cutlass.Int32(_tcgen05_subtile))]\n" ) else: lines.append( _materialize_broadcast_aux_source( prelude_indent, rec, carrier_name ) ) continue if rec.aux_rmem_full is not None: # Pre-wait hoisted rowvec: the whole fragment is already in # registers (loaded before the accumulator consumer_wait by # ``_aux_tile_setup_lines``); slice the active subtile. lines.append( f"{prelude_indent}{rec.aux_loaded} = " f"{rec.aux_rmem_full}" f"[(None, None, None, cutlass.Int32(_tcgen05_subtile))].load()\n" ) continue # Both remaining cases -- rowvec / leading-broadcast aux # (``broadcast_axis in (0, 1)``: a per-column operand broadcast # over M) and exact-shape aux (``broadcast_axis is None``: a full # ``(m, n)`` operand) -- always materialize. Neither has a cheaper # scalar form (the operand is a full per-thread vector either way), # and both otherwise produce a nested profile that cannot combine # with the flat carrier (rowvec from its stride-0 mode, exact-shape # once the fragment spans multiple M rows). The materialize copy is # a no-op reshape when the profile already matches (block_m >= the # atom M). lines.append( f"{prelude_indent}{rec.ttr_aux_subtile} = " f"{rec.ttr_aux_grouped}" f"[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n" + _materialize_broadcast_aux_source(prelude_indent, rec, carrier_name) ) return "".join(lines) # Render the per-thread carrier expression for the accumulator # vector. The identity epilogue (no chain or empty chain) emits # the original `rAcc.load().to(target_dtype)` line. When a # chain is present, hoist `rAcc.load()` to a local TensorSSA so # the chain reads the loaded vector once; for chains with # auxiliary-tensor steps, also emit per-subtile aux-load lines # that bind the aux locals the chain references. Each splice # site below uses the appropriate carrier name (`ttr_racc` for # the SIMT path, `trs_racc` for the TMA path, and # `tcgen05_tRS_rAcc` for the @cute.jit module helper). The # returned snippet is a sequence of zero-or-more prelude # statements (each newline-terminated, indented with # `prelude_indent`) plus the assignment expression for # `tcgen05_acc_vec`. def _splice_acc_vec( carrier_name: str, prelude_indent: str, *, force_simt_edge_aux: bool = False, safe_direct_aux_with_full_tile: bool = False, ) -> tuple[str, str, str]: """Return ``(early_aux_prelude, late_prelude, assignment_rhs)``. ``early_aux_prelude`` is the per-subtile auxiliary-tensor LDG block (``ttr_aux_subtile = ...``; ``aux_loaded = .load()``) and is empty when the chain has no aux steps. ``late_prelude`` holds the ``acc_loaded = carrier.load()`` and the chain-step renderings. ``assignment_rhs`` is the right-hand side of ``acc_vec = ...`` (without leading whitespace or the trailing newline). Both preludes are empty for the identity epilogue (no chain) — in that case ``assignment_rhs`` is the original ``carrier.load().to(target_dtype)`` expression. Each chain step renders into a fresh ``tcgen05_chain_step*`` local so chain composition stays linear in source size — the relu template duplicates ``{inner}`` 5 times, so without per- step binding a 3-deep relu chain would emit 125x duplication and pessimize parse / IR-build time. Per-step locals keep the rendered source O(N) in chain depth and CuTe CSEs the loads at compile. Auxiliary-tensor chain steps additionally emit per-aux-step ``ttr_aux_subtile = ...`` slice + ``aux_loaded = ...`` lines (the per-tile aux setup runs once per output tile and is emitted by the splice site's surrounding scaffolding via ``_aux_tile_setup_lines()``). Splitting the aux LDG out of the chain prelude lets each splice site place the GMEM load where it best fits its live ranges. The default TMA-store splice now inserts it after the c_pipeline acquire, acc ``consumer_wait``, and t2r async TMEM→reg copy so residual and bias fragments are not live through those prefix waits. SIMT-store edge tiles use the same aux prelude, but route the aux load through a predicated copy before rendering the chain. """ load_expr = f"{carrier_name}.load()" if epilogue_chain is None or not epilogue_chain.steps: return ("", "", f"{load_expr}.to({target_dtype})") loaded = df.new_var("tcgen05_acc_loaded") prelude_load = f"{prelude_indent}{loaded} = {load_expr}\n" early_aux_prelude = _aux_subtile_load_source( prelude_indent, carrier_name, force_simt_edge_aux=force_simt_edge_aux, safe_direct_aux_with_full_tile=safe_direct_aux_with_full_tile, ) aux_locals: tuple[str, ...] = tuple(rec.aux_loaded for rec in aux_step_records) chain_prelude, final_expr = epilogue_chain.render_prelude_and_expr( loaded, df.new_var, prelude_indent, aux_locals_by_step=aux_locals or None, ) return ( early_aux_prelude, prelude_load + chain_prelude, f"({final_expr}).to({target_dtype})", ) if tcgen05_value.use_tma_store_epilogue: df.placeholder_args.add(tensor_name) df.wrapper_only_params.extend( [tcgen05_value.tma_store_atom, tcgen05_value.tma_store_tensor] ) if tcgen05_value.use_role_local_epi and tcgen05_value.role_local_tile_counter: df.cute_state.register_tcgen05_epi_role_tile_counter( tcgen05_value.role_local_tile_counter, increment_per_tile=not tcgen05_value.tma_store_full_tiles_only, ) # The bm=128 CtaGroup.TWO family's epilogue tile is N-mode permuted # (see ``tcgen05_two_cta_m128_epilogue_tile_expr``); the host TMA-store # atom must be built from this *exact* device-side expression, not from # the plain ``epi_tile_m/n`` integer keys (which build an unpermuted # ``(m, n)`` tile and silently scramble the output -- the correctness # bug §7 in FINDINGS_512_SHAPE.md). ``epi_tile_raw_expr`` carries the # verbatim device expression to the wrapper. d_two_cta_m128 = tcgen05_is_two_cta_m128( is_two_cta=tcgen05_lifecycle.is_two_cta, bm=tcgen05_value.bm ) d_tma_plan: dict[str, object] = { "kind": "tcgen05_d_tma", "d_name": tensor_name, "bm": tcgen05_value.bm, "bn": tcgen05_value.bn, "c_stage_count": tcgen05_value.c_stage_count, "output_dtype": target_dtype, "kernel_args": [ tcgen05_value.tma_store_atom, tcgen05_value.tma_store_tensor, ], **( { "epi_tile_raw_expr": tcgen05_two_cta_m128_epilogue_tile_expr( tcgen05_value.bm, tcgen05_value.bn, target_dtype, c_layout="cutlass.utils.layout.LayoutEnum.ROW_MAJOR", ) } if d_two_cta_m128 and not tcgen05_value.has_explicit_epilogue_tile else {} ), **( { "epi_tile_m": tcgen05_value.explicit_epi_tile_m, "epi_tile_n": tcgen05_value.explicit_epi_tile_n, "d_store_box_n": tcgen05_value.explicit_d_store_box_n, } if tcgen05_value.has_explicit_epilogue_tile else {} ), } state.codegen.cute_wrapper_plans.append(d_tma_plan) tcgen05_bm = tcgen05_value.bm tcgen05_bn = tcgen05_value.bn tcgen05_bk = tcgen05_value.bk tcgen05_epilog_sync_barrier_id = tcgen05_value.epilog_sync_barrier_id tcgen05_c_stage_count = tcgen05_value.c_stage_count tcgen05_is_two_cta = tcgen05_lifecycle.is_two_cta tcgen05_thr_mma = tcgen05_value.thr_mma # The bm=128 CtaGroup.TWO family stores through the per-CTA epilogue tile # (m of 64, ``use_2cta=True``, no-source) so the store-side # ``tcgen05_store_epi_tile``, the ``kernel_desc.cta_tile_shape_mnk``, and # the host TMA-store atom all match the device-side N-mode-permuted tile. # Resolved through the shared helper so this ``(store_tile_m, epi_tile_expr)`` # pair stays identical to the layout-plan side in ``cute_mma.py``. tcgen05_store_tile_m, tcgen05_store_epi_tile_expr = tcgen05_resolve_epilogue_tile( bm=tcgen05_bm, bn=tcgen05_bn, is_two_cta=tcgen05_is_two_cta, elem_dtype=target_dtype, c_layout="cutlass.utils.layout.LayoutEnum.ROW_MAJOR", explicit_expr=tcgen05_explicit_store_tile_expr, ) full_tile_expr = ( f"({base_indices[0]}) + cutlass.Int32({tcgen05_bm}) <= {m_size} " f"and ({base_indices[1]}) + cutlass.Int32({tcgen05_bn}) <= {n_size}" ) def store_common_setup( gmem_tensor: str, *, include_full_tile: bool ) -> tuple[list[str], list[str]]: epi_tile_expr = tcgen05_store_epi_tile_expr static_setup = [ ( f"{kernel_desc} = type('Tcgen05KernelDesc', (), {{" f"'cta_tile_shape_mnk': ({tcgen05_store_tile_m}, {tcgen05_bn}, {tcgen05_bk}), " "'c_layout': cutlass.utils.layout.LayoutEnum.ROW_MAJOR, " f"'c_dtype': {target_dtype}, " "'acc_dtype': cutlass.Float32, " f"'epilog_sync_bar_id': cutlass.Int32({tcgen05_epilog_sync_barrier_id}), " f"'epilogue_warp_id': ({epi_warp_ids}), " f"'num_c_stage': cutlass.Int32({tcgen05_c_stage_count}), " f"'use_2cta_instrs': {tcgen05_is_two_cta!s}" "})()" ), ( # The fallback helper must receive the D-output dtype through # ``layout_c=`` / ``elem_ty_c=`` so it selects the same # with-source branch as the matmul-plan ``tcgen05_epi_tile``. # The explicit path instead uses the D-store box field directly. # Keep both forms in lockstep with the wrapper-side TMA atom. f"{epi_tile} = {epi_tile_expr}" ), ] tile_setup: list[str] = [] if include_full_tile: tile_setup.append(f"{full_tile} = {full_tile_expr}") tile_setup.extend( [ ( f"{gmem_tile} = cute.local_tile(" f"{gmem_tensor}, ({tcgen05_bm}, {tcgen05_bn}), " f"({tile_coord_m}, {tile_coord_n}))" ), f"{tcgc_base} = {tcgen05_thr_mma}.partition_C({gmem_tile})", ] ) return static_setup, tile_setup simt_edge_only = tcgen05_value.tma_store_full_tiles_only simt_edge_aux_atoms: dict[int, str] = {} simt_edge_aux_atom_setup: list[str] = [] if simt_edge_only: for aux_idx, rec in enumerate(aux_step_records): if rec.broadcast_axis is None: edge_aux_atom = df.new_var(f"{rec.aux_rmem}_edge_atom") simt_edge_aux_atoms[aux_idx] = edge_aux_atom # Use a per-aux atom typed to the aux dtype. Reusing the # output SIMT atom here was spill-free but slower on the # measured Target8 edge path. simt_edge_aux_atom_setup.append( f"{edge_aux_atom} = " f"cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), " f"{rec.aux_dtype})" ) simt_static_store_setup, simt_tile_store_setup = store_common_setup( tensor_name, include_full_tile=not simt_edge_only ) simt_early_aux, simt_late_prelude, simt_acc_vec_rhs = _splice_acc_vec( ttr_racc, " ", force_simt_edge_aux=tcgen05_value.tma_store_full_tiles_only, ) simt_acc_vec_prelude = simt_early_aux + simt_late_prelude tma_static_store_setup, tma_tile_store_setup = store_common_setup( tcgen05_value.tma_store_tensor, include_full_tile=partial_tma_needs_full_tile_guard, ) # Role-local TMA stores reuse one C pipeline across work tiles. Static-full # kernels increment this counter once per role-local tile; hybrid # output-edge kernels increment it only in the full-tile branch so SIMT # fallback edge tiles do not perturb the C-pipeline SMEM stage sequence. tma_c_buffer_expr = "cutlass.Int32(_tcgen05_subtile)" if tcgen05_value.role_local_tile_counter: tma_c_buffer_expr = ( f"{tcgen05_value.role_local_tile_counter} * " f"cutlass.Int32({subtile_count}) + cutlass.Int32(_tcgen05_subtile)" ) simt_store_edge_coord_preloaded = simt_edge_only and bool(aux_steps_in_chain) if simt_edge_only: simt_store_copy_source = _simt_edge_logical_divide_copy_source( " ", ttr_rd, ttr_gc_subtile, include_coord_setup=not simt_store_edge_coord_preloaded, ) else: simt_store_copy_source = ( f" if {full_tile}:\n" f" cute.copy({simt_atom}, {ttr_rd}, {ttr_gc_subtile})\n" f" else:\n" f"{_simt_edge_scalar_copy_source(' ', ttr_rd, ttr_gc_subtile)}" ) simt_store_body_core = [ *simt_static_store_setup, *simt_tile_store_setup, ( f"{tcgc} = cutlass.utils.gemm.sm100.transform_partitioned_tensor_layout(" f"{tcgc_base})" ), ( f"{tcgc_planned} = cute.make_tensor(" f"{tcgc}.iterator, " f"cute.append(cute.append(cute.append({tcgc}.layout, {tcgen05_value.epilogue_rest_mode}), {tcgen05_value.epilogue_rest_mode}), {tcgen05_value.epilogue_rest_mode}))" ), ( f"{tacc} = cutlass.utils.gemm.sm100.transform_partitioned_tensor_layout(" f"{tcgen05_value.epi_acc_frag_base})" ), ( f"{tiled_copy_t2r}, {ttr_tacc_base}, {ttr_racc} = " "cutlass.utils.gemm.sm100.epilogue_tmem_copy_and_partition(" f"{kernel_desc}, {tcgen05_value.epi_tidx}, {tacc}, {tcgc_planned}, {epi_tile}, {tcgen05_lifecycle.is_two_cta!s})" ), f"{thr_copy_t2r} = {tiled_copy_t2r}.get_slice({tcgen05_value.epi_tidx})", f"{tcgc_epi} = cute.flat_divide({tcgc_planned}, {epi_tile})", f"{ttr_gc} = {thr_copy_t2r}.partition_D({tcgc_epi})", ( f"{ttr_tacc_stage} = {ttr_tacc_base}[" f"(None, None, None, None, None, {tcgen05_acc_stage_index_expr})]" ), *( [] if is_secondary_store else [ ( f"if {tcgen05_lifecycle.epi_active}:\n" f" {tcgen05_lifecycle.acc_pipeline}.consumer_wait({tcgen05_lifecycle.acc_consumer_state})" ) ] ), f"{ttr_tacc} = cute.group_modes({ttr_tacc_stage}, 3, cute.rank({ttr_tacc_stage}))", f"{ttr_gc_grouped} = cute.group_modes({ttr_gc}, 3, cute.rank({ttr_gc}))", # Per-aux-step partitioning lines (one chain per auxiliary # tensor). No-op when the chain has no aux steps; generated # source is byte-identical to the unary-chain shape for # unary chains and to the identity-store golden for identity # stores. *_aux_tile_setup_lines( thr_copy_t2r_var=thr_copy_t2r, define_thr_copy_t2r=False, force_gmem_aux=simt_edge_only, ), ( f"{ttr_racc} = cute.make_rmem_tensor(" f"{ttr_gc_grouped}[(None, None, None, 0)].shape, cutlass.Float32)" ), f"{ttr_rd} = cute.make_rmem_tensor({ttr_racc}.shape, {target_dtype})", ( f"{mcld} = cute.max_common_layout(" f"{ttr_rd}.layout, {ttr_gc_grouped}[(None, None, None, 0)].layout)" ), ( f"{num_bits} = min(" f"{ttr_gc_grouped}.iterator.alignment * 8, " f"cute.size({mcld}) * {target_dtype}.width, 256)" ), ( f"{simt_atom} = cute.make_copy_atom(" f"cute.nvgpu.CopyR2GOp(), {target_dtype}, " f"num_bits_per_copy={num_bits}, " f"l1c_evict_priority=cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE)" ), *simt_edge_aux_atom_setup, f"{subtile_count} = cutlass.const_expr(cute.size({ttr_tacc}.shape, mode=[3]))", ( # Per-subtile loop: TMEM->reg (t2r) first, then reg->GMEM (SIMT # store). On the last subtile we release the acc consumer slot # *before* the GMEM store so the next mainloop tile's MMA can # producer_acquire the TMEM stage and begin issuing UMMAs while # this tile's epilogue is still draining to GMEM. This mirrors the # release-acc-inside-the-subtile-loop pattern in Quack's sm100 # gemm epilogue. Without c_pipeline SMEM staging we can only # release after the final t2r (not per-subtile), but even one # tile of overlap measurably improves the wide tcgen05 path on # B200. `cutlass.range(..., unroll_full=True)` keeps the loop # statically unrolled so `tiled_copy_t2r` (a TiledCopy that wraps # a tcgen05 tmem_load atom) is not captured as an scf.for iter_arg # — the cute-to-nvvm pass cannot legalize that conversion through # iter_args and aborts during compile. f"for _tcgen05_subtile in cutlass.range({subtile_count}, unroll_full=True):\n" f" if {tcgen05_lifecycle.epi_active}:\n" f" {ttr_tacc_mn} = {ttr_tacc}[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n" f" {ttr_gc_subtile} = {ttr_gc_grouped}[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n" f" cute.copy({tiled_copy_t2r}, {ttr_tacc_mn}, {ttr_racc})\n" f"{simt_acc_vec_prelude}" f" {acc_vec} = {simt_acc_vec_rhs}\n" f" {ttr_rd}.store({acc_vec})\n" # The secondary fan-out store reuses the still-live accumulator and # must not release it; the primary store owns the release + advance. + ( "" if is_secondary_store else ( f" if _tcgen05_subtile == {subtile_count} - 1:\n" # `cute.copy(t2r, ...)` issues async TMEM->reg loads. # Releasing the acc consumer slot lets the MMA producer # re-acquire the TMEM stage and issue UMMAs that overwrite # TMEM, so we must fence the in-flight async TMEM loads # first to avoid a race on the last subtile's `ttr_racc` / # `ttr_rd` data. This matches Quack's sm100 gemm # fence-before-release pattern. f" cute.arch.fence_view_async_tmem_load()\n" f" with cute.arch.elect_one():\n" f" {tcgen05_lifecycle.acc_pipeline}.consumer_release({tcgen05_lifecycle.acc_consumer_state})\n" ) ) + f"{simt_store_copy_source}" # Advance is a per-thread local state update, so it intentionally # stays outside elect_one; only the mbarrier release is elected. + ( "" if is_secondary_store else ( f"if {tcgen05_lifecycle.epi_active}:\n" + emit_pipeline_advance( tcgen05_lifecycle.acc_consumer_state, indent=" " ) ) ) ), ] # Workstream A Stage 4 (cycle 93, Path B): C-store producer->consumer edge. # Mirrors ``_emit_tcgen05_aux_pipeline_setup``'s SIMT PipelineAsync shape. # producer_arrive_count = ``epi_warp_count`` (per-warp: each of the 4 epi # warps arrives once via ``elect_one`` after R2S + fence); consumer_arrive # _count = 1 (the single store warp); num_stages = ``c_stage_count`` so the # store can lag up to ``c_stages`` subtiles behind the epi warps' T2R/R2S. # Producer (epi ``producer_commit``) AND consumer (store ``consumer_wait`` / # ``consumer_release``) BOTH land in this commit so the ring is never a # one-sided handshake that wedges only after wrapping the depth (the # cycle-2a partial-handshake lesson). c_store_edge_setup = ( [ ( f"{c_store_edge_barriers} = cute.arch.alloc_smem(" f"cutlass.Int64, cutlass.Int32({tcgen05_value.c_stage_count * 2}))" ), ( f"{c_store_edge_producer_group} = cutlass.pipeline.CooperativeGroup(" f"cutlass.pipeline.Agent.Thread, " f"cutlass.Int32({tcgen05_value.epi_warp_count}))" ), ( f"{c_store_edge_consumer_group} = cutlass.pipeline.CooperativeGroup(" "cutlass.pipeline.Agent.Thread, cutlass.Int32(1))" ), ( f"{c_store_edge} = cutlass.pipeline.PipelineAsync.create(" f"num_stages={tcgen05_value.c_stage_count}, " f"producer_group={c_store_edge_producer_group}, " f"consumer_group={c_store_edge_consumer_group}, " f"barrier_storage={c_store_edge_barriers})" ), ( f"{c_store_edge_producer_state} = cutlass.pipeline.make_pipeline_state(" f"cutlass.pipeline.PipelineUserType.Producer, " f"{tcgen05_value.c_stage_count})" ), ( f"{c_store_edge_consumer_state} = cutlass.pipeline.make_pipeline_state(" f"cutlass.pipeline.PipelineUserType.Consumer, " f"{tcgen05_value.c_stage_count})" ), ( f"{c_store_edge_release_state} = cutlass.pipeline.make_pipeline_state(" f"cutlass.pipeline.PipelineUserType.Consumer, " f"{tcgen05_value.c_stage_count})" ), ] if has_store_warp else [] ) tma_store_pipeline_setup = [ ( f"{epilog_sync_barrier} = cutlass.pipeline.NamedBarrier(" f"barrier_id=cutlass.Int32({tcgen05_value.epilog_sync_barrier_id}), " f"num_threads=cutlass.Int32({tcgen05_value.epi_warp_count * 32}))" ), *c_store_edge_setup, ( f"{c_pipeline_producer_group} = cutlass.pipeline.CooperativeGroup(" f"cutlass.pipeline.Agent.Thread, cutlass.Int32({tcgen05_value.epi_warp_count * 32}))" ), ( f"{c_pipeline} = cutlass.pipeline.PipelineTmaStore.create(" f"num_stages={tcgen05_value.c_stage_count}, " f"producer_group={c_pipeline_producer_group})" ), ] c_acquire_placement = state.device_function.config.get( TCGEN05_C_ACQUIRE_PLACEMENT_CONFIG_KEY, TCGEN05_C_ACQUIRE_PLACEMENT_PRE_LOOP, ) acc_wait_placement = state.device_function.config.get( TCGEN05_ACC_WAIT_PLACEMENT_CONFIG_KEY, TCGEN05_ACC_WAIT_PLACEMENT_SUBTILE_LOOP, ) c_store_mode = state.device_function.config.get( TCGEN05_C_STORE_MODE_CONFIG_KEY, TCGEN05_C_STORE_MODE_NORMAL, ) epilogue_layout = state.device_function.config.get( TCGEN05_EPILOGUE_LAYOUT_CONFIG_KEY, TCGEN05_EPILOGUE_LAYOUT_NORMAL, ) diagnose_first_c_acquire_in_loop = ( c_acquire_placement == TCGEN05_C_ACQUIRE_PLACEMENT_FIRST_IN_LOOP ) diagnose_later_c_acquire_before_barrier = ( c_acquire_placement == TCGEN05_C_ACQUIRE_PLACEMENT_LATER_BEFORE_BARRIER ) diagnose_acc_wait_before_subtile_loop = ( acc_wait_placement == TCGEN05_ACC_WAIT_PLACEMENT_BEFORE_SUBTILE_LOOP ) diagnose_skip_epilogue_store = ( c_store_mode == TCGEN05_C_STORE_MODE_SKIP_EPILOGUE_STORE ) diagnose_split_first_t2r = ( epilogue_layout == TCGEN05_EPILOGUE_LAYOUT_SPLIT_FIRST_T2R ) diagnose_split_acc_t2r_store_tail = ( epilogue_layout == TCGEN05_EPILOGUE_LAYOUT_SPLIT_ACC_T2R_STORE_TAIL ) diagnose_module_helper_acc_t2r = ( epilogue_layout == TCGEN05_EPILOGUE_LAYOUT_MODULE_HELPER_ACC_T2R ) diagnose_module_helper_store_tail = ( epilogue_layout == TCGEN05_EPILOGUE_LAYOUT_MODULE_HELPER_STORE_TAIL ) diagnose_split_epilogue_layout = ( diagnose_split_first_t2r or diagnose_split_acc_t2r_store_tail or diagnose_module_helper_acc_t2r or diagnose_module_helper_store_tail ) if tcgen05_pure_matmul_object is not None and diagnose_split_epilogue_layout: raise exc.BackendUnsupported( "cute", "tcgen05_strategy='pure_matmul_role_lifecycle' does not support " f"{TCGEN05_EPILOGUE_LAYOUT_CONFIG_KEY}={epilogue_layout!r}", ) if tcgen05_pure_matmul_object is not None and has_store_warp: # Workstream A Stage 4 (cycle 93) wires the store-warp tail split into # the non-pure ROLE_LOCAL_WITH_SCHEDULER path only. The pure-matmul # role-lifecycle object renders its own tail (``render_tma_store_tail # _region``) and is gated out here so a store warp never silently lands # on the unsplit pure tail (a correctness break). Stage 5 may wire it. raise exc.BackendUnsupported( "cute", "tcgen05_strategy='pure_matmul_role_lifecycle' does not support " "tcgen05_warp_spec_store_warps>0 (Workstream A Stage 4 wires the " "store-warp epilogue split into the non-pure WITH_SCHEDULER path)", ) # The diagnostic split / module-helper epilogue layouts route the # per-subtile tail through helpers that emit ONLY the ``if epi_active`` # half under ``has_store_warp`` (and ``module_helper_store_tail`` keeps the # OLD two-barrier warp-0 ``c_pipeline`` tail while the main path suppressed # the matching acquires) — so the C-store edge would have no consumer and # wedge once the ring wraps, or the ``c_pipeline`` commit/acquire counts # mismatch. They are diagnostic-only source-boundary layouts; production # uses the DEFAULT layout, so reject the combination loudly (same guard # class as the pure-matmul tail above). ``split_first_t2r`` routes through # ``tma_store_subtile_body`` and IS handled by the Stage-4 split, so it is # intentionally excluded. if has_store_warp and ( diagnose_split_acc_t2r_store_tail or diagnose_module_helper_acc_t2r or diagnose_module_helper_store_tail ): raise exc.BackendUnsupported( "cute", f"{TCGEN05_EPILOGUE_LAYOUT_CONFIG_KEY}={epilogue_layout!r} does not " "support tcgen05_warp_spec_store_warps>0 (the diagnostic split / " "module-helper epilogue layouts do not emit the store-warp tail " "half of the Workstream A Stage 4 split; use the default layout)", ) if tcgen05_pure_matmul_object is not None: pure_c_store_pipeline = Tcgen05TmaStorePipelineParams( c_pipeline=c_pipeline, warp_idx=tcgen05_value.warp_idx, ) tma_store_pipeline_tail = ( tcgen05_pure_matmul_object.render_c_store_pipeline_tail( pure_c_store_pipeline ) ) tma_store_first_subtile_acquire = ( tcgen05_pure_matmul_object.render_c_store_pre_loop_acquire_lines( pure_c_store_pipeline, first_c_acquire_in_loop=diagnose_first_c_acquire_in_loop, ) ) tma_store_loop_first_subtile_acquire = ( tcgen05_pure_matmul_object.render_c_store_loop_first_acquire( pure_c_store_pipeline, first_c_acquire_in_loop=diagnose_first_c_acquire_in_loop, ) ) tma_store_loop_later_subtile_acquire = ( tcgen05_pure_matmul_object.render_c_store_loop_later_acquire( pure_c_store_pipeline, later_c_acquire_before_barrier=( diagnose_later_c_acquire_before_barrier ), ) ) tma_store_loop_late_later_subtile_acquire = ( tcgen05_pure_matmul_object.render_c_store_loop_late_later_acquire( pure_c_store_pipeline, later_c_acquire_before_barrier=( diagnose_later_c_acquire_before_barrier ), ) ) else: # Workstream A Stage 4 (cycle 93, Path B): the ``c_pipeline`` # (PipelineTmaStore) producer lifecycle is per-warp — its # ``producer_acquire`` is a ``cp_async_bulk_wait_group`` and its # ``producer_commit`` a ``cp_async_bulk_commit_group``, both scoped to # the warp that ISSUES the TMA-D bulk copy. So when a store warp owns # the TMA-D, the entire ``c_pipeline`` lifecycle (acquire + commit + # tail) moves onto the store warp: its ``wait_group`` reuse guard lives # in the store-warp tail (after the TMA-D + commit, gating the lagged # release), the epi warps' historical store-prefix acquire lines are # dropped (the C-ring is gated by the cross-warp C-store edge instead), # and ``producer_tail`` (final ``wait_group(0)``) stays on the store warp. c_pipeline_owner_predicate = ( store_warp_predicate if has_store_warp else f"{tcgen05_value.warp_idx} == cutlass.Int32(0)" ) first_acquire_role_gate = ( f"{tcgen05_lifecycle.epi_active} and " f"{tcgen05_value.warp_idx} == cutlass.Int32(0)" ) tma_store_pipeline_tail = ( f"if {c_pipeline_owner_predicate}:\n {c_pipeline}.producer_tail()" ) tma_store_first_subtile_acquire = ( [] if (diagnose_first_c_acquire_in_loop or has_store_warp) else [ (f"if {first_acquire_role_gate}:\n {c_pipeline}.producer_acquire()") ] ) tma_store_loop_first_subtile_acquire = ( ( f" if _tcgen05_subtile == 0 and " f"{tcgen05_value.warp_idx} == cutlass.Int32(0):\n" f" {c_pipeline}.producer_acquire()\n" ) if (diagnose_first_c_acquire_in_loop and not has_store_warp) else "" ) tma_store_loop_later_subtile_acquire = ( "" if (diagnose_later_c_acquire_before_barrier or has_store_warp) else ( f" if _tcgen05_subtile != 0 and " f"{tcgen05_value.warp_idx} == cutlass.Int32(0):\n" f" {c_pipeline}.producer_acquire()\n" ) ) tma_store_loop_late_later_subtile_acquire = ( ( f" if _tcgen05_subtile != 0 and " f"{tcgen05_value.warp_idx} == cutlass.Int32(0):\n" f" {c_pipeline}.producer_acquire()\n" ) if (diagnose_later_c_acquire_before_barrier and not has_store_warp) else "" ) if diagnose_split_epilogue_layout: if not ( tcgen05_value.use_role_local_epi and tcgen05_value.use_tma_store_epilogue ): raise exc.BackendUnsupported( "cute", f"{TCGEN05_EPILOGUE_LAYOUT_CONFIG_KEY}={epilogue_layout!r} " "requires the " "role-local TMA-store tcgen05 epilogue", ) if not tcgen05_lifecycle.is_two_cta: raise exc.BackendUnsupported( "cute", f"{TCGEN05_EPILOGUE_LAYOUT_CONFIG_KEY}={epilogue_layout!r} requires " "CtaGroup.TWO", ) # Conservative proxy for the validated static-full CtaGroup.TWO # two-or-more-subtile envelope; the exact subtile count is only # available after the CUTLASS epilogue partitioning below. if tcgen05_value.bn < TCGEN05_TWO_CTA_BLOCK_N: raise exc.BackendUnsupported( "cute", f"{TCGEN05_EPILOGUE_LAYOUT_CONFIG_KEY}={epilogue_layout!r} is only " f"validated for CtaGroup.TWO block_n >= {TCGEN05_TWO_CTA_BLOCK_N}", ) # The diagnostic split-epilogue layouts emit the per-thread # chain into separate ``@cute.jit`` helpers (module-helper # layouts) or split source boundaries; the auxiliary-tensor # splice site needs per-tile aux setup that is not currently # plumbed into those helper signatures. Reject the # combination loudly so a user does not silently get a # kernel that drops the aux read. The diagnostic layouts # are only used for source-boundary investigation and do not # block any production path. if ( diagnose_module_helper_acc_t2r or diagnose_module_helper_store_tail or diagnose_split_first_t2r or diagnose_split_acc_t2r_store_tail ) and aux_steps_in_chain: raise exc.BackendUnsupported( "cute", "auxiliary-tensor epilogue (e.g. " "`out[tile] = (acc + residual[tile]).to(dtype)`) is " f"not plumbed through {TCGEN05_EPILOGUE_LAYOUT_CONFIG_KEY}=" f"{epilogue_layout!r}. The diagnostic split-epilogue " "layouts are only used for source-boundary " "investigation; drop the layout config to use the " "default production layout.", ) tma_store_split_first_subtile_acquire = ( ( f" if {tcgen05_value.warp_idx} == cutlass.Int32(0):\n" f" {c_pipeline}.producer_acquire()\n" ) if diagnose_first_c_acquire_in_loop else "" ) tma_store_pre_loop_acc_wait = ( [ ( f"if {tcgen05_lifecycle.epi_active}:\n" f" {tcgen05_lifecycle.acc_pipeline}.consumer_wait({tcgen05_lifecycle.acc_consumer_state})" ) ] if diagnose_acc_wait_before_subtile_loop and not is_secondary_store else [] ) tma_store_loop_acc_wait = ( "" if diagnose_acc_wait_before_subtile_loop or is_secondary_store else ( f" if _tcgen05_subtile == 0:\n" f" {tcgen05_lifecycle.acc_pipeline}.consumer_wait({tcgen05_lifecycle.acc_consumer_state})\n" ) ) tma_store_split_first_acc_wait = ( "" if diagnose_acc_wait_before_subtile_loop else ( f" {tcgen05_lifecycle.acc_pipeline}.consumer_wait({tcgen05_lifecycle.acc_consumer_state})\n" ) ) tma_store_split_tail_later_subtile_acquire = ( "" if diagnose_later_c_acquire_before_barrier else ( f" if {tcgen05_value.warp_idx} == cutlass.Int32(0):\n" f" {c_pipeline}.producer_acquire()\n" ) ) tma_store_split_tail_late_later_subtile_acquire = ( ( f" if {tcgen05_value.warp_idx} == cutlass.Int32(0):\n" f" {c_pipeline}.producer_acquire()\n" ) if diagnose_later_c_acquire_before_barrier else "" ) # Pyrefly does not preserve the non-None tcgen05_value narrowing inside # the nested source formatter, so keep local string aliases for attributes # read only by that closure. tcgen05_epi_active = tcgen05_lifecycle.epi_active tcgen05_acc_pipeline = tcgen05_lifecycle.acc_pipeline tcgen05_acc_consumer_state = tcgen05_lifecycle.acc_consumer_state tcgen05_warp_idx = tcgen05_value.warp_idx tcgen05_tma_store_atom = tcgen05_value.tma_store_atom # Locals for the store-warp tail closure (Pyrefly drops the non-None # tcgen05_value narrowing inside nested source formatters; see above). tcgen05_role_local_tile_counter = tcgen05_value.role_local_tile_counter def tma_store_acc_t2r_region_body( *, acc_wait: str, allow_aux_chain: bool = False ) -> str: """Return the t2r/math/store-source region. The aux prelude is rendered inside ``body`` immediately after the TMEM→register copy and before ``acc.load()`` / fused math. Keeping residual and bias fragments out of the acquire/T2R prefix shortens their live ranges through the R2S store path; the long-scoreboard overlap from the older hoist was less valuable on the packed Target8 epilogue than eliminating the resulting local-memory spills. """ assert allow_aux_chain or not aux_steps_in_chain, ( "diagnostic / module-helper layouts reject aux-tensor chains at " "validate time; use allow_aux_chain=True only for the default TMA " "store body that threads the aux LDG through the main T2R body." ) carrier = trs_racc store_target = trs_rd early_aux_prelude, late_prelude, rhs = _splice_acc_vec( carrier, " ", safe_direct_aux_with_full_tile=partial_tma_needs_full_tile_guard, ) # The secondary fan-out store reuses the still-live accumulator TMEM and # must not release it: the primary store already owns the accumulator # pipeline consumer release, and the one-shot teardown frees the TMEM # after every store has read it. acc_release = ( "" if is_secondary_store else ( f" if _tcgen05_subtile == {subtile_count} - 1:\n" f" cute.arch.fence_view_async_tmem_load()\n" f" with cute.arch.elect_one():\n" f" {tcgen05_acc_pipeline}.consumer_release({tcgen05_acc_consumer_state})\n" ) ) return ( f"{acc_wait}" f" {ttr_tacc_mn} = {ttr_tacc}[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n" f" cute.copy({tiled_copy_t2r}, {ttr_tacc_mn}, {ttr_racc})\n" f"{early_aux_prelude}" f"{late_prelude}" f" {acc_vec} = {rhs}\n" f"{acc_release}" f" {store_target}.store({acc_vec})\n" ) def tma_store_tail_params( *, late_later_subtile_acquire: str ) -> Tcgen05TmaStoreTailParams: return Tcgen05TmaStoreTailParams( late_later_subtile_acquire=late_later_subtile_acquire, epilog_sync_barrier=epilog_sync_barrier, c_buffer=c_buffer, c_buffer_expr=tma_c_buffer_expr, c_stage_count=tcgen05_c_stage_count, tiled_copy_r2s=tiled_copy_r2s, trs_rd=trs_rd, trs_sd=trs_sd, warp_idx=tcgen05_warp_idx, tma_store_atom=tcgen05_tma_store_atom, bsg_sd=bsg_sd, bsg_gd=bsg_gd, c_pipeline=c_pipeline, ) def tma_store_tail_region(*, late_later_subtile_acquire: str) -> str: if tcgen05_pure_matmul_object is not None: return tcgen05_pure_matmul_object.render_tma_store_tail_region( tma_store_tail_params( late_later_subtile_acquire=late_later_subtile_acquire ) ) if has_store_warp: # Path B epi-warp tail (inside ``if epi_active:``): acquire the # C-store edge stage (wait until the store warp released it, i.e. # the prior TMA-D reading this physical C-ring slot completed), # barrier-1 (intra-epi convergence), R2S, fence, then a C-store-edge # PRODUCER commit in place of the second CTA barrier. The TMA-D + # ``c_pipeline`` lifecycle move to the store warp's tail # (``tma_store_store_warp_tail_region``). The epi warps drop # straight into the next subtile's T2R after committing — that is # the store/T2R overlap. The producer cooperative group is per-warp # (count ``epi_warp_count``), so ``producer_acquire`` is a full-warp # wait on every epi warp and ``producer_commit`` arrives once per # warp via ``elect_one``. return ( f" {c_store_edge}.producer_acquire({c_store_edge_producer_state})\n" f" {epilog_sync_barrier}.arrive_and_wait()\n" f" {c_buffer} = ({tma_c_buffer_expr}) % cutlass.Int32({tcgen05_c_stage_count})\n" f" cute.copy({tiled_copy_r2s}, {trs_rd}, {trs_sd}[(None, None, None, {c_buffer})])\n" f" cute.arch.fence_view_async_shared()\n" f" with cute.arch.elect_one():\n" f" {c_store_edge}.producer_commit({c_store_edge_producer_state})\n" f" {c_store_edge_producer_state}.advance()\n" ) return ( f"{late_later_subtile_acquire}" f" {epilog_sync_barrier}.arrive_and_wait()\n" f" {c_buffer} = ({tma_c_buffer_expr}) % cutlass.Int32({tcgen05_c_stage_count})\n" f" cute.copy({tiled_copy_r2s}, {trs_rd}, {trs_sd}[(None, None, None, {c_buffer})])\n" f" cute.arch.fence_view_async_shared()\n" f" {epilog_sync_barrier}.arrive_and_wait()\n" f" if {tcgen05_warp_idx} == cutlass.Int32(0):\n" f" cute.copy({tcgen05_tma_store_atom}, {bsg_sd}[(None, {c_buffer})], {bsg_gd}[(None, cutlass.Int32(_tcgen05_subtile))])\n" f" {c_pipeline}.producer_commit()\n" ) def tma_store_store_warp_tail_region() -> str: # Path B store-warp tail (inside ``if store_warp_predicate:``): consume # the C-store edge, issue the TMA-D, and recycle the C-ring SMEM stage # with a ``c_stages - 1`` lagged release so a stage is only freed for # the epi producer AFTER its TMA-D read has provably completed. # # Ordering (per subtile, ``S`` = ``c_buffer``): # 1. ``consumer_wait``: the epi warps' R2S of stage ``S`` has landed. # 2. TMA-D ``S`` -> GMEM + ``c_pipeline.producer_commit`` (commit_group). # 3. ``c_pipeline.producer_acquire`` = ``cp_async_bulk_wait_group( # c_stages - 1, read=True)``: after committing store i this drains # every store except the ``c_stages - 1`` most recent, i.e. proves # store ``i - (c_stages - 1)`` finished reading its SMEM stage. # 4. release that proven-drained stage (the ``release_state``, which # lags the wait ``consumer_state`` by ``c_stages - 1``). Suppressed # for the first ``c_stages - 1`` global subtiles (nothing drained # yet); the trailing stages release naturally as later tiles' global # subtile index advances, and the final unreleased stores drain via # ``c_pipeline.producer_tail`` after the loop. lag = tcgen05_c_stage_count - 1 global_subtile = ( f"({tcgen05_role_local_tile_counter} * " f"cutlass.Int32({subtile_count}) + cutlass.Int32(_tcgen05_subtile))" if tcgen05_role_local_tile_counter else "cutlass.Int32(_tcgen05_subtile)" ) return ( f" {c_store_edge}.consumer_wait({c_store_edge_consumer_state})\n" f" {c_store_edge_consumer_state}.advance()\n" f" {c_buffer} = ({tma_c_buffer_expr}) % cutlass.Int32({tcgen05_c_stage_count})\n" f" cute.copy({tcgen05_tma_store_atom}, {bsg_sd}[(None, {c_buffer})], {bsg_gd}[(None, cutlass.Int32(_tcgen05_subtile))])\n" f" {c_pipeline}.producer_commit()\n" f" {c_pipeline}.producer_acquire()\n" f" if {global_subtile} >= cutlass.Int32({lag}):\n" f" with cute.arch.elect_one():\n" f" {c_store_edge}.consumer_release({c_store_edge_release_state})\n" f" {c_store_edge_release_state}.advance()\n" ) def tma_store_subtile_body( *, first_subtile_acquire: str, later_subtile_acquire: str, acc_wait: str, late_later_subtile_acquire: str, ) -> str: # The aux LDG depends on ``_tcgen05_subtile`` and stays inside # the per-subtile T2R body. It intentionally runs after the # c_pipeline acquire and TMEM→register copy so the residual/bias # fragments are not live through the store-prefix waits. t2r_body = tma_store_acc_t2r_region_body( acc_wait=acc_wait, allow_aux_chain=True, ) if has_store_warp: # Path B: the epi warps own T2R/R2S + the C-store producer commit; # the store warp (a SEPARATE ``if``, NOT under ``epi_active``) owns # the TMA-D + ``c_pipeline`` commit/acquire (its ``cp_async_bulk # _wait_group`` reuse guard) + the lagged edge release. The C-ring # acquire/commit move WHOLLY onto the store warp (PipelineTmaStore # is per-warp commit-group state), so the epi warps never touch # ``c_pipeline``; their store-prefix acquire lines are dropped. return ( f" if {tcgen05_epi_active}:\n" f"{t2r_body}" f"{tma_store_tail_region(late_later_subtile_acquire='')}" f" if {store_warp_predicate}:\n" f"{tma_store_store_warp_tail_region()}" ) return ( f" if {tcgen05_epi_active}:\n" f"{first_subtile_acquire}" f"{later_subtile_acquire}" f"{t2r_body}" f"{tma_store_tail_region(late_later_subtile_acquire=late_later_subtile_acquire)}" ) def indented_diagnostic_region(source: str) -> str: if not source: return " pass\n" return "".join(f" {line}" for line in source.splitlines(keepends=True)) def tma_store_helper_boundary_subtile_body( *, first_subtile_acquire: str, later_subtile_acquire: str, acc_wait: str, late_later_subtile_acquire: str, ) -> str: acquire_region = f"{first_subtile_acquire}{later_subtile_acquire}" acc_region = tma_store_acc_t2r_region_body(acc_wait=acc_wait) tail_region = tma_store_tail_region( late_later_subtile_acquire=late_later_subtile_acquire ) # These constant-true blocks are diagnostic source boundaries. The # generated-code AST round trip preserves them, while emitted comments # are not reliable line-info anchors. return ( f" if {tcgen05_epi_active}:\n" f" if True:\n" f"{indented_diagnostic_region(acquire_region)}" f" if True:\n" f"{indented_diagnostic_region(acc_region)}" f" if True:\n" f"{indented_diagnostic_region(tail_region)}" ) module_acc_t2r_helper_name = ( df.unique_name("tcgen05_acc_t2r_region") if diagnose_module_helper_acc_t2r else "" ) module_store_tail_helper_name = ( df.unique_name("tcgen05_store_tail_region") if diagnose_module_helper_store_tail else "" ) def tma_store_module_acc_t2r_helper_source(*, acc_wait: str) -> str: # Aux-tensor chains are rejected for the diagnostic module-helper # layouts (see the ``BackendUnsupported`` raise above), so # ``module_early_aux`` is always empty here. Concatenating it # with ``module_late_prelude`` preserves the prior flat-prelude # source order for unary chains and identity stores in this # diagnostic layout. module_early_aux, module_late_prelude, rhs = _splice_acc_vec( "tcgen05_tRS_rAcc", " " ) prelude = module_early_aux + module_late_prelude return ( "@cute.jit\n" f"def {module_acc_t2r_helper_name}(" "_tcgen05_subtile, " "tcgen05_acc_pipeline, " "tcgen05_acc_consumer_state, " "tcgen05_tTR_tAcc, " "tcgen05_tiled_copy_t2r, " "tcgen05_tTR_rAcc, " "tcgen05_tRS_rAcc, " "tcgen05_tRS_rD, " "tcgen05_subtile_count" "):\n" f"{acc_wait}" " tcgen05_tTR_tAcc_mn = tcgen05_tTR_tAcc[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n" " cute.copy(tcgen05_tiled_copy_t2r, tcgen05_tTR_tAcc_mn, tcgen05_tTR_rAcc)\n" f"{prelude}" f" tcgen05_acc_vec = {rhs}\n" " if _tcgen05_subtile == tcgen05_subtile_count - 1:\n" " cute.arch.fence_view_async_tmem_load()\n" " with cute.arch.elect_one():\n" " tcgen05_acc_pipeline.consumer_release(tcgen05_acc_consumer_state)\n" " tcgen05_tRS_rD.store(tcgen05_acc_vec)" ) def tma_store_module_acc_t2r_helper_call() -> str: return ( f" {module_acc_t2r_helper_name}(" f"_tcgen05_subtile, " f"{tcgen05_acc_pipeline}, " f"{tcgen05_acc_consumer_state}, " f"{ttr_tacc}, " f"{tiled_copy_t2r}, " f"{ttr_racc}, " f"{trs_racc}, " f"{trs_rd}, " f"{subtile_count})\n" ) def tma_store_module_helper_subtile_body( *, first_subtile_acquire: str, later_subtile_acquire: str, late_later_subtile_acquire: str, ) -> str: return ( f" if {tcgen05_epi_active}:\n" f"{first_subtile_acquire}" f"{later_subtile_acquire}" f"{tma_store_module_acc_t2r_helper_call()}" f"{tma_store_tail_region(late_later_subtile_acquire=late_later_subtile_acquire)}" ) def tma_store_module_tail_helper_source(*, late_later_subtile_acquire: str) -> str: return ( "@cute.jit\n" f"def {module_store_tail_helper_name}(" "_tcgen05_subtile, " "tcgen05_tma_c_buffer_index, " "tcgen05_epilog_sync_barrier, " "tcgen05_tiled_copy_r2s, " "tcgen05_tRS_rD, " "tcgen05_tRS_sD, " "tcgen05_tma_store_atom, " "tcgen05_bSG_sD, " "tcgen05_bSG_gD, " "tcgen05_c_pipeline, " "tcgen05_warp_idx" "):\n" f"{late_later_subtile_acquire}" " tcgen05_epilog_sync_barrier.arrive_and_wait()\n" f" tcgen05_c_buffer = tcgen05_tma_c_buffer_index % cutlass.Int32({tcgen05_c_stage_count})\n" " cute.copy(tcgen05_tiled_copy_r2s, tcgen05_tRS_rD, tcgen05_tRS_sD[(None, None, None, tcgen05_c_buffer)])\n" " cute.arch.fence_view_async_shared()\n" " tcgen05_epilog_sync_barrier.arrive_and_wait()\n" " if tcgen05_warp_idx == cutlass.Int32(0):\n" " cute.copy(tcgen05_tma_store_atom, tcgen05_bSG_sD[(None, tcgen05_c_buffer)], tcgen05_bSG_gD[(None, cutlass.Int32(_tcgen05_subtile))])\n" " tcgen05_c_pipeline.producer_commit()" ) def tma_store_module_tail_helper_call() -> str: return ( f" {module_store_tail_helper_name}(" f"_tcgen05_subtile, " f"{tma_c_buffer_expr}, " f"{epilog_sync_barrier}, " f"{tiled_copy_r2s}, " f"{trs_rd}, " f"{trs_sd}, " f"{tcgen05_tma_store_atom}, " f"{bsg_sd}, " f"{bsg_gd}, " f"{c_pipeline}, " f"{tcgen05_warp_idx})\n" ) def tma_store_module_tail_subtile_body( *, first_subtile_acquire: str, later_subtile_acquire: str, acc_wait: str, ) -> str: return ( f" if {tcgen05_epi_active}:\n" f"{first_subtile_acquire}" f"{later_subtile_acquire}" f"{tma_store_acc_t2r_region_body(acc_wait=acc_wait)}" f"{tma_store_module_tail_helper_call()}" ) if diagnose_split_first_t2r: tma_store_split_first_subtile_body = tma_store_subtile_body( first_subtile_acquire=tma_store_split_first_subtile_acquire, later_subtile_acquire="", acc_wait=tma_store_split_first_acc_wait, late_later_subtile_acquire="", ) tma_store_split_tail_subtile_body = tma_store_subtile_body( first_subtile_acquire="", later_subtile_acquire=tma_store_split_tail_later_subtile_acquire, acc_wait="", late_later_subtile_acquire=( tma_store_split_tail_late_later_subtile_acquire ), ) # Diagnostic-only scaffolding: reuse the one-indent subtile formatter # for a static first subtile without changing production source layout. # The tail loop maps split-loop indices back to logical subtile ids 1..N-1; # unroll_full=True keeps those subtile values compile-time constants. tma_store_subtile_loop = ( "if True:\n" f" _tcgen05_subtile = 0\n" f"{tma_store_split_first_subtile_body}" f"for _tcgen05_split_subtile in cutlass.range({subtile_count} - 1, unroll_full=True):\n" f" _tcgen05_subtile = _tcgen05_split_subtile + 1\n" f"{tma_store_split_tail_subtile_body}" ) elif diagnose_split_acc_t2r_store_tail: tma_store_helper_boundary_body = tma_store_helper_boundary_subtile_body( first_subtile_acquire=tma_store_loop_first_subtile_acquire, later_subtile_acquire=tma_store_loop_later_subtile_acquire, acc_wait=tma_store_loop_acc_wait, late_later_subtile_acquire=tma_store_loop_late_later_subtile_acquire, ) tma_store_subtile_loop = ( f"for _tcgen05_subtile in cutlass.range({subtile_count}, unroll_full=True):\n" f"{tma_store_helper_boundary_body}" ) elif diagnose_module_helper_acc_t2r: module_helper_acc_wait = ( "" if diagnose_acc_wait_before_subtile_loop else ( " if _tcgen05_subtile == 0:\n" " tcgen05_acc_pipeline.consumer_wait(tcgen05_acc_consumer_state)\n" ) ) state.codegen.module_statements.append( statement_from_string( tma_store_module_acc_t2r_helper_source(acc_wait=module_helper_acc_wait) ) ) tma_store_module_helper_body = tma_store_module_helper_subtile_body( first_subtile_acquire=tma_store_loop_first_subtile_acquire, later_subtile_acquire=tma_store_loop_later_subtile_acquire, late_later_subtile_acquire=tma_store_loop_late_later_subtile_acquire, ) tma_store_subtile_loop = ( f"for _tcgen05_subtile in cutlass.range({subtile_count}, unroll_full=True):\n" f"{tma_store_module_helper_body}" ) elif diagnose_module_helper_store_tail: module_tail_late_later_subtile_acquire = ( ( " if _tcgen05_subtile != 0 and " "tcgen05_warp_idx == cutlass.Int32(0):\n" " tcgen05_c_pipeline.producer_acquire()\n" ) if diagnose_later_c_acquire_before_barrier else "" ) state.codegen.module_statements.append( statement_from_string( tma_store_module_tail_helper_source( late_later_subtile_acquire=module_tail_late_later_subtile_acquire ) ) ) tma_store_module_tail_body = tma_store_module_tail_subtile_body( first_subtile_acquire=tma_store_loop_first_subtile_acquire, later_subtile_acquire=tma_store_loop_later_subtile_acquire, acc_wait=tma_store_loop_acc_wait, ) tma_store_subtile_loop = ( f"for _tcgen05_subtile in cutlass.range({subtile_count}, unroll_full=True):\n" f"{tma_store_module_tail_body}" ) else: tma_store_default_subtile_body = tma_store_subtile_body( first_subtile_acquire=tma_store_loop_first_subtile_acquire, later_subtile_acquire=tma_store_loop_later_subtile_acquire, acc_wait=tma_store_loop_acc_wait, late_later_subtile_acquire=tma_store_loop_late_later_subtile_acquire, ) tma_store_subtile_loop = ( f"for _tcgen05_subtile in cutlass.range({subtile_count}, unroll_full=True):\n" f"{tma_store_default_subtile_body}" ) tma_store_smem_setup = [ # Must match the wrapper-side `tcgen05_d_tma` TMA atom layout in # `helion/runtime/__init__.py`; both describe one D SMEM stage. ( f"{smem_d_layout} = cutlass.utils.blackwell_helpers.make_smem_layout_epi(" f"{target_dtype}, cutlass.utils.layout.LayoutEnum.ROW_MAJOR, " f"{epi_tile}, {tcgen05_value.c_stage_count})" ), ( f"{smem_d_ptr} = cute.arch.alloc_smem(" f"{target_dtype}, cute.cosize({smem_d_layout}.outer), alignment=1024)" ), ( f"{smem_d} = cute.make_tensor(" f"cute.recast_ptr({smem_d_ptr}, {smem_d_layout}.inner, dtype={target_dtype}), " f"{smem_d_layout}.outer)" ), *_rowvec_aux_smem_setup_lines(), ] tma_store_acc_layout_setup = [ ( f"{tacc} = cutlass.utils.gemm.sm100.transform_partitioned_tensor_layout(" f"{tcgen05_value.epi_acc_frag_base})" ), ] tma_store_role_invariant_setup = [ *tma_static_store_setup, *tma_store_smem_setup, *tma_store_acc_layout_setup, ] suppressed_store_body_core = [ ( # Diagnostic-only invalid-output mode. Keep the accumulator # pipeline draining so persistent kernels do not deadlock, but # suppress C-pipeline acquire/commit, R2S/SMEM work, and TMA D # stores to bound whether hot waits are tied to the C-store path. f"if {tcgen05_lifecycle.epi_active}:\n" f" {tcgen05_lifecycle.acc_pipeline}.consumer_wait({tcgen05_lifecycle.acc_consumer_state})\n" f" with cute.arch.elect_one():\n" f" {tcgen05_lifecycle.acc_pipeline}.consumer_release({tcgen05_lifecycle.acc_consumer_state})\n" + emit_pipeline_advance( tcgen05_lifecycle.acc_consumer_state, indent=" ", ) ) ] # C-input warp aux pipeline consumer-wait + lane-0-gated # consumer-release framing (``cute_plan.md`` §7.5.3.2 cycle 2b). # Gate-closed configs (default ``c_input_warps=0`` or no aux # residual) keep the historical GMEM aux path. When the gate # fires, the wait/release pair runs once per *subtile* of the # per-output-tile aux region: per-subtile staging keeps the # SMEM ring footprint at one ``epi_tile`` chunk per stage # rather than one ``(bm, bn)`` chunk, which is essential to # fit cluster_m=2 + ``tcgen05_ab_stages=3`` in the 228 KB # B200 SMEM cap. The wait begins the aux-load block emitted by # ``_aux_subtile_load_source`` (before any ``.load()`` from the # SMEM ring); the default TMA-store path now splices that block # after the c_pipeline acquire and T2R copy to keep aux fragments # out of the store-prefix live range. The release + ``advance`` # happen at the bottom of the same per-subtile iteration (after # the chain has consumed ``aux_loaded``). Lane-0 gating mirrors # the per-warp consumer arrive count # (``epi_warp_count``) allocated on the aux pipeline. # Static-full role-local stores have no dynamic full-tile branch, so all # C-store invariant setup can be hoisted once. Scheduler-backed hybrid # output-edge stores split full and fringe tiles into separate role-local # scheduler phases, which gives the full-tile phase the same hoist shape. # The monolithic hybrid path still keeps descriptor/SMEM layout setup # inside its dynamic full-tile branch. split_hybrid_tma_store_role = ( tcgen05_value.use_role_local_epi and tcgen05_value.use_tma_store_epilogue and tcgen05_value.tma_store_full_tiles_only and aux_matmul_plan is not None and aux_matmul_plan.has_scheduler_warp # CLC publishes a single hardware-scheduled stream today. The # full/edge split below requires the scheduler warp to publish two # static streams with a sentinel between them. and not aux_matmul_plan.is_clc_persistent and not diagnose_skip_epilogue_store ) hoist_tma_store_resources = ( tcgen05_value.use_role_local_epi and tcgen05_value.use_tma_store_epilogue and (not tcgen05_value.tma_store_full_tiles_only or split_hybrid_tma_store_role) and not diagnose_skip_epilogue_store ) hoist_hybrid_tma_store_pipeline = ( tcgen05_value.use_role_local_epi and tcgen05_value.use_tma_store_epilogue and tcgen05_value.tma_store_full_tiles_only and not split_hybrid_tma_store_role and not diagnose_skip_epilogue_store ) tma_store_body_setup_core = [ *(tma_static_store_setup if not hoist_tma_store_resources else []), *( tma_store_pipeline_setup if not (hoist_tma_store_resources or hoist_hybrid_tma_store_pipeline) else [] ), *(tma_store_smem_setup if not hoist_tma_store_resources else []), *_rowvec_aux_copy_lines(), *tma_store_first_subtile_acquire, *tma_tile_store_setup, ( f"{tcgc} = cutlass.utils.gemm.sm100.transform_partitioned_tensor_layout(" f"{tcgc_base})" ), ( f"{tcgc_planned} = cute.make_tensor(" f"{tcgc}.iterator, " f"cute.append(cute.append(cute.append({tcgc}.layout, {tcgen05_value.epilogue_rest_mode}), {tcgen05_value.epilogue_rest_mode}), {tcgen05_value.epilogue_rest_mode}))" ), *(tma_store_acc_layout_setup if not hoist_tma_store_resources else []), ( f"{tiled_copy_t2r}, {ttr_tacc_base}, {ttr_racc} = " "cutlass.utils.gemm.sm100.epilogue_tmem_copy_and_partition(" f"{kernel_desc}, {tcgen05_value.epi_tidx}, {tacc}, {tcgc_planned}, {epi_tile}, {tcgen05_lifecycle.is_two_cta!s})" ), (f"{ttr_rd} = cute.make_rmem_tensor({ttr_racc}.shape, {target_dtype})"), ( f"{tiled_copy_r2s}, {trs_rd}, {trs_sd} = " "cutlass.utils.gemm.sm100.epilogue_smem_copy_and_partition(" f"{kernel_desc}, {tiled_copy_t2r}, {ttr_rd}, " f"{tcgen05_value.epi_tidx}, {smem_d})" ), f"{trs_racc} = {tiled_copy_r2s}.retile({ttr_racc})", f"{tcgc_epi} = cute.flat_divide({tcgc_planned}, {epi_tile})", # Per-aux-step partitioning lines (one chain per auxiliary # tensor). No-op when the chain has no aux steps; the TMA # path requires an explicit ``thr_copy_t2r`` slice because # (unlike the SIMT path) the TMA path does not otherwise # create one — the t2r partition is consumed directly by # the SMEM-staged store, never via partition_D. The aux # load needs partition_D to compute a per-thread GMEM read # for the auxiliary tile so we create the slice here. # When the C-input warp productive-body gate is open the # source switches from per-tile GMEM to the per-subtile # SMEM ring stage (see ``_aux_tile_setup_lines`` SMEM # branch); the partition pipeline is layout-only and # compiles unchanged, and the per-subtile ``consumer_wait`` # / lane-0-gated ``consumer_release`` are emitted by # ``_aux_subtile_load_source`` inside the per-subtile loop. *_aux_tile_setup_lines( thr_copy_t2r_var=thr_copy_t2r, define_thr_copy_t2r=True, retile_for_r2s=True, ), ( f"{bsg_sd}, {bsg_gd_partitioned} = cute.nvgpu.cpasync.tma_partition(" f"{tcgen05_value.tma_store_atom}, 0, cute.make_layout(1), " f"cute.group_modes({smem_d}, 0, 2), " f"cute.group_modes({tcgc_epi}, 0, 2))" ), ( f"{bsg_gd} = {bsg_gd_partitioned}[" f"(None, None, None, cutlass.Int32(0), cutlass.Int32(0), cutlass.Int32(0))]" ), f"{bsg_gd} = cute.group_modes({bsg_gd}, 1, cute.rank({bsg_gd}))", ( f"{ttr_tacc_stage} = {ttr_tacc_base}[" f"(None, None, None, None, None, {tcgen05_acc_stage_index_expr})]" ), f"{ttr_tacc} = cute.group_modes({ttr_tacc_stage}, 3, cute.rank({ttr_tacc_stage}))", f"{subtile_count} = cutlass.const_expr(cute.size({ttr_tacc}.shape, mode=[3]))", *tma_store_pre_loop_acc_wait, ] # Warp 0 pre-acquires the first TMA-store SMEM stage before per-tile # C-store setup. The subtile loop acquires only later stages, so C-stage # waits can overlap setup, the first acc-pipeline wait, and the other epi # warps' TMEM load/conversion work on later subtile iterations. Most # alternate placements are diagnostics, but the edge+K-tail production seed # uses the measured first_in_loop / before_subtile_loop pair. # tcgen05_c_acquire_placement=first_in_loop moves only that first acquire # into the subtile loop; later acquires and the accumulator wait keep their # default order. The diagnostic later_before_barrier placement keeps the # first acquire in production position and moves only later-subtile # acquires just before the first epilogue barrier. # tcgen05_acc_wait_placement=before_subtile_loop keeps both C acquire sites # in production position and moves only the accumulator consumer wait # before the subtile loop. A CTA-scoped named barrier ensures all epi warps # have observed warp 0's acquire before they write SMEM; a second barrier # ensures the SMEM writes and Quack-style async-shared fence are visible # before warp 0 issues and commits the TMA operation. Compute the SMEM ring # index after the first barrier so the acquire/barrier/index order stays # aligned with Quack's TMA-store epilogue. # The accumulator consumer state advances after the loop, matching Quack's # call-site ordering while preserving the early release. After warp 0 # commits the TMA store, the next subtile's producer_acquire plus the first # named barrier are enough to keep all epi warps from writing a reused SMEM # stage too early. Avoiding a post-commit barrier matches Quack's epilogue # loop. The split_first_t2r diagnostic emits the first static subtile as a # standalone source block, then loops over later subtile work. It is a # layout discriminator for the hot acc-wait/T2R SASS row; the default # production source shape remains the single loop. # Advance is a per-thread local state update, so it intentionally stays # outside elect_one; only the mbarrier release is elected. tma_store_pipeline_tail_lines = ( [tma_store_pipeline_tail] if not (hoist_tma_store_resources or hoist_hybrid_tma_store_pipeline) else [] ) if tcgen05_pure_matmul_object is not None: tma_store_body_core = tcgen05_pure_matmul_object.build_tma_store_body_core( Tcgen05TmaStoreBodyCoreParams( setup_lines=tma_store_body_setup_core, subtile_loop=Tcgen05TmaStoreSubtileLoopParams( subtile_count=subtile_count, epi_active=tcgen05_epi_active, first_subtile_acquire=tma_store_loop_first_subtile_acquire, later_subtile_acquire=tma_store_loop_later_subtile_acquire, acc_t2r_region_body=tma_store_acc_t2r_region_body( acc_wait=tma_store_loop_acc_wait, allow_aux_chain=True, ), tail=tma_store_tail_params( late_later_subtile_acquire=( tma_store_loop_late_later_subtile_acquire ), ), ), pipeline_tail_lines=tma_store_pipeline_tail_lines, ) ) else: # The secondary fan-out store does not own the accumulator consumer # state, so it must not advance it (the primary store advances once). tma_store_acc_advance = ( "" if is_secondary_store else ( f"if {tcgen05_lifecycle.epi_active}:\n" + emit_pipeline_advance( tcgen05_lifecycle.acc_consumer_state, indent=" ", ) ) ) tma_store_body_core = [ *tma_store_body_setup_core, tma_store_subtile_loop + tma_store_acc_advance, *tma_store_pipeline_tail_lines, ] tma_store_full_tile_body_core = list(tma_store_body_core) if ( tcgen05_value.tma_store_full_tiles_only and tcgen05_value.role_local_tile_counter ): tma_store_full_tile_body_core.append( f"{tcgen05_value.role_local_tile_counter} = " f"{tcgen05_value.role_local_tile_counter} + cutlass.Int32(1)" ) tma_store_body_source = "\n".join(tma_store_full_tile_body_core) simt_store_body_source = "\n".join(simt_store_body_core) hybrid_tma_store_body_core = [ f"{full_tile} = {full_tile_expr}", ( f"if {full_tile}:\n" f"{textwrap.indent(tma_store_body_source, ' ')}\n" "else:\n" f"{textwrap.indent(simt_store_body_source, ' ')}" ), ] if diagnose_skip_epilogue_store: store_body_core = suppressed_store_body_core elif tcgen05_value.tma_store_full_tiles_only: store_body_core = hybrid_tma_store_body_core elif tcgen05_value.use_tma_store_epilogue: store_body_core = tma_store_body_core else: store_body_core = simt_store_body_core main_stmts: list[ast.AST] if tcgen05_value.use_role_local_epi: # These setup statements intentionally remain virtual-pid-independent. # The persistent splitter hoists pipeline state before the role-local # scheduler loops. Scheduler-backed hybrid stores keep descriptor and # layout Python objects inside the epilogue role prelude so they do # not leak across unrelated dynamic warp-role ``if`` regions. tma_store_pipeline_hoisted_stmts = ( [statement_from_string(line) for line in tma_store_pipeline_setup] if (hoist_tma_store_resources or hoist_hybrid_tma_store_pipeline) else [] ) tma_store_role_invariant_stmts = ( [statement_from_string(line) for line in tma_store_role_invariant_setup] if hoist_tma_store_resources else [] ) if split_hybrid_tma_store_role: tma_store_hoisted_stmts = tma_store_pipeline_hoisted_stmts elif hoist_tma_store_resources or hoist_hybrid_tma_store_pipeline: tma_store_hoisted_stmts = [ *tma_store_pipeline_hoisted_stmts, *tma_store_role_invariant_stmts, ] else: tma_store_hoisted_stmts = [] if tcgen05_pure_matmul_object is not None: assert not split_hybrid_tma_store_role, ( "pure lifecycle is admitted only for static-full pure matmul" ) assert not hoist_hybrid_tma_store_pipeline, ( "pure lifecycle does not use hybrid edge TMA-store pipeline setup" ) main_stmts = tcgen05_pure_matmul_object.emit_store_role_stmts( df.cute_state, tma_store_hoisted_stmts=tma_store_hoisted_stmts, store_body_core=store_body_core, ) elif split_hybrid_tma_store_role: sync_before_stmt = statement_from_string("cute.arch.sync_threads()") sync_after_stmt = statement_from_string("cute.arch.sync_threads()") full_main_stmt = statement_from_string( "if True:\n" + textwrap.indent("\n".join(tma_store_full_tile_body_core), " ") ) edge_main_stmt = statement_from_string( "if True:\n" + textwrap.indent("\n".join(simt_store_body_core), " ") ) df.cute_state.register_tcgen05_per_tile_stmts( [sync_before_stmt, full_main_stmt, edge_main_stmt, sync_after_stmt] ) df.cute_state.register_tcgen05_epi_role_full_edge_stmts( full_tile_stmts=[full_main_stmt], edge_tile_stmts=[edge_main_stmt], ) # `cute.arch.alloc_smem` is a CuTe DSL static allocation even # though it is represented as a statement. Keeping the descriptor, # layout, and allocation statements in the epi-role prelude scopes # CuTe Python objects away from unrelated warp-role branches # without making the shared-memory reservation data-dependent on # the runtime epi-warp predicate. df.cute_state.register_tcgen05_epi_role_prelude_stmts( tma_store_role_invariant_stmts ) main_stmts = [ *tcgen05_acc_stage_index_top_level_stmts, *tma_store_hoisted_stmts, *tma_store_role_invariant_stmts, sync_before_stmt, full_main_stmt, edge_main_stmt, sync_after_stmt, ] else: sync_before_stmt = statement_from_string("cute.arch.sync_threads()") sync_after_stmt = statement_from_string("cute.arch.sync_threads()") main_stmt = statement_from_string( "if True:\n" + textwrap.indent("\n".join(store_body_core), " ") ) df.cute_state.register_tcgen05_per_tile_stmts( [sync_before_stmt, main_stmt, sync_after_stmt] ) df.cute_state.register_tcgen05_epi_role_stmts([main_stmt]) main_stmts = [ *tcgen05_acc_stage_index_top_level_stmts, *tma_store_hoisted_stmts, sync_before_stmt, main_stmt, sync_after_stmt, ] else: store_body = [ "cute.arch.sync_threads()", *store_body_core, "cute.arch.sync_threads()", ] main_stmt = statement_from_string( "if True:\n" + textwrap.indent("\n".join(store_body), " ") ) main_stmts = [*tcgen05_acc_stage_index_top_level_stmts, main_stmt] # Pipeline drain + TMEM dealloc are one-shot cleanup. They must run # AFTER all tiles have been processed (in the persistent path) and # naturally land at the end of the kernel in the non-persistent path. # Keep them as separate statements so the persistent splitter can # extract them via the post-loop registration below. tma_store_post_loop_tail = "" if hoist_tma_store_resources or hoist_hybrid_tma_store_pipeline: # Role-local persistent epilogues reuse the C-store pipeline across # scheduler-recycled work tiles. Draining it inside each tile would # serialize the next tile's epilogue against this tile's TMA stores. # The tail must run before TMEM dealloc setup below. tma_store_post_loop_tail = tma_store_pipeline_tail if is_secondary_store: # The matmul drain + TMEM-free teardown is one-shot and owned by the # primary store; the secondary fan-out store emits only its store body. post_loop_stmts = [] elif tcgen05_pure_matmul_object is not None: post_loop_stmts = tcgen05_pure_matmul_object.emit_store_post_loop_stmts( df.cute_state, candidate_names, tma_store_pipeline_tail=tma_store_post_loop_tail, ) else: post_loop_lines = tcgen05_lifecycle.render_store_post_loop_lines( tma_store_pipeline_tail=tma_store_post_loop_tail ) post_loop_stmts = [statement_from_string(line) for line in post_loop_lines] df.cute_state.register_tcgen05_post_loop_stmts(post_loop_stmts) return [*main_stmts, *post_loop_stmts] def _codegen_cute_store_loaded_index_trailing_slices( state: CodegenState, tensor: torch.Tensor, subscript: list[object] | tuple[object, ...], ast_subscript: list[object] | tuple[object, ...], extra_mask: ast.AST | None, value_node: torch.fx.Node, ) -> ast.AST | None: from .._compiler.ast_extension import create if value_node.target is not load or len(value_node.args) < 2: return None source_tensor_node = value_node.args[0] if not isinstance(source_tensor_node, torch.fx.Node): return None source_tensor = source_tensor_node.meta.get("val") if not isinstance(source_tensor, torch.Tensor): return None source_subscript = value_node.args[1] if not isinstance(source_subscript, (list, tuple)) or not source_subscript: return None indexer = source_subscript[0] if not isinstance(indexer, torch.fx.Node): return None indexer_value = indexer.meta.get("val") if not isinstance(indexer_value, torch.Tensor) or indexer_value.ndim == 0: return None trailing_source = [*source_subscript[1:]] if not trailing_source or not all(idx == slice(None) for idx in trailing_source): return None if len(subscript) != indexer_value.ndim + len(trailing_source): return None trailing_store = subscript[indexer_value.ndim :] if not all(idx == slice(None) for idx in trailing_store): return None ast_source_subscript = list( map_arg(tuple(source_subscript), lambda arg: state.env[arg]) ) index_exprs = _cute_index_exprs( state, [indexer_value], [ast_source_subscript[0]], tensor=source_tensor, inactive_singleton_slice_expr="0", ) if len(index_exprs) != 1: return None prefix_subscript = [*subscript[: indexer_value.ndim]] prefix_ast_subscript = [*ast_subscript[: indexer_value.ndim]] target_prefix = _cute_index_exprs( state, prefix_subscript, prefix_ast_subscript, tensor=tensor, inactive_singleton_slice_expr="0", ) if len(target_prefix) != indexer_value.ndim: return None env = CompileEnvironment.current() index_dtype = env.backend.dtype_str(env.index_dtype) source_loop_vars = [ state.device_function.new_var("slice_idx", dce=True) for _ in trailing_source ] source_indices = [ index_exprs[0], *[f"{index_dtype}({var})" for var in source_loop_vars], ] target_indices = [ *target_prefix, *[f"{index_dtype}({var})" for var in source_loop_vars], ] if len(source_indices) != source_tensor.ndim or len(target_indices) != tensor.ndim: return None source_name = state.device_function.tensor_arg(source_tensor).name target_name = state.device_function.tensor_arg(tensor).name source_dtype = env.backend.dtype_str(source_tensor.dtype) target_dtype = env.backend.dtype_str(tensor.dtype) source_mask = _cute_combined_mask( state, [indexer_value], None, tensor=source_tensor, ) target_mask = _cute_combined_mask( state, prefix_subscript, extra_mask, tensor=tensor, ) masks = [mask for mask in (source_mask, target_mask) if mask is not None] mask_expr = " and ".join(f"({mask})" for mask in masks) if masks else None load_expr = f"{source_name}[{', '.join(source_indices)}]" if mask_expr is not None: load_expr = f"({load_expr} if {mask_expr} else {source_dtype}(0))" store_expr = ( f"{target_name}.__setitem__({_cute_index_tuple(target_indices)}, " f"{env.backend.ast_to_dtype_expr(load_expr, target_dtype)})" ) if mask_expr is not None: store_expr = f"{store_expr} if {mask_expr} else None" tensor_dim = 0 for idx in prefix_subscript: block_id = None if isinstance(idx, torch.SymInt): block_id = env.get_block_id(idx) elif idx == slice(None) and tensor_dim < tensor.ndim: block_id = next( ( candidate for candidate in _matching_block_ids(env, tensor.shape[tensor_dim]) if candidate in state.codegen.active_device_loops ), None, ) tensor_dim += 1 if block_id is None: continue axis = None grid_state = state.codegen.current_grid_state if grid_state is not None: axis = grid_state.block_thread_axes.get(block_id) if axis is None: loops = state.codegen.active_device_loops.get(block_id) if loops: axis = loops[-1].block_thread_axes.get(block_id) if axis is None or not (0 <= axis < 3): continue block_size = state.device_function.resolved_block_size(block_id) if not isinstance(block_size, int): continue state.codegen.max_thread_block_dims[axis] = max( state.codegen.max_thread_block_dims[axis], block_size, ) state.codegen.referenced_thread_block_dims[axis] = max( state.codegen.referenced_thread_block_dims[axis], block_size, ) stmt: ast.stmt = create(ast.Expr, value=expr_from_string(store_expr)) for loop_var, source_pos in reversed( [*zip(source_loop_vars, range(1, len(source_subscript)), strict=True)] ): extent = _cute_tensor_dim_size_expr(state, source_tensor, source_pos) stmt = create( ast.For, target=create(ast.Name, id=loop_var, ctx=ast.Store()), iter=expr_from_string(f"range({extent})"), body=[stmt], orelse=[], type_comment=None, ) state.add_statement(stmt) return ast.Constant(value=None) def _cute_expand_broadcast_dim(value_node: torch.fx.Node) -> int | None: """Return the dim an ``aten.expand`` broadcasts (input size 1 -> >1). Returns ``None`` unless ``value_node`` is an ``aten.expand`` whose value has exactly one broadcast dimension — i.e. the expanded value carries a stride-0 mode at exactly one position whose pre-expand extent was 1. This is the signal that the stored value replicates one source element across that dim. """ if value_node.target is not torch.ops.aten.expand.default: return None input_arg = value_node.args[0] if not isinstance(input_arg, torch.fx.Node): return None out_val = value_node.meta.get("val") in_val = input_arg.meta.get("val") if not isinstance(out_val, torch.Tensor) or not isinstance(in_val, torch.Tensor): return None if out_val.ndim != in_val.ndim: return None env = CompileEnvironment.current() broadcast_dims = [ dim for dim in range(out_val.ndim) if env.known_equal(in_val.shape[dim], 1) and not env.known_equal(out_val.shape[dim], 1) and out_val.stride(dim) == 0 ] if len(broadcast_dims) != 1: return None return broadcast_dims[0] def _cute_block_tile_begin_expr(state: CodegenState, block_id: int) -> str | None: """Return the *per-block* tile start for a tile mapped onto a thread axis. In the CuTe SIMT model a tile dimension is spread across a thread axis, so the strategy's ``index_var`` is the per-*thread* global index (``pid * block + thread_idx[axis]``). Subtracting the thread-local coordinate yields the per-*block* tile base (``pid * block``), shared by every thread in the tile — the correct anchor for a broadcast lane loop. Returns ``None`` when the block id has no active thread axis in this scope. """ from .._compiler.cute.cute_reshape import _grid_local_coord_expr loops = state.codegen.active_device_loops.get(block_id) if not loops: return None loop_state = loops[-1] thread_axis = loop_state.block_thread_axes.get(block_id) global_index = loop_state.strategy.index_var(block_id) if thread_axis is None or global_index is None: return None local_coord = _grid_local_coord_expr(state.codegen, block_id, thread_axis) return state.codegen.lift( expr_from_string(f"({global_index}) - ({local_coord})"), dce=True, prefix="tile_begin", ).id def _cute_unsqueeze_expand_load_source( value_node: torch.fx.Node, broadcast_dim: int ) -> torch.fx.Node | None: """Return the ``hl.load`` feeding ``expand(val[..., None, ...])``. Walks ``value_node`` (an ``aten.expand``) back through a single unsqueeze-style subscript op (``val[:, None, :]`` inserting the broadcast dim) to the originating ``hl.load``. Returns ``None`` unless the chain is exactly that shape, so the caller falls back to the load-agnostic path. """ from .view_ops import subscript as subscript_op inner = value_node.args[0] if not isinstance(inner, torch.fx.Node): return None if inner.op == "call_function" and inner.target is subscript_op: index_arg = inner.args[1] if len(inner.args) > 1 else None if not isinstance(index_arg, (list, tuple)): return None # Exactly one ``None`` (the inserted broadcast dim) at ``broadcast_dim``. none_positions = [pos for pos, entry in enumerate(index_arg) if entry is None] if none_positions != [broadcast_dim]: return None load_node = inner.args[0] else: load_node = inner if ( isinstance(load_node, torch.fx.Node) and load_node.op == "call_function" and load_node.target is load and len(load_node.args) >= 2 ): return load_node return None def _codegen_cute_store_expand_broadcast_tile( state: CodegenState, tensor: torch.Tensor, subscript: list[object] | tuple[object, ...], ast_subscript: list[object] | tuple[object, ...], value: ast.AST, extra_mask: ast.AST | None, value_node: torch.fx.Node, ) -> ast.AST | None: """Lower a store whose value is broadcast across a reused tile dimension. Handles the pattern:: val = hl.load(src, [tile, hl.arange(k)]) # (block, k) val_3d = val[:, None, :].expand(block, block, k) # stride-0 middle dim hl.store(out, [idx[tile], tile.index, hl.arange(k)], val_3d) Here ``tile`` appears twice in the store index — once as a tensor indexer (``idx[tile]``) and once as the bare tile index (``tile.index``) — while the value is broadcast (stride 0) along the second (``tile.index``) position. The generic SIMT store lowers both positions onto ``tile``'s single thread axis, so each thread only writes the ``a == b`` diagonal of the ``(block, block)`` block. Instead emit a sequential lane loop over the broadcast position so a thread holding ``val[a]`` writes the full ``out[idx[a], begin+b, :]`` row for every ``b`` in the tile, filling the block. ``val`` is broadcast, so every lane reads the same per-thread register. Returns ``None`` (a strict no-op) unless every gate matches, so existing kernels are byte-for-byte unchanged. """ env = CompileEnvironment.current() broadcast_dim = _cute_expand_broadcast_dim(value_node) if broadcast_dim is None: return None if broadcast_dim >= len(subscript): return None broadcast_idx = subscript[broadcast_dim] # The broadcast position must be a bare tile index (a SymInt block id), and # that same block id must be reused by another (tensor) index position — the # collision the generic path mis-handles. if not isinstance(broadcast_idx, torch.SymInt): return None broadcast_block_id = env.get_block_id(broadcast_idx) if broadcast_block_id is None: return None block_size = state.device_function.resolved_block_size(broadcast_block_id) if not isinstance(block_size, int) or block_size <= 1: return None reused = False for pos, idx in enumerate(subscript): if pos == broadcast_dim: continue if isinstance(idx, torch.Tensor): for dim_size in idx.shape: if broadcast_block_id in _matching_block_ids(env, dim_size): reused = True break if reused: break if not reused: return None # Walk the value chain ``expand -> unsqueeze(None) -> load`` to recover the # source load. The stored value is a per-thread register holding ``val[a, c]`` # whose coordinates live on the *load*'s thread axes; the store's own free # ``hl.arange`` index entries are distinct nodes that the synthetic-axis # machinery assigns to *different* axes. Reusing the load's coordinate for # those non-broadcast positions keeps the register and the store address on # the same thread axis (otherwise thread ``(a, c_load, c_store)`` would write # ``out[..., c_store] = val[a, c_load]`` for ``c_load != c_store``). load_node = _cute_unsqueeze_expand_load_source(value_node, broadcast_dim) load_coords: list[str] | None = None load_subscript_proxy: tuple[object, ...] | None = None if load_node is not None: load_tensor_node = load_node.args[0] load_subscript = load_node.args[1] if isinstance(load_tensor_node, torch.fx.Node) and isinstance( load_subscript, (list, tuple) ): load_tensor = load_tensor_node.meta.get("val") if isinstance(load_tensor, torch.Tensor): load_subscript_proxy = tuple( map_arg([*load_subscript], lambda arg: arg.meta["val"]) ) load_subscript_ast = map_arg( [*load_subscript], lambda arg: state.env[arg] ) load_coords = _cute_index_exprs( state, [*load_subscript_proxy], [*load_subscript_ast], tensor=load_tensor, inactive_singleton_slice_expr="0", ) if len(load_coords) != load_tensor.ndim: load_coords = None load_subscript_proxy = None index_exprs = _cute_index_exprs( state, subscript, ast_subscript, tensor=tensor, inactive_singleton_slice_expr="0", ) if len(index_exprs) != tensor.ndim or "None" in index_exprs: return None # Re-align each non-broadcast free-``hl.arange`` store position onto the # load's matching coordinate. Value dim ``d`` maps to load dim ``d`` before # the unsqueezed broadcast dim and ``d - 1`` after it. Only positions where # *both* the store and the matching load entry are free ``hl.arange`` index # tensors are remapped — a tensor *indexer* (``idx[tile]``) keeps its own # coordinate. if load_coords is not None and load_subscript_proxy is not None: for pos, idx in enumerate(subscript): if pos == broadcast_dim or not isinstance(idx, torch.Tensor): continue load_dim = pos if pos < broadcast_dim else pos - 1 if not (0 <= load_dim < len(load_coords)): continue if isinstance(load_subscript_proxy[load_dim], torch.Tensor): index_exprs[pos] = load_coords[load_dim] # Replace the broadcast position's coordinate (currently the reused tile's # per-thread global index) with ``block_begin + lane`` so the lane loop sweeps # the full tile block, identically for every thread in the tile. ``block_begin`` # is the *per-block* tile start (``global_index - local_coord``); in the CuTe # SIMT model the tile is mapped onto a thread axis, so the bare offset var # still carries the per-thread ``thread_idx`` lane and must be stripped. block_begin = _cute_block_tile_begin_expr(state, broadcast_block_id) if block_begin is None: return None lane_var = state.device_function.new_var("bcast_lane", dce=True) index_dtype = env.index_type() broadcast_coord = f"({block_begin}) + {index_dtype}({lane_var})" index_exprs[broadcast_dim] = broadcast_coord backend = env.backend target_dtype = backend.dtype_str(tensor.dtype) tensor_name = state.device_function.tensor_arg(tensor).name value = expr_from_string( backend.ast_to_dtype_expr("{value}", target_dtype), value=value, ) store_expr = expr_from_string( _cute_scalar_store_expr(tensor_name, index_exprs, "{value}"), value=value, ) # Base mask excludes the broadcast position (its bound is enforced by the lane # bound below); other positions keep their tile/tensor masks. base_subscript = [ slice(None) if pos == broadcast_dim else idx for pos, idx in enumerate(subscript) ] mask_expr = _cute_combined_mask(state, base_subscript, extra_mask, tensor=tensor) dim_size = _cute_tensor_dim_size_expr(state, tensor, broadcast_dim) lane_bound = f"({broadcast_coord}) < {dim_size}" mask_expr = lane_bound if mask_expr is None else f"({mask_expr}) and {lane_bound}" from .._compiler.ast_extension import create mask_ast = expr_from_string(mask_expr) assert isinstance(mask_ast, ast.expr) assert isinstance(store_expr, ast.expr) body_stmt: ast.stmt = ast.fix_missing_locations( ast.If( test=mask_ast, body=[ast.Expr(value=store_expr)], orelse=[], ) ) loop_stmt = create( ast.For, target=create(ast.Name, id=lane_var, ctx=ast.Store()), iter=expr_from_string(f"range({block_size})"), body=[body_stmt], orelse=[], type_comment=None, ) state.add_statement(loop_stmt) return ast.Constant(value=None) def _codegen_cute_store_permute_lane_loops( state: CodegenState, tensor: torch.Tensor, subscript: list[object] | tuple[object, ...], ast_subscript: list[object] | tuple[object, ...], value: ast.AST, extra_mask: ast.AST | None, value_node: torch.fx.Node, ) -> ast.AST | None: from .._compiler.cute.cute_reshape import _coords_from_flat_index from .._compiler.cute.cute_reshape import _flat_index_from_coords from .._compiler.cute.cute_reshape import _get_dim_local_coord from .._compiler.cute.cute_reshape import _get_tile_shape from .._compiler.cute.cute_reshape import _permute_reorders_active_dims from .._compiler.cute.cute_reshape import _shape_op_needs_materialization from .._compiler.cute.cute_reshape import _store_permute_info from .._compiler.generate_ast import GenerateAST from .._compiler.tile_strategy import DeviceGridState if not isinstance(state.codegen, GenerateAST): return None grid_state = state.codegen.current_grid_state if not isinstance(grid_state, DeviceGridState) or not grid_state.has_lane_loops(): return None if _shape_op_needs_materialization(value_node): return None index_exprs = _cute_index_exprs( state, subscript, ast_subscript, tensor=tensor, inactive_singleton_slice_expr="0", ) index_tuple = _cute_index_tuple(index_exprs) mask_expr = _cute_combined_mask(state, subscript, extra_mask, tensor=tensor) tensor_name = state.device_function.tensor_arg(tensor).name input_node: torch.fx.Node output_val = value_node.meta.get("val") read_flat: str input_shape: list[int] info = _store_permute_info(value_node) if info is not None: input_node, perm = info input_val = input_node.meta.get("val") if not isinstance(input_val, torch.Tensor) or not isinstance( output_val, torch.Tensor ): return None if not _permute_reorders_active_dims(state.codegen, input_val, perm): return None source_tensor_node = input_node.args[0] if input_node.args else None source_extra_mask = input_node.args[2] if len(input_node.args) > 2 else None if ( input_node.op == "call_function" and input_node.target is load and isinstance(source_tensor_node, torch.fx.Node) and source_extra_mask is None ): source_tensor = source_tensor_node.meta.get("val") if isinstance(source_tensor, torch.Tensor): reordered_subscript = [ subscript[perm.index(i)] for i in range(len(perm)) ] reordered_ast_subscript = ( [ast_subscript[perm.index(i)] for i in range(len(perm))] if isinstance(ast_subscript, (list, tuple)) else None ) source_index_exprs = _cute_index_exprs( state, reordered_subscript, ast_subscript=reordered_ast_subscript, tensor=source_tensor, inactive_singleton_slice_expr="0", ) source_index_tuple = _cute_index_tuple(source_index_exprs) source_name = state.device_function.tensor_arg(source_tensor).name source_mask = _cute_combined_mask( state, reordered_subscript, None, tensor=source_tensor, ) source_dtype = CompileEnvironment.current().backend.dtype_str( source_tensor.dtype ) return expr_from_string( ( f"({tensor_name}.__setitem__({index_tuple}, " f"({source_name}[{source_index_tuple}] if {source_mask} else {source_dtype}(0))) " f"if {mask_expr} else None)" ) if source_mask is not None and mask_expr is not None else ( f"{tensor_name}.__setitem__({index_tuple}, " f"{source_name}[{source_index_tuple}] if {source_mask} else {source_dtype}(0))" if source_mask is not None else ( f"({tensor_name}.__setitem__({index_tuple}, {source_name}[{source_index_tuple}]) " f"if {mask_expr} else None)" if mask_expr is not None else f"{tensor_name}.__setitem__({index_tuple}, {source_name}[{source_index_tuple}])" ) ) ) raise exc.BackendUnsupported("cute", "permute lane-loop source tensor") env = CompileEnvironment.current() df = state.device_function input_shape = _get_tile_shape(input_val, env, df.config) output_shape = _get_tile_shape(output_val, env, df.config) src_coords = [ _get_dim_local_coord(state.codegen, input_val, i) for i in range(len(input_shape)) ] current_flat = _flat_index_from_coords(src_coords, input_shape) output_coords = _coords_from_flat_index(current_flat, output_shape) read_coords = [output_coords[perm.index(i)] for i in range(len(perm))] read_flat = _flat_index_from_coords(read_coords, input_shape) elif value_node.target in { torch.ops.aten.view.default, torch.ops.aten.reshape.default, }: input_arg = value_node.args[0] if not isinstance(input_arg, torch.fx.Node): return None input_node = input_arg input_val = input_node.meta.get("val") if not isinstance(input_val, torch.Tensor) or not isinstance( output_val, torch.Tensor ): return None env = CompileEnvironment.current() df = state.device_function input_shape = _get_tile_shape(input_val, env, df.config) output_shape = _get_tile_shape(output_val, env, df.config) if input_shape == output_shape: return None input_non_unit = [s for s in input_shape if s != 1] output_non_unit = [s for s in output_shape if s != 1] if input_non_unit == output_non_unit: return None src_coords = [ _get_dim_local_coord(state.codegen, input_val, i) for i in range(len(input_shape)) ] current_flat = _flat_index_from_coords(src_coords, input_shape) output_coords = [ _get_dim_local_coord(state.codegen, output_val, i) for i in range(len(output_shape)) ] read_flat = _flat_index_from_coords(output_coords, output_shape) else: return None env = CompileEnvironment.current() df = state.device_function input_numel = 1 for size in input_shape: input_numel *= size dtype_str = env.backend.dtype_str(input_val.dtype) smem_ptr = df.new_var("permute_smem_ptr") smem = df.new_var("permute_smem") state.codegen.add_statement( statement_from_string( f"{smem_ptr} = cute.arch.alloc_smem({dtype_str}, {input_numel})" ) ) state.codegen.add_statement( statement_from_string( f"{smem} = cute.make_tensor({smem_ptr}, ({input_numel},))" ) ) read_expr = ( f"{df.tensor_arg(tensor).name}.__setitem__({index_tuple}, {smem}[{read_flat}])" if mask_expr is None else ( f"({df.tensor_arg(tensor).name}.__setitem__({index_tuple}, {smem}[{read_flat}]) " f"if {mask_expr} else None)" ) ) return expr_from_string( f"({smem}.__setitem__({current_flat}, {{value}}), " f"cute.arch.sync_threads(), " f"{read_expr})", value=value, ) @_decorators.codegen(store, "metal") def _(state: CodegenState) -> ast.AST: # Metal delegates to the same PointerIndexingStrategy as Triton. # This produces tl.store(ptr + offset, val, mask) in the AST; # the MSL walker translates it to Metal. tensor = state.proxy_arg(0) subscript = state.proxy_arg(1) assert isinstance(subscript, (list, tuple)) value = state.ast_arg(2) extra_mask = state.ast_args[3] assert isinstance(extra_mask, (type(None), ast.AST)) if isinstance(tensor, torch.Tensor): device_fn = state.device_function device_fn.device_store_index += 1 indexing_idx = device_fn.device_memory_op_index device_fn.device_memory_op_index += 1 strategy = device_fn.get_indexing_strategy(indexing_idx) return strategy.codegen_store( state, tensor, [*subscript], value, extra_mask, None ) raise exc.BackendUnsupported("metal", f"store target type: {type(tensor)}") def _try_splice_tcgen05_unary_epilogue( state: CodegenState, tensor: object, subscript: list[object] | tuple[object, ...], ast_subscript: list[object] | tuple[object, ...], extra_mask: ast.AST | None, value_node: torch.fx.Node | None, ) -> ast.AST | None: """Splice attempt for ``out[tile] = chain(acc).to(x.dtype)``. Returns the splice-completion sentinel (``ast.Constant(value=None)``) on a successful splice (the caller should return it directly), and ``None`` if the splice did not fire — the caller should continue to the loud-failure backstop or the SIMT fallback. Splice is attempted only when the kernel has a tcgen05-registered matmul fx_node (``cute_state.matmul_fx_nodes`` non-empty), the store value has a backing FX node, the store target is a 2-D ``torch.Tensor``, and the chain analyzer accepts the value chain (returning ``(chain, anchor)`` for a non-empty chain rooted at a tcgen05 matmul). Chains the whitelist rejects (broadcast aux loads, reductions, kwarg-bearing binaries, etc.) leave the analyzer returning ``None`` and the splice does not fire — the loud-failure backstop then catches them. """ cute_state = state.device_function.cute_state if not cute_state.matmul_fx_nodes: return None if value_node is None: return None if not isinstance(tensor, torch.Tensor): return None analyzed = analyze_tcgen05_unary_epilogue_chain( state, value_node, output_global_shape=tuple(tensor.shape) ) if analyzed is None: return None chain, anchor = analyzed assert chain.steps anchor_result_var = cute_state.matmul_fx_node_result_vars.get(anchor) if anchor_result_var is None: return None rewritten_stmt = _codegen_cute_store_tcgen05_tile( state, tensor, subscript, ast_subscript, extra_mask, anchor_result_var, epilogue_chain=chain, ) if rewritten_stmt is None: return None stmts = rewritten_stmt if isinstance(rewritten_stmt, list) else [rewritten_stmt] for stmt in stmts: state.add_statement(stmt) return ast.Constant(value=None) @_decorators.codegen(store, "cute") def _(state: CodegenState) -> ast.AST: tensor = state.proxy_arg(0) subscript = state.proxy_arg(1) assert isinstance(subscript, (list, tuple)) ast_subscript = state.ast_args[1] assert isinstance(ast_subscript, (list, tuple)) raw_value = state.ast_args[2] extra_mask = state.ast_args[3] assert isinstance(extra_mask, (type(None), ast.AST)) value_node = None if state.fx_node is not None and len(state.fx_node.args) > 2: maybe_value_node = state.fx_node.args[2] if isinstance(maybe_value_node, torch.fx.Node): value_node = maybe_value_node if isinstance(tensor, torch.Tensor): affine_range_store = _codegen_cute_affine_range_store( state, tensor, subscript, ast_subscript, raw_value, extra_mask, value_node, ) if affine_range_store is not None: state.add_statement(affine_range_store) return ast.Constant(value=None) affine_reshape_store = _codegen_cute_affine_reshape_store( state, tensor, subscript, ast_subscript, extra_mask, value_node, ) if affine_reshape_store is not None: state.add_statement(affine_reshape_store) return ast.Constant(value=None) strided_slice_store = _codegen_cute_strided_slice_store( state, tensor, subscript, raw_value, extra_mask, value_node, ) if strided_slice_store is not None: state.add_statement(strided_slice_store) return ast.Constant(value=None) value = state.ast_arg(2) if value_node is not None: if value_node.op == "call_function": if isinstance(tensor, torch.Tensor): rewritten_stmt = _codegen_cute_store_stack_load( state, tensor, subscript, ast_subscript, value, extra_mask, value_node, ) if rewritten_stmt is not None: return rewritten_stmt rewritten_stmt = _codegen_cute_store_loaded_index_trailing_slices( state, tensor, subscript, ast_subscript, extra_mask, value_node, ) if rewritten_stmt is not None: return rewritten_stmt rewritten_stmt = _codegen_cute_store_expand_broadcast_tile( state, tensor, subscript, ast_subscript, value, extra_mask, value_node, ) if rewritten_stmt is not None: return rewritten_stmt rewritten_stmt = _codegen_cute_store_permute_lane_loops( state, tensor, subscript, ast_subscript, value, extra_mask, value_node, ) if rewritten_stmt is not None: return rewritten_stmt from .._compiler.cute.cute_reshape import codegen_cute_store_permute rewritten = codegen_cute_store_permute(state, value, value_node) if rewritten is not None: value = rewritten if isinstance(tensor, tuple): stack_tensor_ast = state.ast_args[0] assert isinstance(stack_tensor_ast, tuple) assert len(stack_tensor_ast) == 2 _tensor_like_ast, dev_ptrs_ast = stack_tensor_ast assert isinstance(dev_ptrs_ast, ast.AST) tensor_like, dev_ptrs = tensor offset_expr = _cute_stack_tensor_offset_expr( state, tensor_like, [*subscript], ast_subscript, ) backend = CompileEnvironment.current().backend target_dtype = backend.dtype_str(tensor_like.dtype) value = expr_from_string( backend.ast_to_dtype_expr("{value}", target_dtype), value=value, ) ptr_expr = _cute_stack_tensor_pointer_expr( target_dtype, dev_ptrs_ast, offset_expr ) store_expr = expr_from_string( "({ptr}).store({value})", ptr=ptr_expr, value=value ) mask_expr = _cute_stack_tensor_mask_expr( state, tensor_like, dev_ptrs, [*subscript], extra_mask, ) if mask_expr is None: return store_expr mask_ast = expr_from_string(mask_expr) assert isinstance(mask_ast, ast.expr) assert isinstance(store_expr, ast.expr) state.add_statement( ast.fix_missing_locations( ast.If( test=mask_ast, body=[ast.Expr(value=store_expr)], orelse=[], ) ) ) return ast.Constant(value=None) if not isinstance(tensor, torch.Tensor): raise exc.BackendUnsupported("cute", f"store target type: {type(tensor)}") _log_cute_layout(state, "store") if isinstance(value, ast.Name): rewritten_stmt = _codegen_cute_store_tcgen05_tile( state, tensor, subscript, ast_subscript, extra_mask, value.id, ) if rewritten_stmt is not None: stmts = ( rewritten_stmt if isinstance(rewritten_stmt, list) else [rewritten_stmt] ) for stmt in stmts: state.add_statement(stmt) return ast.Constant(value=None) # Try to splice a whitelisted chain epilogue # (`out[tile] = chain(acc).to(x.dtype)`) into the role-local # tcgen05 epilogue's per-thread T2R loop. Implementation in # ``_try_splice_tcgen05_unary_epilogue``. Chains the whitelist # rejects (broadcast aux loads, reductions, etc.) leave the # splice off and fall through to the loud-failure backstop # below. spliced = _try_splice_tcgen05_unary_epilogue( state, tensor, subscript, ast_subscript, extra_mask, value_node ) if spliced is not None: return spliced # Loud-failure backstop for fused-epilogue stores that follow a # tcgen05 matmul. The tcgen05 grid-emission path (in `program_id.py`) # does not bind the per-block-id `indices_<n>` / `mask_<n>` variable # names that the SIMT-fallback store path expects, so falling through # here would emit a kernel that crashes inside the cute DSL with # `name 'mask_0' is not defined`. Detect the pattern here — any # store value whose FX user chain transitively reaches a # tcgen05-registered matmul fx node — and raise a structured error # so the caller sees the actionable message instead of a cute-DSL # crash. Fixing this requires either (a) extending the tcgen05 grid # to emit per-block-id index/mask vars, or (b) per-subtile lambda # emission in `_codegen_cute_store_tcgen05_tile`. if ( state.device_function.cute_state.matmul_fx_nodes and value_node is not None and reach_tcgen05_matmul_anchors(state, value_node) ): raise exc.BackendUnsupported( "cute", "tcgen05 MMA path does not yet emit per-block-id indices " "and masks for non-whitelisted fused epilogues that follow " "the MMA. The store target's value chain depends on a " "tcgen05 matmul result through ops the chain analyzer " "rejects (e.g. aux tensors with a 3-D underlying shape " "and a static collapse like `aux3d[tile_m, tile_n, 0]`, " "loads whose index expression is not exactly the " "carrier tile-id symbol, non-scalar binary ops, " "`aten.add.Tensor` with `alpha=k`, or an intermediate " "`.to(d_inter)` cast where `d_inter` differs from the " "store-target dtype). Identity stores " "(`out[tile] = acc.to(x.dtype)`), whitelisted unary chains " "(relu/tanh/exp/log/sqrt/abs/neg + scalar add/sub/mul/div " "on the accumulator carrier), exact-shape 2-D " "auxiliary-tensor binary ops (`acc + residual[tile_m, " "tile_n]`), and rank-1 trailing-axis (rowvec) broadcast " "aux loads (`acc + bias[tile_n]`) all work via the " "fused-epilogue splice path. The leading-axis rank-1 " "form (`acc + bias[tile_m]`) is rejected because a bare " "rank-1 RHS aligns to the trailing axis under PyTorch " "broadcasting; an explicit colvec broadcast must be " "written with `bias[tile_m][:, None]` / " "`.unsqueeze(-1)`.", ) tensor_name = state.device_function.tensor_arg(tensor).name backend = CompileEnvironment.current().backend target_dtype = backend.dtype_str(tensor.dtype) value = expr_from_string( backend.ast_to_dtype_expr("{value}", target_dtype), value=value, ) index_exprs = _cute_index_exprs( state, subscript, ast_subscript, tensor=tensor, inactive_singleton_slice_expr="0", ) topk_lane_expr: object | None = None topk_k: object | None = None if state.fx_node is not None and len(state.fx_node.args) > 2: value_node = state.fx_node.args[2] if ( isinstance(value_node, torch.fx.Node) and value_node.target is operator.getitem and isinstance(value_node.args[0], torch.fx.Node) and value_node.args[0].target is torch.ops.aten.topk.default ): topk_lane_expr = value_node.args[0].meta.get("cute_topk_lane_expr") topk_k = value_node.args[0].meta.get("cute_topk_k") if isinstance(topk_lane_expr, str) and isinstance(topk_k, int): index_exprs[-1] = topk_lane_expr store_uses_pointer = "None" not in index_exprs store_expr = _cute_scalar_store_expr(tensor_name, index_exprs, "{value}") assign_expr = expr_from_string(store_expr, value=value) mask_expr = _cute_combined_mask(state, subscript, extra_mask, tensor=tensor) if isinstance(topk_lane_expr, str) and isinstance(topk_k, int): topk_mask = f"({topk_lane_expr}) < {topk_k}" mask_expr = topk_mask if mask_expr is None else f"({mask_expr}) and {topk_mask}" if mask_expr is None: return assign_expr if store_uses_pointer: mask_ast = expr_from_string(mask_expr) assert isinstance(mask_ast, ast.expr) assert isinstance(assign_expr, ast.expr) state.add_statement( ast.fix_missing_locations( ast.If( test=mask_ast, body=[ast.Expr(value=assign_expr)], orelse=[], ) ) ) return ast.Constant(value=None) return expr_from_string( f"({store_expr} if {mask_expr} else None)", value=value, ) # TODO(joydddd): Add support for stack tensor in ref mode. @_decorators.ref(store) def _( tensor: torch.Tensor, index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, ) -> None: from .ref_tile import RefTile # Normalize indices and identify tensor indices indices = [] tensor_idx_positions = [] for i, idx in enumerate(index): if isinstance(idx, RefTile): idx = idx.index # pyrefly: ignore [bad-argument-type] indices.append(idx) if isinstance(idx, torch.Tensor): tensor_idx_positions.append(i) # Handle broadcasting for multiple tensor indices if len(tensor_idx_positions) > 1: grids = torch.meshgrid( # pyrefly: ignore [bad-argument-type] *(indices[i] for i in tensor_idx_positions), indexing="ij", ) for i, grid in zip(tensor_idx_positions, grids, strict=False): # pyrefly: ignore [unsupported-operation] indices[i] = grid if extra_mask is not None: mask = extra_mask.to(torch.bool) # Check bounds for tensor indices for i, idx in enumerate(indices): if isinstance(idx, torch.Tensor): mask = mask & (idx >= 0) & (idx < tensor.shape[i]) mask_count = int(mask.sum().item()) if mask_count == 0: return # Use index_put_ for masked stores valid_indices = [] for idx in indices: if isinstance(idx, torch.Tensor): valid_indices.append(idx[mask].long()) else: idx_val = int(idx) if isinstance(idx, torch.SymInt) else idx valid_indices.append( # pyrefly: ignore [no-matching-overload] torch.full( (mask_count,), idx_val, dtype=torch.long, device=tensor.device ) ) if isinstance(value, torch.Tensor): values = value[mask] else: val = int(value) if isinstance(value, torch.SymInt) else value values = torch.full( (mask_count,), val, dtype=tensor.dtype, device=tensor.device ) # Check for duplicate indices - this is undefined behavior in Triton if valid_indices: stacked = torch.stack(valid_indices, dim=1) unique_count = stacked.unique(dim=0).size(0) if unique_count < stacked.size(0): raise exc.DuplicateStoreIndicesError( "hl.store with duplicate indices has undefined behavior in compiled mode. " "The order in which values are written to the same memory location is " "non-deterministic and may vary between Triton versions and backends." ) tensor.index_put_(tuple(valid_indices), values, accumulate=False) return # Simple assignment tensor[tuple(indices)] = ( # pyrefly: ignore[unsupported-operation] int(value) if isinstance(value, torch.SymInt) else value )
[docs] @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def load( tensor: torch.Tensor | StackTensor, index: list[object], extra_mask: torch.Tensor | None = None, eviction_policy: str | None = None, ) -> torch.Tensor: """Load a value from a tensor using a list of indices. This function is equivalent to `tensor[index]` but allows setting `extra_mask=` to mask elements beyond the default masking based on the hl.tile range. It also accepts an optional `eviction_policy` which is forwarded to the underlying Triton `tl.load` call to control the cache eviction behavior (e.g., "evict_last"). Args: tensor: The tensor / stack tensor to load from index: The indices to use to index into the tensor extra_mask: The extra mask (beyond automatic tile bounds masking) to apply to the tensor eviction_policy: Optional Triton load eviction policy to hint cache behavior Returns: torch.Tensor: The loaded value """ raise exc.NotInsideKernel
@_decorators.prepare_args(load) def _( tensor: torch.Tensor | StackTensor, index: list[object], extra_mask: torch.Tensor | None = None, eviction_policy: str | None = None, ) -> tuple[torch.Tensor | tuple, list[object], torch.Tensor | None, str | None]: from .tile_proxy import Tile index = Tile._tiles_to_sizes_for_index(index) if isinstance(tensor, StackTensor): return (tuple(tensor), index, extra_mask, eviction_policy) assert isinstance(tensor, torch.Tensor) return (tensor, index, extra_mask, eviction_policy) @_decorators.register_fake(load) def _( tensor: torch.Tensor | tuple[object, ...], index: list[object], extra_mask: torch.Tensor | None = None, eviction_policy: str | None = None, ) -> torch.Tensor: if isinstance(tensor, torch.Tensor): target_shape = SubscriptIndexing.compute_shape(tensor, index) env = CompileEnvironment.current() env.backend.process_fake_tensor_load(tensor, index) return env.new_index_result(tensor, target_shape) if isinstance(tensor, tuple): tensor_like, dev_ptrs = tensor assert isinstance(tensor_like, torch.Tensor) assert isinstance(dev_ptrs, torch.Tensor) tensor_shape = SubscriptIndexing.compute_shape(tensor_like, index) target_shape = list(dev_ptrs.size()) + tensor_shape return tensor_like.new_empty(target_shape) raise NotImplementedError(f"Unsupported tensor type: {type(tensor)}") def _maybe_materialize_tile_index_load( state: CodegenState, tensor: torch.Tensor, subscript: list[object] | tuple[object, ...], ) -> ast.AST | None: """If this load is on a ``tile.index`` value (e.g. ``tile_m.index[:, None]``), emit the inline ``indices_<bid>[<sub>]`` expression and return it. Returns ``None`` otherwise. ``tile.index`` tensors are synthesized inside the kernel — they aren't registered in ``tensor_to_origin`` — so the regular load path's ``tensor_arg`` lookup would ``KeyError``. Supported subscript entries are ``None`` (new axis) and ``slice(None)`` (full slice). """ from ..language import tile_index tensor_node = state.fx_node.args[0] if state.fx_node is not None else None if not ( isinstance(tensor_node, torch.fx.Node) and tensor_node.op == "call_function" and tensor_node.target == tile_index ): return None env = CompileEnvironment.current() block_id = env.get_block_id(tensor.size(0)) assert block_id is not None base_var = state.codegen.index_var(block_id) parts = [] for idx in subscript: if idx is None: parts.append("None") elif idx == slice(None): parts.append(":") else: raise AssertionError(f"Unexpected index type in tile_index load: {idx}") return expr_from_string(f"{base_var}[{', '.join(parts)}]") @_decorators.codegen(load, "triton") def _(state: CodegenState) -> ast.AST: tensor = state.proxy_arg(0) subscript = state.proxy_arg(1) assert isinstance(subscript, (list, tuple)) ast_subscript = state.ast_args[1] assert isinstance(ast_subscript, (list, tuple)) extra_mask = state.ast_args[2] assert isinstance(extra_mask, (type(None), ast.AST)) eviction_policy = state.ast_args[3] if len(state.ast_args) > 3 else None device_fn = state.device_function load_idx = device_fn.device_load_index device_fn.device_load_index += 1 # If no explicit eviction_policy and we're in device code, use tunable if eviction_policy is None and state.codegen.on_device: policies = state.config.load_eviction_policies if load_idx < len(policies): policy_value = policies[load_idx] eviction_policy = _EVICTION_POLICY_MAP.get(policy_value, policy_value) if eviction_policy is not None: assert isinstance(eviction_policy, str) eviction_policy = ast.Constant(value=eviction_policy) cache_modifier = None if state.codegen.on_device: modifier_idx = device_fn.device_load_cache_modifier_index device_fn.device_load_cache_modifier_index += 1 modifiers = state.config.load_cache_modifiers if modifier_idx < len(modifiers) and modifiers[modifier_idx]: cache_modifier = ast.Constant(value=modifiers[modifier_idx]) if isinstance(tensor, torch.Tensor): tile_index_result = _maybe_materialize_tile_index_load(state, tensor, subscript) if tile_index_result is not None: return tile_index_result # Use the shared memory op index for indexing strategy indexing_idx = device_fn.device_memory_op_index device_fn.device_memory_op_index += 1 strategy = device_fn.get_indexing_strategy(indexing_idx) if state.codegen.load_transform is not None: return state.codegen.load_transform( state, tensor, [*subscript], extra_mask, eviction_policy, cache_modifier, strategy.codegen_load, ) return strategy.codegen_load( state, tensor, [*subscript], extra_mask, eviction_policy, cache_modifier ) if isinstance(tensor, tuple): from .._compiler.indexing_strategy import StackIndexingStrategy # Fusion is not supported for stack loads (multi-tensor device pointers); # fall through to the unfused path regardless of load_transform. stack_tensor_ast = state.ast_args[0] assert isinstance(stack_tensor_ast, tuple) assert len(stack_tensor_ast) == 2 tensor_like_ast, dev_ptrs_ast = stack_tensor_ast return StackIndexingStrategy.codegen_load( state, tensor, dev_ptrs_ast, [*subscript], extra_mask, eviction_policy, cache_modifier, ) raise NotImplementedError(f"Unsupported tensor type: {type(tensor)}") @_decorators.codegen(load, "pallas") def _(state: CodegenState) -> ast.AST: tensor = state.proxy_arg(0) subscript = state.proxy_arg(1) assert isinstance(tensor, torch.Tensor) assert isinstance(subscript, (list, tuple)) tile_index_result = _maybe_materialize_tile_index_load(state, tensor, subscript) if tile_index_result is not None: return tile_index_result return pallas_codegen.load_expr(state, list(subscript), tensor) @_decorators.codegen(load, "metal") def _(state: CodegenState) -> ast.AST: # Metal delegates to the same PointerIndexingStrategy as Triton. # This produces tl.load(ptr + offset, mask, other=0) in the AST; # the MSL walker translates it to Metal. tensor = state.proxy_arg(0) subscript = state.proxy_arg(1) assert isinstance(subscript, (list, tuple)) ast_subscript = state.ast_args[1] assert isinstance(ast_subscript, (list, tuple)) extra_mask = state.ast_args[2] assert isinstance(extra_mask, (type(None), ast.AST)) eviction_policy = state.ast_args[3] if len(state.ast_args) > 3 else None assert isinstance(eviction_policy, (type(None), ast.AST)) if isinstance(tensor, torch.Tensor): device_fn = state.device_function device_fn.device_load_index += 1 indexing_idx = device_fn.device_memory_op_index device_fn.device_memory_op_index += 1 strategy = device_fn.get_indexing_strategy(indexing_idx) return strategy.codegen_load( state, tensor, [*subscript], extra_mask, eviction_policy, None ) raise exc.BackendUnsupported("metal", f"load tensor type: {type(tensor)}") def _cute_load_feeds_sort_or_scan(load_node: object) -> bool: """Return True if ``load_node`` feeds a sort/topk/_associative_scan. Direct users (sort/topk and the scalar ``_associative_scan`` path) are matched immediately. For a tuple ``_associative_scan`` the index stream is typically a ``load`` that flows through a chain of dtype-cast / shape ops (e.g. ``indices[tile].float().unsqueeze(1).expand_as(vals)``) before reaching the scan. To recover a scalar load for that stream we follow the forward chain through those pass-through ops. """ from torch.fx.node import Node from .._compiler.cute.indexing import is_cute_shape_chain_target if not isinstance(load_node, Node): return False passthrough_targets = (torch.ops.prims.convert_element_type.default,) seen: set[Node] = set() stack: list[Node] = [load_node] while stack: node = stack.pop() for user in node.users: if not isinstance(user, Node): continue target = user.target if ( target in (torch.ops.aten.sort.default, torch.ops.aten.topk.default) or getattr(target, "__name__", None) == "_associative_scan" ): return True if ( is_cute_shape_chain_target(target) or target in passthrough_targets ) and user not in seen: seen.add(user) stack.append(user) return False @_decorators.codegen(load, "cute") def _(state: CodegenState) -> object: tensor = state.proxy_arg(0) subscript = state.proxy_arg(1) assert isinstance(subscript, (list, tuple)) ast_subscript = state.ast_args[1] assert isinstance(ast_subscript, (list, tuple)) extra_mask = state.ast_args[2] assert isinstance(extra_mask, (type(None), ast.AST)) if isinstance(tensor, tuple): stack_tensor_ast = state.ast_args[0] assert isinstance(stack_tensor_ast, tuple) assert len(stack_tensor_ast) == 2 tensor_like_ast, dev_ptrs_ast = stack_tensor_ast assert isinstance(dev_ptrs_ast, ast.AST) tensor_like, dev_ptrs = tensor offset_expr = _cute_stack_tensor_offset_expr( state, tensor_like, [*subscript], ast_subscript, ) backend = CompileEnvironment.current().backend target_dtype = backend.dtype_str(tensor_like.dtype) ptr_expr = _cute_stack_tensor_pointer_expr( target_dtype, dev_ptrs_ast, offset_expr ) load_expr = f"({ast.unparse(ptr_expr)}).load()" mask_expr = _cute_stack_tensor_mask_expr( state, tensor_like, dev_ptrs, [*subscript], extra_mask, ) if tensor_like.dtype is torch.bool: load_expr = f"({load_expr} != cutlass.Uint8(0))" if mask_expr is None: return expr_from_string(load_expr) return expr_from_string( f"({load_expr} if {mask_expr} else cutlass.Boolean(0))" ) if mask_expr is None: return expr_from_string(load_expr) return expr_from_string(f"({load_expr} if {mask_expr} else {target_dtype}(0))") if not isinstance(tensor, torch.Tensor): raise exc.BackendUnsupported("cute", f"load tensor type: {type(tensor)}") _log_cute_layout(state, "load") from ..language import tile_index tensor_node = state.fx_node.args[0] if state.fx_node is not None else None if ( isinstance(tensor_node, torch.fx.Node) and tensor_node.op == "call_function" and tensor_node.target == tile_index ): env = CompileEnvironment.current() block_id = env.get_block_id(tensor.size(0)) if block_id is None: raise exc.BackendUnsupported("cute", "tile_index load block id") index_var = _cute_active_index_var(state, block_id) if index_var is None: raise exc.BackendUnsupported("cute", "inactive tile_index load") for idx in subscript: if idx is None or idx == slice(None): continue raise exc.BackendUnsupported( "cute", f"tile_index load index type: {type(idx)}" ) return expr_from_string(index_var) cute_state = state.device_function.cute_state if cute_state.suppress_root_lane_loops or ( state.fx_node is not None and cute_state.is_collective_handled_load(state.fx_node.name) ): zero = CompileEnvironment.current().backend.dtype_str(tensor.dtype) return expr_from_string(f"{zero}(0)") packed_affine_lhs = _maybe_codegen_cute_packed_affine_lhs_load( state, tensor, subscript, extra_mask ) if packed_affine_lhs is not None: return packed_affine_lhs packed_rhs_load = _maybe_codegen_cute_packed_rhs_load( state, tensor, subscript, extra_mask ) if packed_rhs_load is not None: return packed_rhs_load if _is_cute_affine_range_load_for_store(state, subscript, ast_subscript): zero = _cute_scalar_storage_dtype(tensor.dtype) return expr_from_string(f"{zero}(0)") if _is_cute_strided_slice_load_for_store(state, tensor, subscript): zero = _cute_scalar_storage_dtype(tensor.dtype) return expr_from_string(f"{zero}(0)") tensor_name = state.device_function.tensor_arg(tensor).name index_exprs = _cute_index_exprs( state, subscript, ast_subscript, tensor=tensor, inactive_slice_expr="None", inactive_singleton_slice_expr="0", ) mask_expr = _cute_combined_mask( state, subscript, extra_mask, tensor=tensor, include_tensor_index_masks=False, ) vec_ctx = _cute_vector_load_ctx(state, tensor, subscript, index_exprs, extra_mask) if vec_ctx is not None: vec_width, vec_block_id, vec_mode = vec_ctx from .._compiler.reduction_strategy import LoopedReductionStrategy loops = state.codegen.active_device_loops.get(vec_block_id) strategy = loops[-1].strategy if loops else None if vec_mode == "vec": load_expr = _cute_vector_load_expr( tensor_name, index_exprs, tensor.dtype, vec_width=vec_width ) # The mask is deferred to the post-fold scalar in # codegen_reduction. The vec load itself is unconditional; the # mask is recorded on the active LoopedReductionStrategy and # applied around the folded sum. if isinstance(strategy, LoopedReductionStrategy): strategy._cute_emitted_vec_load = True if mask_expr is not None: strategy._cute_pending_vec_masks.append(mask_expr) mask_expr = None elif vec_mode == "unroll": # Register (or reuse) a hoisted U16 vec load for this (tensor, # base_index) pair, then return ``hoist_var[vi].bitcast(dtype)`` # so the existing scalar pipeline sees a scalar of the original # dtype. assert isinstance(strategy, LoopedReductionStrategy) load_expr = _cute_register_unroll_vec_hoist( state, strategy, tensor, tensor_name, index_exprs, vec_width, ) elif vec_mode == "tile_unroll": # Same hoist protocol as ``LoopedReductionStrategy``'s # ``unroll`` mode but for ``CuteNDTileStrategy`` lane loops. from .._compiler.tile_strategy import BlockSizeTileStrategy assert isinstance(strategy, BlockSizeTileStrategy) load_expr = _cute_register_tile_unroll_vec_hoist( state, strategy, vec_block_id, tensor, tensor_name, index_exprs, vec_width, ) else: assert vec_mode == "tile_unroll_split2" # V=8 fp16/bf16: emit two back-to-back ``cute.arch.load(..., # V=4)`` calls (lanes 0-3 and 4-7). Works around the CuTe # DSL's ``nvvm.load.ext`` ICE on V=8 while still issuing the # full LDG.128 of bytes-per-thread-per-outer-iter. from .._compiler.tile_strategy import BlockSizeTileStrategy assert isinstance(strategy, BlockSizeTileStrategy) load_expr = _cute_register_tile_unroll_vec_hoist_split2( state, strategy, vec_block_id, tensor, tensor_name, index_exprs, vec_width, ) else: load_expr = _cute_scalar_load_expr(tensor_name, index_exprs, tensor.dtype) if tensor.dtype is torch.bool: load_expr = f"({load_expr} != cutlass.Uint8(0))" if mask_expr is None: return expr_from_string(load_expr) return expr_from_string(f"({load_expr} if {mask_expr} else cutlass.Boolean(0))") if state.fx_node is not None and _cute_load_feeds_sort_or_scan(state.fx_node): from .._compiler.cute.indexing import CuteSortableLoad tensor_dim = 0 sort_index_pos = -1 for idx in subscript: if idx is None: continue if tensor_dim == tensor.ndim - 1: sort_index_pos = tensor_dim break tensor_dim += 1 if sort_index_pos < 0: raise exc.BackendUnsupported("cute", "sort/topk input rank") sortable_load = CuteSortableLoad( expr=expr_from_string( load_expr if mask_expr is None else f"({load_expr} if {mask_expr} else {_cute_scalar_storage_dtype(tensor.dtype)}(0))" ), tensor_name=tensor_name, index_exprs=tuple(index_exprs), sort_index_pos=sort_index_pos, mask_expr=mask_expr, dtype=tensor.dtype, ) state.fx_node.meta["cute_sortable_load"] = sortable_load return sortable_load.expr if mask_expr is None: return expr_from_string(load_expr) zero = _cute_scalar_storage_dtype(tensor.dtype) return expr_from_string(f"({load_expr} if {mask_expr} else {zero}(0))") @_decorators.get_masked_value(load) def _(node: torch.fx.Node) -> int: return 0 # loads are always masked to 0 # TODO(joydddd): Add support for stack tensor in ref mode. @_decorators.ref(load) def _( tensor: torch.Tensor, index: list[object], extra_mask: torch.Tensor | None = None, eviction_policy: str | None = None, ) -> torch.Tensor: from .ref_tile import RefTile if extra_mask is None: # Convert RefTiles to indices indices = [idx.index if isinstance(idx, RefTile) else idx for idx in index] # Use meshgrid for Cartesian product when we have multiple tensor indices tensor_idxs = [ i for i, idx in enumerate(indices) if isinstance(idx, torch.Tensor) ] if len(tensor_idxs) > 1: # pyrefly: ignore [bad-argument-type] grids = torch.meshgrid(*(indices[i] for i in tensor_idxs), indexing="ij") for i, grid in zip(tensor_idxs, grids, strict=False): indices[i] = grid # pyrefly: ignore [bad-argument-type, bad-index] return tensor[tuple(indices)] # Create zero result matching mask shape result = torch.zeros(extra_mask.shape, dtype=tensor.dtype, device=tensor.device) # Process indices: convert RefTiles and clamp tensor indices orig_indices, safe_indices, is_tensor_mask = [], [], [] for i, idx in enumerate(index): if isinstance(idx, RefTile): idx = idx.index # Convert RefTile to tensor if isinstance(idx, torch.Tensor): dim_size = tensor.shape[i] if i < len(tensor.shape) else tensor.numel() orig_indices.append(idx) safe_indices.append(torch.clamp(idx, 0, dim_size - 1)) is_tensor_mask.append(True) else: orig_indices.append(idx) safe_indices.append(idx) is_tensor_mask.append(False) # Apply broadcasting if we have multiple tensor indices tensor_positions = [i for i, is_tensor in enumerate(is_tensor_mask) if is_tensor] if len(tensor_positions) > 1: # Add unsqueeze operations for broadcasting broadcast_indices = [] for i, (idx, is_tensor) in enumerate( zip(safe_indices, is_tensor_mask, strict=False) ): if is_tensor: new_idx = idx # Add dimension for each other tensor index for j, other_pos in enumerate(tensor_positions): if other_pos != i: new_idx = new_idx.unsqueeze(j if other_pos < i else -1) broadcast_indices.append(new_idx) else: broadcast_indices.append(idx) values = tensor[tuple(broadcast_indices)] else: values = tensor[tuple(safe_indices)] # Build validity mask valid_mask = extra_mask.clone() for i, (orig_idx, is_tensor) in enumerate( zip(orig_indices, is_tensor_mask, strict=False) ): if is_tensor: dim_size = tensor.shape[i] if i < len(tensor.shape) else tensor.numel() in_bounds = (orig_idx >= 0) & (orig_idx < dim_size) # Broadcast to match mask shape by adding dimensions # Count how many tensor indices come before and after this one n_before = sum(1 for j in range(i) if is_tensor_mask[j]) n_after = sum( 1 for j in range(i + 1, len(is_tensor_mask)) if is_tensor_mask[j] ) # Add dimensions: n_after dimensions at the end, n_before at the beginning for _ in range(n_after): in_bounds = in_bounds.unsqueeze(-1) for _ in range(n_before): in_bounds = in_bounds.unsqueeze(0) valid_mask = valid_mask & in_bounds return torch.where(valid_mask, values, result)