from __future__ import annotations
import ast
import contextlib
import logging
import operator
import textwrap
from typing import TYPE_CHECKING
import torch
from torch.fx import has_side_effect
from torch.fx.node import map_arg
from .. import exc
from .._compiler.ast_extension import expr_from_string
from .._compiler.ast_extension import statement_from_string
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.compile_environment import _symint_expr
from .._compiler.cute.cutedsl_compat import emit_dealloc_mbarrier_initialized_kwarg
from .._compiler.cute.cutedsl_compat import emit_pipeline_advance
from .._compiler.cute.cutedsl_compat import emit_producer_tail_tma_umma
from .._compiler.cute.cutedsl_compat import emit_producer_tail_umma_async
from .._compiler.cute.tcgen05_constants import (
TCGEN05_ACC_WAIT_PLACEMENT_BEFORE_SUBTILE_LOOP,
)
from .._compiler.cute.tcgen05_constants import TCGEN05_ACC_WAIT_PLACEMENT_CONFIG_KEY
from .._compiler.cute.tcgen05_constants import TCGEN05_ACC_WAIT_PLACEMENT_SUBTILE_LOOP
from .._compiler.cute.tcgen05_constants import TCGEN05_C_ACQUIRE_PLACEMENT_CONFIG_KEY
from .._compiler.cute.tcgen05_constants import TCGEN05_C_ACQUIRE_PLACEMENT_FIRST_IN_LOOP
from .._compiler.cute.tcgen05_constants import (
TCGEN05_C_ACQUIRE_PLACEMENT_LATER_BEFORE_BARRIER,
)
from .._compiler.cute.tcgen05_constants import TCGEN05_C_ACQUIRE_PLACEMENT_PRE_LOOP
from .._compiler.cute.tcgen05_constants import TCGEN05_C_STORE_MODE_CONFIG_KEY
from .._compiler.cute.tcgen05_constants import TCGEN05_C_STORE_MODE_NORMAL
from .._compiler.cute.tcgen05_constants import TCGEN05_C_STORE_MODE_SKIP_EPILOGUE_STORE
from .._compiler.cute.tcgen05_constants import TCGEN05_EPILOGUE_LAYOUT_CONFIG_KEY
from .._compiler.cute.tcgen05_constants import (
TCGEN05_EPILOGUE_LAYOUT_MODULE_HELPER_ACC_T2R,
)
from .._compiler.cute.tcgen05_constants import (
TCGEN05_EPILOGUE_LAYOUT_MODULE_HELPER_STORE_TAIL,
)
from .._compiler.cute.tcgen05_constants import TCGEN05_EPILOGUE_LAYOUT_NORMAL
from .._compiler.cute.tcgen05_constants import (
TCGEN05_EPILOGUE_LAYOUT_SPLIT_ACC_T2R_STORE_TAIL,
)
from .._compiler.cute.tcgen05_constants import TCGEN05_EPILOGUE_LAYOUT_SPLIT_FIRST_T2R
from .._compiler.cute.tcgen05_constants import TCGEN05_TWO_CTA_BLOCK_N
from .._compiler.host_function import HostFunction
from .._compiler.indexing_strategy import SubscriptIndexing
from .._compiler.indexing_strategy import TileWithOffsetInfo
from .._compiler.indexing_strategy import _get_tile_with_offset_info
from .._compiler.pallas import codegen as pallas_codegen
from .._compiler.variable_origin import GridOrigin
from .._compiler.variable_origin import TileBeginOrigin
from .._compiler.variable_origin import TileCountOrigin
from .._compiler.variable_origin import TileEndOrigin
from .._compiler.variable_origin import TileIdOrigin
from . import _decorators
from .stack_tensor import StackTensor
if TYPE_CHECKING:
from .._compiler.inductor_lowering import CodegenState
from .._compiler.tile_strategy import LoopDimInfo
from .._compiler.host_function import SymbolOrigin
# TileBeginWithOffset removed - using TileBeginWithOffsetPattern instead
__all__ = ["load", "store"]
log = logging.getLogger(__name__)
# Map short config names to full Triton API names for eviction policies
_EVICTION_POLICY_MAP = {
"": None,
"first": "evict_first",
"last": "evict_last",
}
[docs]
@has_side_effect
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
def store(
tensor: torch.Tensor | StackTensor,
index: list[object],
value: torch.Tensor | torch.SymInt | float,
extra_mask: torch.Tensor | None = None,
) -> None:
"""Store a value to a tensor using a list of indices.
This function is equivalent to `tensor[index] = value` but allows
setting `extra_mask=` to mask elements beyond the default masking
based on the hl.tile range.
Args:
tensor: The tensor / stack tensor to store to
index: The indices to use to index into the tensor
value: The value to store
extra_mask: The extra mask (beyond automatic tile bounds masking) to apply to the tensor
Returns:
None
"""
raise exc.NotInsideKernel
@_decorators.prepare_args(store)
def _(
tensor: torch.Tensor | StackTensor,
index: list[object],
value: torch.Tensor | torch.SymInt | float,
extra_mask: torch.Tensor | None = None,
) -> tuple[
torch.Tensor | tuple,
list[object],
torch.Tensor | torch.SymInt | float | int,
torch.Tensor | None,
]:
from .tile_proxy import Tile
if isinstance(value, torch.Tensor) and value.dtype != tensor.dtype:
value = value.to(tensor.dtype)
index = Tile._tiles_to_sizes_for_index(index)
if isinstance(tensor, StackTensor):
return (tuple(tensor), index, value, extra_mask)
if isinstance(tensor, torch.Tensor):
return (tensor, index, value, extra_mask)
raise NotImplementedError(f"Cannot store to type: {type(tensor)}")
@_decorators.register_fake(store)
def _(
tensor: torch.Tensor | tuple[object, ...],
index: list[object],
value: torch.Tensor | torch.SymInt | float,
extra_mask: torch.Tensor | None = None,
) -> None:
return None
@_decorators.codegen(store, "triton")
def _(state: CodegenState) -> ast.AST:
tensor = state.proxy_arg(0)
subscript = state.proxy_arg(1)
assert isinstance(subscript, (list, tuple))
value = state.ast_arg(2)
extra_mask = state.ast_args[3]
assert isinstance(extra_mask, (type(None), ast.AST))
if isinstance(tensor, torch.Tensor):
device_fn = state.device_function
fx_node = state.fx_node
assert fx_node is not None
epilogue_subtile_group_id = fx_node.meta.get("epilogue_subtile_group_id")
if epilogue_subtile_group_id is None:
indexing_idx = device_fn.allocate_store_index()
elif fx_node.meta.get("epilogue_subtile_primary_store", False):
indexing_idx = device_fn.allocate_store_index()
device_fn.epilogue_subtile_store_indices[epilogue_subtile_group_id] = (
indexing_idx
)
else:
indexing_idx = device_fn.epilogue_subtile_store_indices[
epilogue_subtile_group_id
]
strategy = device_fn.get_indexing_strategy(indexing_idx)
if state.codegen.store_transform is not None:
return state.codegen.store_transform(
state,
tensor,
[*subscript],
value,
extra_mask,
strategy.codegen_store,
)
return strategy.codegen_store(state, tensor, [*subscript], value, extra_mask)
if isinstance(tensor, tuple):
from .._compiler.indexing_strategy import StackIndexingStrategy
# Fusion is not supported for stack stores (multi-tensor device pointers);
# fall through to the unfused path regardless of store_transform.
stack_tensor_ast = state.ast_args[0]
assert isinstance(stack_tensor_ast, tuple)
assert len(stack_tensor_ast) == 2
_tensor_like_ast, dev_ptrs_ast = stack_tensor_ast
return StackIndexingStrategy.codegen_store(
state, tensor, dev_ptrs_ast, [*subscript], value, extra_mask
)
raise NotImplementedError(f"Cannot store to type: {type(tensor)}")
def _record_pad_info(
state: CodegenState,
tensor: torch.Tensor,
tensor_dim: int,
block_id: int,
extra_pad: int = 0,
) -> None:
"""Record that a tensor dimension uses pl.ds() and may need host-side padding.
*extra_pad* accounts for non-zero loop begins: 0 when the loop starts
at offset 0, ``begin % block_size`` for a constant begin, or
``block_size - 1`` for a data-dependent begin.
Note: stores one entry per (tensor, dim). If two inner loops tile the
same dim with different block_ids, the last one wins. This is fine when
both loops use the same block size (the common case).
"""
pad_info = state.device_function.pallas_pad_info
tensor_id = id(tensor)
if tensor_id not in pad_info:
pad_info[tensor_id] = {}
pad_info[tensor_id][tensor_dim] = (block_id, extra_pad)
def _maybe_get_symbol_origin(idx: object) -> SymbolOrigin | None:
if not isinstance(idx, torch.SymInt):
return None
expr = _symint_expr(idx)
if expr is None:
return None
return HostFunction.current().expr_to_origin.get(expr)
@_decorators.codegen(store, "pallas")
def _(state: CodegenState) -> None:
tensor = state.proxy_arg(0)
subscript = state.proxy_arg(1)
assert isinstance(subscript, (list, tuple))
value = state.ast_arg(2)
assert isinstance(tensor, torch.Tensor)
name = state.device_function.tensor_arg(tensor).name
name = pallas_codegen.vmem_name(state, name)
# Increment memory op index to stay in sync with triton backend
device_fn = state.device_function
device_fn.device_store_index += 1
device_fn.device_memory_op_index += 1
index_str, _ = pallas_codegen.index_str(state, subscript, tensor)
state.codegen.add_statement(
statement_from_string(f"{name}[{index_str}] = {{value}}", value=value)
)
def _matching_block_ids(env: CompileEnvironment, size: object) -> list[int]:
"""Find all block_ids that match the given dimension size."""
candidates: list[int] = []
if isinstance(size, (int, torch.SymInt)):
if (direct := env.get_block_id(size)) is not None:
candidates.append(direct)
if not isinstance(size, (int, torch.SymInt)):
return candidates
for info in env.block_sizes:
if not isinstance(info.size, (int, torch.SymInt)):
continue
if not env.known_equal(info.size, size):
continue
if info.block_id not in candidates:
candidates.append(info.block_id)
return candidates
def _log_cute_layout(state: CodegenState, op_name: str) -> None:
"""Log the CuTe layout annotation for the current node, if any.
This is used during CuTe load/store codegen to make layout info
visible for debugging and future codegen integration.
"""
layout = state.cute_layout
if layout is None:
return
node_name = state.fx_node.name if state.fx_node else "?"
log.debug(
"cute %s %s: layout tag=%s thread=%s value=%s",
op_name,
node_name,
layout.tag.value,
layout.thread_shape,
layout.value_shape,
)
def _cute_active_index_var(state: CodegenState, block_id: int) -> str | None:
loops = state.codegen.active_device_loops.get(block_id)
if loops:
return loops[-1].strategy.index_var(block_id)
grid_state = state.codegen.current_grid_state
if grid_state is not None and block_id in grid_state.block_ids:
return grid_state.strategy.index_var(block_id)
return None
def _cute_active_mask_var(state: CodegenState, block_id: int) -> str | None:
loops = state.codegen.active_device_loops.get(block_id)
if loops:
return loops[-1].strategy.mask_var(block_id)
return None
def _cute_unique_graph_block_id(state: CodegenState) -> int | None:
fx_node = state.fx_node
if fx_node is None:
return None
graph_block_ids = [
graph_info.block_ids
for graph_info in state.codegen.codegen_graphs
if graph_info.graph is fx_node.graph and hasattr(graph_info, "block_ids")
]
if len(graph_block_ids) != 1 or len(graph_block_ids[0]) != 1:
return None
(block_id,) = graph_block_ids[0]
return block_id
def _maybe_codegen_cute_packed_affine_lhs_load(
state: CodegenState,
tensor: torch.Tensor,
subscript: list[object] | tuple[object, ...],
extra_mask: ast.AST | None,
) -> object | None:
from .._compiler.cute.indexing import CutePackedAffineLoad
from .._compiler.cute.indexing import match_cute_affine_range_iota
from .._compiler.cute.indexing import match_cute_stack_reshape_rhs
from .matmul_ops import dot
fx_node = state.fx_node
if (
fx_node is None
or len(fx_node.users) != 1
or len(subscript) not in (2, 3)
or len(fx_node.args) < 2
):
return None
fx_subscript = fx_node.args[1]
if not isinstance(fx_subscript, (list, tuple)) or len(fx_subscript) != len(
subscript
):
return None
range_node = fx_subscript[-1]
if not isinstance(range_node, torch.fx.Node):
return None
affine_range = match_cute_affine_range_iota(range_node)
if affine_range is None:
return None
user = next(iter(fx_node.users))
if user.op != "call_function" or user.target not in {
dot,
torch.ops.aten.bmm.default,
torch.ops.aten.baddbmm.default,
torch.ops.aten.mm.default,
torch.ops.aten.addmm.default,
}:
return None
rhs_index = (
2
if user.target in (torch.ops.aten.addmm.default, torch.ops.aten.baddbmm.default)
else 1
)
rhs_arg = user.args[rhs_index]
if not isinstance(rhs_arg, torch.fx.Node):
return None
packed_rhs = match_cute_stack_reshape_rhs(rhs_arg)
if packed_rhs is None:
return None
_, factor = packed_rhs
if factor != affine_range.factor:
return None
packed_block_id = _cute_unique_graph_block_id(state)
if packed_block_id is None:
return None
packed_index = _cute_active_index_var(state, packed_block_id)
if packed_index is None:
return None
leading_subscript = [*subscript[:-1]]
row_index_exprs = _cute_index_exprs(
state,
leading_subscript,
tensor=tensor,
inactive_slice_expr="None",
inactive_singleton_slice_expr="0",
)
if len(row_index_exprs) != len(leading_subscript):
return None
tensor_name = state.device_function.tensor_arg(tensor).name
mask_terms: list[str] = []
row_mask = _cute_combined_mask(state, leading_subscript, extra_mask, tensor=tensor)
if row_mask is not None:
mask_terms.append(row_mask)
if packed_mask := _cute_active_mask_var(state, packed_block_id):
mask_terms.append(f"({packed_mask})")
mask_expr = " and ".join(mask_terms) if mask_terms else None
zero = CompileEnvironment.current().backend.dtype_str(tensor.dtype)
terms: list[ast.AST] = []
for offset in range(factor):
index_expr = ", ".join(
[
*row_index_exprs,
f"cutlass.Int32({factor}) * ({packed_index}) + cutlass.Int32({offset})",
]
)
term = expr_from_string(f"{tensor_name}[{index_expr}]")
if mask_expr is not None:
term = expr_from_string(
f"({{value}} if {mask_expr} else {zero}(0))",
value=term,
)
terms.append(term)
return CutePackedAffineLoad(tuple(terms))
def _maybe_codegen_cute_packed_rhs_load(
state: CodegenState,
tensor: torch.Tensor,
subscript: list[object] | tuple[object, ...],
extra_mask: ast.AST | None,
) -> ast.AST | None:
from .._compiler.cute.indexing import match_cute_duplicate_stack_reshape_rhs
fx_node = state.fx_node
if fx_node is None or len(subscript) not in (2, 3) or len(fx_node.users) != 1:
return None
user = next(iter(fx_node.users))
if user.op != "call_function" or user.target is not torch.ops.aten.stack.default:
return None
stack_users = list(user.users)
if len(stack_users) != 1 or not isinstance(stack_users[0], torch.fx.Node):
return None
rhs_node = stack_users[0]
packed_rhs = match_cute_duplicate_stack_reshape_rhs(rhs_node)
if packed_rhs != (
fx_node,
len(user.args[0]) if isinstance(user.args[0], (list, tuple)) else 0,
):
return None
packed_block_id = _cute_unique_graph_block_id(state)
if packed_block_id is None:
return None
packed_index = _cute_active_index_var(state, packed_block_id)
if packed_index is None:
return None
leading_subscript = [*subscript[:-2]]
col_index_exprs = _cute_index_exprs(
state,
[subscript[-1]],
tensor=tensor,
inactive_slice_expr="None",
inactive_singleton_slice_expr="0",
)
if len(col_index_exprs) != 1:
return None
(col_index,) = col_index_exprs
leading_index_exprs = _cute_index_exprs(
state,
leading_subscript,
tensor=tensor,
inactive_slice_expr="None",
inactive_singleton_slice_expr="0",
)
if len(leading_index_exprs) != len(leading_subscript):
return None
tensor_name = state.device_function.tensor_arg(tensor).name
load_index_expr = ", ".join([*leading_index_exprs, packed_index, col_index])
load_expr: ast.AST = expr_from_string(f"{tensor_name}[{load_index_expr}]")
mask_terms: list[str] = []
col_mask = _cute_combined_mask(
state,
[*leading_subscript, subscript[-1]],
extra_mask,
tensor=tensor,
)
if col_mask is not None:
mask_terms.append(col_mask)
if packed_mask := _cute_active_mask_var(state, packed_block_id):
mask_terms.append(f"({packed_mask})")
if not mask_terms:
return load_expr
zero = CompileEnvironment.current().backend.dtype_str(tensor.dtype)
return expr_from_string(
f"({{value}} if {' and '.join(mask_terms)} else {zero}(0))",
value=load_expr,
)
def _cute_index_exprs(
state: CodegenState,
subscript: list[object] | tuple[object, ...],
ast_subscript: list[object] | tuple[object, ...] | None = None,
tensor: torch.Tensor | None = None,
*,
inactive_slice_expr: str | None = None,
inactive_singleton_slice_expr: str | None = None,
) -> list[str]:
env = CompileEnvironment.current()
def symint_index_expr(idx: torch.SymInt, used_block_ids: set[int]) -> str:
expr = _symint_expr(idx)
if expr is not None:
origin_info = HostFunction.current().expr_to_origin.get(expr)
if origin_info is not None and isinstance(origin_info.origin, GridOrigin):
if type(origin_info.origin) is not GridOrigin:
block_id = origin_info.origin.block_id
loop_info = active_loop_info(block_id)
begin_var = tile_begin_expr(block_id, loop_info)
block_size_var = (
state.device_function.block_size_var(block_id) or "1"
)
if isinstance(origin_info.origin, TileBeginOrigin):
return begin_var
if isinstance(origin_info.origin, TileEndOrigin):
if loop_info is not None and loop_info.end_var_name is not None:
return env.backend.minimum_expr(
f"({begin_var}) + ({block_size_var})",
loop_info.end_var_name,
)
return f"({begin_var}) + ({block_size_var})"
if isinstance(origin_info.origin, TileCountOrigin):
end_var = (
loop_info.end_var_name
if loop_info is not None
and loop_info.end_var_name is not None
else f"({begin_var}) + ({block_size_var})"
)
extent = f"({end_var}) - ({begin_var})"
return env.backend.cdiv_expr(
extent, block_size_var, is_device=True
)
if isinstance(origin_info.origin, TileIdOrigin):
if block_size_var == "1":
return begin_var
return f"({begin_var}) // ({block_size_var})"
return state.sympy_expr(expr)
block_id = env.get_block_id(idx)
if block_id is not None:
used_block_ids.add(block_id)
return index_var_for_block_id(block_id, idx)
if expr is not None:
return state.sympy_expr(expr)
raise exc.BackendUnsupported("cute", f"unlowerable symbolic index: {idx}")
def active_loop_info(block_id: int) -> LoopDimInfo | None:
loops = state.codegen.active_device_loops.get(block_id)
if loops:
return loops[-1].block_id_to_info.get(block_id)
grid_state = state.codegen.current_grid_state
if grid_state is not None:
return grid_state.block_id_to_info.get(block_id)
return None
def active_local_coord(block_id: int) -> str | None:
from .._compiler.cute.cute_reshape import _grid_local_coord_expr
loops = state.codegen.active_device_loops.get(block_id)
if loops:
thread_axis = loops[-1].block_thread_axes.get(block_id)
if thread_axis is not None:
return _grid_local_coord_expr(state.codegen, block_id, thread_axis)
grid_state = state.codegen.current_grid_state
if grid_state is not None:
thread_axis = grid_state.block_thread_axes.get(block_id)
if thread_axis is not None:
return _grid_local_coord_expr(state.codegen, block_id, thread_axis)
return None
def tile_begin_expr(block_id: int, loop_info: LoopDimInfo | None) -> str:
loops = state.codegen.active_device_loops.get(block_id)
if loops:
return state.codegen.offset_var(block_id)
begin_var = "0"
if loop_info is not None and loop_info.begin_var_name is not None:
begin_var = loop_info.begin_var_name
global_index = active_index_var(block_id)
local_coord = active_local_coord(block_id)
if global_index is not None and local_coord is not None:
return state.codegen.lift(
expr_from_string(f"({global_index}) - ({local_coord})"),
dce=True,
prefix="tile_begin",
).id
if global_index is not None:
return global_index
return begin_var
def active_index_var(block_id: int) -> str | None:
loops = state.codegen.active_device_loops.get(block_id)
if loops:
return loops[-1].strategy.index_var(block_id)
grid_state = state.codegen.current_grid_state
if grid_state is not None and block_id in grid_state.block_ids:
return grid_state.strategy.index_var(block_id)
return None
def resolve_active_slice_block_id(
size: object,
used_block_ids: set[int],
) -> int | None:
candidates = _matching_block_ids(env, size)
active_candidates = [
block_id
for block_id in candidates
if active_index_var(block_id) is not None
]
active_unused_candidates = [
block_id for block_id in active_candidates if block_id not in used_block_ids
]
if len(active_unused_candidates) == 1:
return active_unused_candidates[0]
if len(active_candidates) == 1:
return active_candidates[0]
if len(active_unused_candidates) > 1:
reduction_unused = [
block_id
for block_id in active_unused_candidates
if env.block_sizes[block_id].reduction
]
if len(reduction_unused) == 1:
return reduction_unused[0]
if len(active_candidates) > 1:
reduction_active = [
block_id
for block_id in active_candidates
if env.block_sizes[block_id].reduction
]
if len(reduction_active) == 1:
return reduction_active[0]
return None
def index_var_for_block_id(block_id: int, size: object) -> str:
if (idx_var := active_index_var(block_id)) is not None:
return idx_var
raise exc.BackendUnsupported(
"cute",
(
"indexing dimension is not active in this scope "
f"(block_id={block_id}, size={size})"
),
)
def local_coord_for_block_id(block_id: int, begin_var: str) -> str | None:
if (local_coord := active_local_coord(block_id)) is not None:
return local_coord
if (idx_var := active_index_var(block_id)) is not None:
return f"({idx_var}) - ({begin_var})"
return None
def tile_with_offset_index_expr(tile_info: TileWithOffsetInfo) -> str:
block_id = tile_info.block_id
begin_var = tile_begin_expr(block_id, active_loop_info(block_id))
local_coord = local_coord_for_block_id(block_id, begin_var)
if local_coord is None:
raise exc.BackendUnsupported(
"cute",
(
"indexing dimension is not active in this scope "
f"(block_id={block_id})"
),
)
offset_expr = state.device_function.literal_expr(tile_info.offset)
return f"({begin_var}) + cutlass.Int32({offset_expr}) + ({local_coord})"
used_block_ids = {
block_id
for idx in subscript
if isinstance(idx, torch.SymInt)
if (block_id := env.get_block_id(idx)) is not None
}
result = []
tensor_dim = 0
for pos, idx in enumerate(subscript):
ast_idx = None
if ast_subscript is not None:
ast_idx = ast_subscript[pos]
if idx is None:
continue
if (
tensor is not None
and tensor_dim < tensor.ndim
and env.known_equal(tensor.shape[tensor_dim], 1)
and not (isinstance(idx, slice) and idx == slice(None))
):
result.append("0")
tensor_dim += 1
continue
if (
tile_info := _get_tile_with_offset_info(
idx, getattr(state, "fx_node", None), pos
)
) is not None and tile_info.block_size is not None:
used_block_ids.add(tile_info.block_id)
result.append(tile_with_offset_index_expr(tile_info))
tensor_dim += 1
continue
if isinstance(idx, torch.SymInt):
result.append(symint_index_expr(idx, used_block_ids))
tensor_dim += 1
elif isinstance(idx, int):
result.append(str(idx))
tensor_dim += 1
elif isinstance(idx, torch.Tensor):
from .._compiler.cute.indexing import CuteAffineRangeIndex
if isinstance(ast_idx, CuteAffineRangeIndex):
raise exc.BackendUnsupported(
"cute",
"affine hl.arange() indexing is only supported in CuTe packed-matmul load fusion",
)
if not isinstance(ast_idx, ast.AST):
raise exc.BackendUnsupported(
"cute", f"tensor index without AST at position {pos}"
)
lifted = state.codegen.lift(ast_idx, dce=True, prefix="index")
index_dtype = env.backend.dtype_str(env.index_dtype)
result.append(f"{index_dtype}({lifted.id})")
tensor_dim += 1
elif isinstance(idx, slice) and idx == slice(None):
if tensor is None:
raise exc.BackendUnsupported("cute", "slice indexing without tensor")
dim_size = tensor.shape[tensor_dim]
block_id = resolve_active_slice_block_id(dim_size, used_block_ids)
if block_id is not None:
idx_var = active_index_var(block_id)
assert idx_var is not None
used_block_ids.add(block_id)
result.append(idx_var)
tensor_dim += 1
continue
if inactive_singleton_slice_expr is not None and env.known_equal(
dim_size, 1
):
result.append(inactive_singleton_slice_expr)
tensor_dim += 1
continue
if inactive_slice_expr is None:
raise exc.BackendUnsupported(
"cute",
(
"indexing dimension is not active in this scope "
f"(tensor_dim={pos}, size={dim_size})"
),
)
result.append(inactive_slice_expr)
tensor_dim += 1
else:
raise exc.BackendUnsupported("cute", f"index type: {type(idx)}")
return result
def _cute_index_tuple(index_exprs: list[str]) -> str:
if len(index_exprs) == 1:
return f"({index_exprs[0]},)"
return f"({', '.join(index_exprs)})"
def _cute_scalar_pointer_expr(tensor_name: str, index_exprs: list[str]) -> str:
env = CompileEnvironment.current()
index_dtype = env.index_type()
offset = " + ".join(
f"({index_dtype}({index}) * {index_dtype}({tensor_name}.layout.stride[{dim}]))"
for dim, index in enumerate(index_exprs)
)
return f"({tensor_name}.iterator + {offset})"
def _cute_scalar_load_expr(tensor_name: str, index_exprs: list[str]) -> str:
if "None" in index_exprs:
return f"{tensor_name}[{', '.join(index_exprs)}]"
return f"{_cute_scalar_pointer_expr(tensor_name, index_exprs)}.load()"
def _cute_scalar_store_expr(
tensor_name: str, index_exprs: list[str], value: str
) -> str:
if "None" in index_exprs:
return f"{tensor_name}.__setitem__({_cute_index_tuple(index_exprs)}, {value})"
return f"{_cute_scalar_pointer_expr(tensor_name, index_exprs)}.store({value})"
def _cute_stack_tensor_offset_expr(
state: CodegenState,
tensor_like: torch.Tensor,
subscript: list[object],
ast_subscript: list[object] | tuple[object, ...],
) -> str:
env = CompileEnvironment.current()
index_exprs = _cute_index_exprs(
state,
subscript,
ast_subscript,
tensor=tensor_like,
inactive_slice_expr="None",
inactive_singleton_slice_expr="0",
)
if "None" in index_exprs:
raise exc.BackendUnsupported("cute", "inactive stack tensor load dimension")
index_dtype = env.index_type()
terms = []
for dim, index in enumerate(index_exprs):
stride = tensor_like.stride(dim)
stride_expr = (
str(stride) if isinstance(stride, int) else state.sympy_expr(stride)
)
terms.append(f"({index_dtype}({index}) * {index_dtype}({stride_expr}))")
return " + ".join(terms) if terms else "0"
def _cute_stack_tensor_mask_expr(
state: CodegenState,
tensor_like: torch.Tensor,
dev_ptrs: torch.Tensor,
subscript: list[object],
extra_mask: ast.AST | None,
) -> str | None:
terms = []
tensor_mask = _cute_combined_mask(
state,
subscript,
extra_mask,
tensor=tensor_like,
include_tensor_index_masks=False,
)
if tensor_mask is not None:
terms.append(tensor_mask)
stack_mask = _cute_combined_mask(
state,
[slice(None)] * dev_ptrs.ndim,
None,
tensor=dev_ptrs,
)
if stack_mask is not None and stack_mask not in terms:
terms.append(stack_mask)
if not terms:
return None
return " and ".join(f"({term})" for term in terms)
def _cute_stack_tensor_pointer_expr(
target_dtype: str,
dev_ptrs_ast: ast.AST,
offset_expr: str,
) -> ast.AST:
return expr_from_string(
f"(cute.make_ptr({target_dtype}, cutlass.Int64({{base}}), "
f"cute.AddressSpace.gmem) + ({offset_expr}))",
base=dev_ptrs_ast,
)
def _codegen_cute_store_stack_load(
state: CodegenState,
tensor: torch.Tensor,
subscript: tuple[object, ...] | list[object],
ast_subscript: tuple[object, ...] | list[object],
value: ast.AST,
extra_mask: ast.AST | None,
value_node: torch.fx.Node,
) -> ast.AST | None:
if value_node.op != "call_function" or value_node.target is not load:
return None
stack_arg = value_node.args[0]
if not isinstance(stack_arg, tuple) or len(stack_arg) != 2:
return None
ptr_node = stack_arg[1]
if (
not isinstance(ptr_node, torch.fx.Node)
or ptr_node.op != "call_function"
or ptr_node.target is not load
or len(ptr_node.args) < 2
):
return None
dev_ptrs = (
ptr_node.args[0].meta.get("val")
if isinstance(ptr_node.args[0], torch.fx.Node)
else None
)
ptr_subscript = ptr_node.args[1]
if not isinstance(dev_ptrs, torch.Tensor) or not isinstance(
ptr_subscript, (list, tuple)
):
return None
tensor_like_node = stack_arg[0]
tensor_like = (
tensor_like_node.meta.get("val")
if isinstance(tensor_like_node, torch.fx.Node)
else tensor_like_node
)
if not isinstance(tensor_like, torch.Tensor):
return None
if (
dev_ptrs.ndim == 2
and len(ptr_subscript) == 2
and all(isinstance(idx, slice) and idx == slice(None) for idx in ptr_subscript)
and len(subscript) >= 3
and isinstance(subscript[0], slice)
and subscript[0] == slice(None)
and isinstance(subscript[1], slice)
and subscript[1] == slice(None)
):
stack_value_subscript = value_node.args[1]
if not isinstance(stack_value_subscript, (list, tuple)):
return None
stack_value_subscript_proxy = map_arg(
stack_value_subscript, lambda arg: arg.meta["val"]
)
stack_value_subscript_ast = map_arg(
stack_value_subscript, lambda arg: state.env[arg]
)
tensor_offset_expr = _cute_stack_tensor_offset_expr(
state,
tensor_like,
[*stack_value_subscript_proxy],
[*stack_value_subscript_ast],
)
target_index_exprs = _cute_index_exprs(
state,
[*subscript],
ast_subscript,
tensor=tensor,
inactive_singleton_slice_expr="0",
)
if len(target_index_exprs) != tensor.ndim:
return None
first_stack_index = target_index_exprs[0]
target_tail = target_index_exprs[2:]
loop_var = state.device_function.new_var("stack_dim", dce=True)
env = CompileEnvironment.current()
index_dtype = env.index_type()
dev_ptrs_name = state.device_function.tensor_arg(dev_ptrs).name
tensor_name = state.device_function.tensor_arg(tensor).name
target_dtype = env.backend.dtype_str(tensor.dtype)
dev_ptr_offset = (
f"{index_dtype}({first_stack_index}) * "
f"{index_dtype}({dev_ptrs.stride(0)}) + "
f"{index_dtype}({loop_var}) * {index_dtype}({dev_ptrs.stride(1)})"
)
stack_ptr_expr = (
f"(cute.make_ptr({target_dtype}, "
f"cutlass.Int64(({dev_ptrs_name}.iterator + {dev_ptr_offset}).load()), "
f"cute.AddressSpace.gmem) + ({tensor_offset_expr}))"
)
target_indices = [first_stack_index, loop_var, *target_tail]
store_expr = _cute_scalar_store_expr(
tensor_name,
target_indices,
f"({stack_ptr_expr}).load()",
)
mask_expr = _cute_combined_mask(state, [*subscript], extra_mask, tensor=tensor)
if mask_expr is None:
body = f" {store_expr}"
else:
body = f" if {mask_expr}:\n {store_expr}"
state.add_statement(
statement_from_string(
f"for {loop_var} in range({dev_ptrs.size(1)}):\n{body}"
)
)
return ast.Constant(value=None)
ptr_subscript_proxy = map_arg(ptr_subscript, lambda arg: arg.meta["val"])
ptr_subscript_ast = map_arg(ptr_subscript, lambda arg: state.env[arg])
ptr_index_exprs = _cute_index_exprs(
state,
[*ptr_subscript_proxy],
[*ptr_subscript_ast],
tensor=dev_ptrs,
inactive_slice_expr="None",
inactive_singleton_slice_expr="0",
)
if "None" in ptr_index_exprs:
return None
target_index_exprs = _cute_index_exprs(
state,
[*subscript],
ast_subscript,
tensor=tensor,
inactive_singleton_slice_expr="0",
)
ptr_pos = 0
rewritten_index_exprs = []
for idx, index_expr in zip(subscript, target_index_exprs, strict=True):
if isinstance(idx, slice) and idx == slice(None):
replacement = (
ptr_index_exprs[ptr_pos] if ptr_pos < len(ptr_index_exprs) else None
)
ptr_pos += 1
rewritten_index_exprs.append(
replacement if replacement is not None else index_expr
)
else:
if ptr_pos < len(ptr_subscript_proxy) and not (
isinstance(ptr_subscript_proxy[ptr_pos], slice)
and ptr_subscript_proxy[ptr_pos] == slice(None)
):
ptr_pos += 1
rewritten_index_exprs.append(index_expr)
tensor_name = state.device_function.tensor_arg(tensor).name
backend = CompileEnvironment.current().backend
target_dtype = backend.dtype_str(tensor.dtype)
value = expr_from_string(
backend.ast_to_dtype_expr("{value}", target_dtype),
value=value,
)
store_expr = expr_from_string(
_cute_scalar_store_expr(tensor_name, rewritten_index_exprs, "{value}"),
value=value,
)
mask_expr = _cute_combined_mask(state, [*subscript], extra_mask, tensor=tensor)
if mask_expr is None:
return store_expr
mask_ast = expr_from_string(mask_expr)
assert isinstance(mask_ast, ast.expr)
assert isinstance(store_expr, ast.expr)
state.add_statement(
ast.fix_missing_locations(
ast.If(
test=mask_ast,
body=[ast.Expr(value=store_expr)],
orelse=[],
)
)
)
return ast.Constant(value=None)
def _cute_affine_range_block_id(state: CodegenState, affine: object) -> int | None:
from .._compiler.cute.indexing import CuteAffineRangeIndex
if not isinstance(affine, CuteAffineRangeIndex):
return None
env = CompileEnvironment.current()
base_meta = getattr(affine.base, "meta", {})
base_val = base_meta.get("val") if isinstance(base_meta, dict) else None
block_id = env.resolve_block_id(base_val) if base_val is not None else None
if block_id is None:
codegen = base_meta.get("codegen") if isinstance(base_meta, dict) else None
if isinstance(codegen, ast.Name) and codegen.id.startswith("_BLOCK_SIZE_"):
with contextlib.suppress(ValueError):
block_id = int(codegen.id.removeprefix("_BLOCK_SIZE_"))
if block_id is None:
return None
if state.fx_node is not None:
return env.resolve_codegen_block_id(
block_id, state.codegen, state.fx_node.graph
)
return block_id
def _cute_affine_range_expr(
state: CodegenState,
affine: object,
lane_var: str,
*,
dtype: torch.dtype | None = None,
) -> str | None:
from .._compiler.cute.indexing import CuteAffineRangeIndex
if not isinstance(affine, CuteAffineRangeIndex):
return None
if affine.step != 1 or affine.factor <= 0:
return None
block_id = _cute_affine_range_block_id(state, affine)
if block_id is None:
return None
index_var = _cute_active_index_var(state, block_id)
if index_var is None:
return None
expr = f"({affine.factor}) * ({index_var}) + cutlass.Int32({lane_var})"
if dtype is not None:
expr = f"{CompileEnvironment.current().backend.dtype_str(dtype)}({expr})"
return expr
def _codegen_cute_affine_range_store(
state: CodegenState,
tensor: torch.Tensor,
subscript: list[object] | tuple[object, ...],
ast_subscript: list[object] | tuple[object, ...],
value: object,
extra_mask: ast.AST | None,
value_node: torch.fx.Node | None = None,
) -> ast.AST | None:
from .._compiler.ast_extension import create
from .._compiler.cute.indexing import CuteAffineRangeIndex
affine_positions = [
(pos, idx)
for pos, idx in enumerate(ast_subscript)
if isinstance(idx, CuteAffineRangeIndex)
]
if len(affine_positions) != 1 or len(subscript) != 1 or extra_mask is not None:
return None
_pos, affine = affine_positions[0]
block_id = _cute_affine_range_block_id(state, affine)
if block_id is None:
return None
lane_var = state.device_function.new_var("affine_lane", dce=True)
index_expr = _cute_affine_range_expr(
state, affine, lane_var, dtype=CompileEnvironment.current().index_dtype
)
if index_expr is None:
return None
backend = CompileEnvironment.current().backend
if (
value_node is not None
and value_node.op == "call_function"
and value_node.target is load
):
source_tensor_node = value_node.args[0]
if not isinstance(source_tensor_node, torch.fx.Node):
return None
source_tensor = source_tensor_node.meta.get("val")
if not isinstance(source_tensor, torch.Tensor):
return None
source_subscript = value_node.args[1]
if (
not isinstance(source_subscript, (list, tuple))
or len(source_subscript) != 1
):
return None
ast_source_subscript = list(
map_arg(tuple(source_subscript), lambda arg: state.env[arg])
)
(source_affine,) = ast_source_subscript
if not isinstance(source_affine, CuteAffineRangeIndex):
return None
if source_affine.factor != affine.factor:
return None
source_index_expr = _cute_affine_range_expr(
state,
source_affine,
lane_var,
dtype=CompileEnvironment.current().index_dtype,
)
if source_index_expr is None:
return None
source_name = state.device_function.tensor_arg(source_tensor).name
value_expr = f"{source_name}[{source_index_expr}]"
if source_tensor.dtype is torch.bool:
value_expr = f"({value_expr} != cutlass.Uint8(0))"
elif isinstance(value, CuteAffineRangeIndex):
value_expr = _cute_affine_range_expr(state, value, lane_var, dtype=value.dtype)
if value_expr is None:
return None
elif isinstance(value, ast.AST):
value_expr = ast.unparse(value)
elif isinstance(value, (int, float, bool)):
value_expr = repr(value)
else:
return None
target_dtype = backend.dtype_str(tensor.dtype)
value_expr = backend.ast_to_dtype_expr(value_expr, target_dtype)
tensor_name = state.device_function.tensor_arg(tensor).name
store_expr = (
f"{tensor_name}.__setitem__({_cute_index_tuple([index_expr])}, {value_expr})"
)
mask_var = _cute_active_mask_var(state, block_id)
if mask_var is not None:
store_expr = f"{store_expr} if {mask_var} else None"
return create(
ast.For,
target=create(ast.Name, id=lane_var, ctx=ast.Store()),
iter=expr_from_string(f"range({affine.factor})"),
body=[create(ast.Expr, value=expr_from_string(store_expr))],
orelse=[],
type_comment=None,
)
def _is_cute_affine_range_load_for_store(
state: CodegenState,
subscript: list[object] | tuple[object, ...],
ast_subscript: list[object] | tuple[object, ...],
) -> bool:
from .._compiler.cute.indexing import CuteAffineRangeIndex
from .._compiler.cute.indexing import match_cute_affine_range_iota
def compatible_store_user(user: torch.fx.Node) -> bool:
if (
user.op != "call_function"
or user.target is not store
or len(user.args) < 4
or user.args[2] is not state.fx_node
or user.args[3] is not None
):
return False
store_subscript = user.args[1]
return (
isinstance(store_subscript, (list, tuple))
and len(store_subscript) == 1
and isinstance(store_subscript[0], torch.fx.Node)
and match_cute_affine_range_iota(store_subscript[0]) is not None
)
return (
state.fx_node is not None
and len(state.fx_node.users) > 0
and all(compatible_store_user(user) for user in state.fx_node.users)
and len(subscript) == 1
and len(ast_subscript) == 1
and isinstance(ast_subscript[0], CuteAffineRangeIndex)
)
def _cute_positive_1d_slice_bounds(
tensor: torch.Tensor, index: object
) -> tuple[int, int, int, int] | None:
if not isinstance(index, slice) or index == slice(None):
return None
with contextlib.suppress(TypeError):
dim_size = int(tensor.shape[0])
start, stop, step = index.indices(dim_size)
if step <= 0:
return None
length = max(0, (stop - start + step - 1) // step)
return start, stop, step, length
return None
def _is_cute_strided_slice_load_for_store(
state: CodegenState,
tensor: torch.Tensor,
subscript: list[object] | tuple[object, ...],
) -> bool:
def compatible_store_user(user: torch.fx.Node) -> bool:
if (
user.op != "call_function"
or user.target is not store
or len(user.args) < 4
or user.args[2] is not state.fx_node
or user.args[3] is not None
):
return False
target_node = user.args[0]
if not isinstance(target_node, torch.fx.Node):
return False
target_tensor = target_node.meta.get("val")
if not isinstance(target_tensor, torch.Tensor) or target_tensor.ndim != 1:
return False
store_subscript = user.args[1]
return (
isinstance(store_subscript, (list, tuple))
and len(store_subscript) == 1
and _cute_positive_1d_slice_bounds(target_tensor, store_subscript[0])
is not None
)
return (
state.fx_node is not None
and len(state.fx_node.users) > 0
and all(compatible_store_user(user) for user in state.fx_node.users)
and tensor.ndim == 1
and len(subscript) == 1
and _cute_positive_1d_slice_bounds(tensor, subscript[0]) is not None
)
def _codegen_cute_strided_slice_store(
state: CodegenState,
tensor: torch.Tensor,
subscript: list[object] | tuple[object, ...],
value: object,
extra_mask: ast.AST | None,
value_node: torch.fx.Node | None = None,
) -> ast.AST | None:
from .._compiler.ast_extension import create
if tensor.ndim != 1 or len(subscript) != 1 or extra_mask is not None:
return None
target_bounds = _cute_positive_1d_slice_bounds(tensor, subscript[0])
if target_bounds is None:
return None
target_start, _target_stop, target_step, target_length = target_bounds
env = CompileEnvironment.current()
backend = env.backend
index_dtype = backend.dtype_str(env.index_dtype)
loop_var = state.device_function.new_var("slice_idx", dce=True)
target_index = f"{index_dtype}({target_start} + {loop_var} * {target_step})"
if (
value_node is not None
and value_node.op == "call_function"
and value_node.target is load
):
source_tensor_node = value_node.args[0]
if not isinstance(source_tensor_node, torch.fx.Node):
return None
source_tensor = source_tensor_node.meta.get("val")
if not isinstance(source_tensor, torch.Tensor) or source_tensor.ndim != 1:
return None
source_subscript = value_node.args[1]
if (
not isinstance(source_subscript, (list, tuple))
or len(source_subscript) != 1
):
return None
source_bounds = _cute_positive_1d_slice_bounds(
source_tensor, source_subscript[0]
)
if source_bounds is None:
return None
source_start, _source_stop, source_step, source_length = source_bounds
if source_length != target_length:
return None
source_index = f"{index_dtype}({source_start} + {loop_var} * {source_step})"
source_name = state.device_function.tensor_arg(source_tensor).name
value_expr = f"{source_name}[{source_index}]"
if source_tensor.dtype is torch.bool:
value_expr = f"({value_expr} != cutlass.Uint8(0))"
elif isinstance(value, ast.AST):
value_expr = ast.unparse(value)
elif isinstance(value, (int, float, bool)):
value_expr = repr(value)
else:
return None
target_name = state.device_function.tensor_arg(tensor).name
target_dtype = backend.dtype_str(tensor.dtype)
value_expr = backend.ast_to_dtype_expr(value_expr, target_dtype)
store_expr = f"{target_name}.__setitem__(({target_index},), {value_expr})"
return create(
ast.For,
target=create(ast.Name, id=loop_var, ctx=ast.Store()),
iter=expr_from_string(f"range({target_length})"),
body=[create(ast.Expr, value=expr_from_string(store_expr))],
orelse=[],
type_comment=None,
)
def _cute_combined_mask(
state: CodegenState,
subscript: list[object] | tuple[object, ...],
extra_mask: ast.AST | None,
tensor: torch.Tensor | None = None,
*,
include_tensor_index_masks: bool = True,
) -> str | None:
env = CompileEnvironment.current()
terms: list[str] = []
def mask_var_for_block_id(block_id: int) -> str | None:
loops = state.codegen.active_device_loops.get(block_id)
if loops:
return loops[-1].strategy.mask_var(block_id)
return None
def active_index_var(block_id: int) -> str | None:
loops = state.codegen.active_device_loops.get(block_id)
if loops:
return loops[-1].strategy.index_var(block_id)
grid_state = state.codegen.current_grid_state
if grid_state is not None and block_id in grid_state.block_ids:
return grid_state.strategy.index_var(block_id)
return None
def active_local_coord(block_id: int) -> str | None:
from .._compiler.cute.cute_reshape import _grid_local_coord_expr
loops = state.codegen.active_device_loops.get(block_id)
if loops:
thread_axis = loops[-1].block_thread_axes.get(block_id)
if thread_axis is not None:
return _grid_local_coord_expr(state.codegen, block_id, thread_axis)
grid_state = state.codegen.current_grid_state
if grid_state is not None:
thread_axis = grid_state.block_thread_axes.get(block_id)
if thread_axis is not None:
return _grid_local_coord_expr(state.codegen, block_id, thread_axis)
return None
def tile_begin_expr(block_id: int) -> str:
loops = state.codegen.active_device_loops.get(block_id)
if loops:
return state.codegen.offset_var(block_id)
global_index = active_index_var(block_id)
local_coord = active_local_coord(block_id)
if global_index is not None and local_coord is not None:
return state.codegen.lift(
expr_from_string(f"({global_index}) - ({local_coord})"),
dce=True,
prefix="tile_begin",
).id
if global_index is not None:
return global_index
return "0"
def tile_with_offset_mask_terms(
tile_info: TileWithOffsetInfo,
tensor_dim: int,
) -> list[str]:
block_id = tile_info.block_id
local_coord = active_local_coord(block_id)
begin_var = tile_begin_expr(block_id)
if local_coord is None:
if (idx_var := active_index_var(block_id)) is None:
raise exc.BackendUnsupported(
"cute",
(
"indexing dimension is not active in this scope "
f"(block_id={block_id})"
),
)
local_coord = f"({idx_var}) - ({begin_var})"
tile_terms = []
if tile_info.block_size is not None:
block_size_expr = state.device_function.literal_expr(tile_info.block_size)
tile_terms.append(f"({local_coord}) < cutlass.Int32({block_size_expr})")
if tensor is not None and tensor_dim < tensor.ndim:
offset_expr = state.device_function.literal_expr(tile_info.offset)
dim_size = _cute_tensor_dim_size_expr(state, tensor, tensor_dim)
tile_terms.append(
f"(({begin_var}) + cutlass.Int32({offset_expr}) + "
f"({local_coord})) < {dim_size}"
)
return tile_terms
if extra_mask is not None:
terms.append(state.codegen.lift(extra_mask, dce=True, prefix="mask").id)
seen: set[int] = set()
tensor_dim = 0
for pos, idx in enumerate(subscript):
block_id: int | None = None
if idx is None:
continue
if (
tile_info := _get_tile_with_offset_info(
idx, getattr(state, "fx_node", None), pos
)
) is not None and tile_info.block_size is not None:
seen.add(tile_info.block_id)
for term in tile_with_offset_mask_terms(tile_info, tensor_dim):
if term not in terms:
terms.append(term)
tensor_dim += 1
continue
if isinstance(idx, torch.SymInt):
block_id = env.get_block_id(idx)
elif isinstance(idx, slice) and idx == slice(None) and tensor is not None:
for bid in _matching_block_ids(env, tensor.shape[tensor_dim]):
if bid not in seen and mask_var_for_block_id(bid) is not None:
block_id = bid
break
elif isinstance(idx, torch.Tensor):
if not include_tensor_index_masks:
for dim_size in idx.shape:
for bid in _matching_block_ids(env, dim_size):
if bid in seen or not env.is_jagged_tile(bid):
continue
mask_var = mask_var_for_block_id(bid)
if mask_var is not None:
seen.add(bid)
if mask_var not in terms:
terms.append(mask_var)
break
tensor_dim += 1
continue
for dim_size in idx.shape:
for bid in _matching_block_ids(env, dim_size):
if bid in seen:
continue
mask_var = mask_var_for_block_id(bid)
if mask_var is not None:
seen.add(bid)
if mask_var not in terms:
terms.append(mask_var)
break
else:
continue
tensor_dim += 1
continue
else:
tensor_dim += 1
continue
if block_id is None or block_id in seen:
tensor_dim += 1
continue
seen.add(block_id)
if (mask_var := mask_var_for_block_id(block_id)) is not None:
if mask_var not in terms:
terms.append(mask_var)
tensor_dim += 1
if not terms:
return None
return " and ".join(f"({term})" for term in terms)
def _cute_tensor_dim_size_expr(
state: CodegenState, tensor: torch.Tensor, dim: int
) -> str:
return state.device_function.tensor_size(tensor, dim).name
def _cute_tile_begin_expr(state: CodegenState, idx: object) -> str:
env = CompileEnvironment.current()
def active_index_var(block_id: int) -> str | None:
loops = state.codegen.active_device_loops.get(block_id)
if loops:
return loops[-1].strategy.index_var(block_id)
grid_state = state.codegen.current_grid_state
if grid_state is not None and block_id in grid_state.block_ids:
return grid_state.strategy.index_var(block_id)
return None
def active_local_coord(block_id: int) -> str | None:
from .._compiler.cute.cute_reshape import _grid_local_coord_expr
loops = state.codegen.active_device_loops.get(block_id)
if loops:
thread_axis = loops[-1].block_thread_axes.get(block_id)
if thread_axis is not None:
return _grid_local_coord_expr(state.codegen, block_id, thread_axis)
grid_state = state.codegen.current_grid_state
if grid_state is not None:
thread_axis = grid_state.block_thread_axes.get(block_id)
if thread_axis is not None:
return _grid_local_coord_expr(state.codegen, block_id, thread_axis)
return None
def tile_begin_from_block_id(block_id: int) -> str:
loops = state.codegen.active_device_loops.get(block_id)
if loops:
return state.codegen.offset_var(block_id)
global_index = active_index_var(block_id)
local_coord = active_local_coord(block_id)
if global_index is not None and local_coord is not None:
return state.codegen.lift(
expr_from_string(f"({global_index}) - ({local_coord})"),
dce=True,
prefix="tile_begin",
).id
if global_index is not None:
return global_index
return "0"
if isinstance(idx, int):
return str(idx)
if not isinstance(idx, torch.SymInt):
raise exc.BackendUnsupported("cute", f"tile base index type: {type(idx)}")
expr = _symint_expr(idx)
if expr is not None:
origin_info = HostFunction.current().expr_to_origin.get(expr)
if origin_info is not None and isinstance(origin_info.origin, TileBeginOrigin):
return tile_begin_from_block_id(origin_info.origin.block_id)
block_id = env.get_block_id(idx)
if block_id is not None:
return tile_begin_from_block_id(block_id)
if expr is not None:
return state.sympy_expr(expr)
raise exc.BackendUnsupported("cute", f"unlowerable tile base index: {idx}")
def _codegen_cute_store_tcgen05_tile(
state: CodegenState,
tensor: torch.Tensor,
subscript: list[object] | tuple[object, ...],
ast_subscript: list[object] | tuple[object, ...],
extra_mask: ast.AST | None,
value_name: str,
) -> list[ast.AST] | ast.AST | None:
if extra_mask is not None or tensor.ndim != 2:
return None
tcgen05_value = state.device_function.get_cute_tcgen05_store_value(value_name)
if tcgen05_value is None:
return None
# Backstop for callers that bypass Config.normalize() validation;
# see _tcgen05_epi_warp_count docstring and cute_plan.md.
if tcgen05_value.epi_warp_count != 4:
raise exc.BackendUnsupported(
"cute",
f"tcgen05 SIMT-store epilogue requires "
f"tcgen05_num_epi_warps=4 (got {tcgen05_value.epi_warp_count}). "
"CUTLASS tmem_warp_shape_mn=(4,1) hard-codes a 4-warp t2r "
"partition for the supported tcgen05 path; per-warp "
"tcgen05.ld semantics make the partition uncoverable by "
"fewer warps. Lifts when the c_pipeline-driven multi-warp "
"epilogue lands (see cute_plan.md).",
)
backend = CompileEnvironment.current().backend
df = state.device_function
tensor_name = df.tensor_arg(tensor).name
target_dtype = backend.dtype_str(tensor.dtype)
base_indices = [_cute_tile_begin_expr(state, idx) for idx in subscript]
if len(base_indices) != 2:
return None
m_size = _cute_tensor_dim_size_expr(state, tensor, 0)
n_size = _cute_tensor_dim_size_expr(state, tensor, 1)
tile_coord_m = f"({base_indices[0]}) // cutlass.Int32({tcgen05_value.bm})"
tile_coord_n = f"({base_indices[1]}) // cutlass.Int32({tcgen05_value.bn})"
full_tile = df.new_var("tcgen05_full_tile")
gmem_tile = df.new_var("tcgen05_gC")
coord_tile = df.new_var("tcgen05_cC")
tcgc_base = df.new_var("tcgen05_tCgC_base")
tccc_base = df.new_var("tcgen05_tCcC_base")
tcgc = df.new_var("tcgen05_tCgC")
tcgc_planned = df.new_var("tcgen05_tCgC_planned")
tccc = df.new_var("tcgen05_tCcC")
tacc = df.new_var("tcgen05_tAcc")
epi_tile = df.new_var("tcgen05_store_epi_tile")
tiled_copy_t2r = df.new_var("tcgen05_tiled_copy_t2r")
thr_copy_t2r = df.new_var("tcgen05_thr_copy_t2r")
ttr_tacc_base = df.new_var("tcgen05_tTR_tAcc_base")
tcgc_epi = df.new_var("tcgen05_tCgC_epi")
tccc_epi = df.new_var("tcgen05_tCcC_epi")
ttr_gc = df.new_var("tcgen05_tTR_gC")
ttr_cc = df.new_var("tcgen05_tTR_cC")
ttr_racc = df.new_var("tcgen05_tTR_rAcc")
ttr_rd = df.new_var("tcgen05_tTR_rD")
ttr_tacc_stage = df.new_var("tcgen05_tTR_tAcc_stage")
ttr_tacc = df.new_var("tcgen05_tTR_tAcc")
ttr_gc_grouped = df.new_var("tcgen05_tTR_gC_grouped")
ttr_cc_grouped = df.new_var("tcgen05_tTR_cC_grouped")
ttr_tacc_mn = df.new_var("tcgen05_tTR_tAcc_mn")
ttr_gc_subtile = df.new_var("tcgen05_tTR_gC_subtile")
ttr_cc_subtile = df.new_var("tcgen05_tTR_cC_subtile")
pred_c = df.new_var("tcgen05_pred_C")
pred_c_shape = df.new_var("tcgen05_pred_C_shape")
acc_vec = df.new_var("tcgen05_acc_vec")
kernel_desc = df.new_var("tcgen05_kernel_desc")
mcld = df.new_var("tcgen05_mcld")
num_bits = df.new_var("tcgen05_num_bits")
simt_atom = df.new_var("tcgen05_simt_atom")
smem_d_layout = df.new_var("tcgen05_sD_layout")
smem_d_ptr = df.new_var("tcgen05_sD_ptr")
smem_d = df.new_var("tcgen05_sD")
tiled_copy_r2s = df.new_var("tcgen05_tiled_copy_r2s")
trs_rd = df.new_var("tcgen05_tRS_rD")
trs_racc = df.new_var("tcgen05_tRS_rAcc")
trs_sd = df.new_var("tcgen05_tRS_sD")
bsg_sd = df.new_var("tcgen05_bSG_sD")
bsg_gd_partitioned = df.new_var("tcgen05_bSG_gD_partitioned")
bsg_gd = df.new_var("tcgen05_bSG_gD")
c_buffer = df.new_var("tcgen05_c_buffer")
epilog_sync_barrier = df.new_var("tcgen05_epilog_sync_barrier")
c_pipeline_producer_group = df.new_var("tcgen05_c_pipeline_producer_group")
c_pipeline = df.new_var("tcgen05_c_pipeline")
subtile_count = df.new_var("tcgen05_subtile_count")
epi_warp_ids = ", ".join(
f"cutlass.Int32({i})" for i in range(tcgen05_value.epi_warp_count)
)
if tcgen05_value.epi_warp_count == 1:
epi_warp_ids += ","
if tcgen05_value.use_tma_store_epilogue:
df.placeholder_args.add(tensor_name)
df.wrapper_only_params.extend(
[tcgen05_value.tma_store_atom, tcgen05_value.tma_store_tensor]
)
if tcgen05_value.use_role_local_epi:
df.register_cute_tcgen05_epi_role_tile_counter(
tcgen05_value.role_local_tile_counter
)
state.codegen.cute_wrapper_plans.append(
{
"kind": "tcgen05_d_tma",
"d_name": tensor_name,
"bm": tcgen05_value.bm,
"bn": tcgen05_value.bn,
"c_stage_count": tcgen05_value.c_stage_count,
"output_dtype": target_dtype,
"kernel_args": [
tcgen05_value.tma_store_atom,
tcgen05_value.tma_store_tensor,
],
}
)
tcgen05_bm = tcgen05_value.bm
tcgen05_bn = tcgen05_value.bn
tcgen05_bk = tcgen05_value.bk
tcgen05_epilog_sync_barrier_id = tcgen05_value.epilog_sync_barrier_id
tcgen05_c_stage_count = tcgen05_value.c_stage_count
tcgen05_is_two_cta = tcgen05_value.is_two_cta
tcgen05_thr_mma = tcgen05_value.thr_mma
def store_common_setup(
gmem_tensor: str, *, include_full_tile: bool
) -> tuple[list[str], list[str]]:
static_setup = [
(
f"{kernel_desc} = type('Tcgen05KernelDesc', (), {{"
f"'cta_tile_shape_mnk': ({tcgen05_bm}, {tcgen05_bn}, {tcgen05_bk}), "
"'c_layout': cutlass.utils.layout.LayoutEnum.ROW_MAJOR, "
f"'c_dtype': {target_dtype}, "
"'acc_dtype': cutlass.Float32, "
f"'epilog_sync_bar_id': cutlass.Int32({tcgen05_epilog_sync_barrier_id}), "
f"'epilogue_warp_id': ({epi_warp_ids}), "
f"'num_c_stage': cutlass.Int32({tcgen05_c_stage_count}), "
f"'use_2cta_instrs': {tcgen05_is_two_cta!s}"
"})()"
),
(
f"{epi_tile} = cutlass.utils.blackwell_helpers.compute_epilogue_tile_shape("
f"({tcgen05_bm}, {tcgen05_bn}), False, "
f"cutlass.utils.layout.LayoutEnum.ROW_MAJOR, {target_dtype})"
),
]
tile_setup: list[str] = []
if include_full_tile:
tile_setup.append(
f"{full_tile} = "
f"({base_indices[0]}) + cutlass.Int32({tcgen05_bm}) <= {m_size} "
f"and ({base_indices[1]}) + cutlass.Int32({tcgen05_bn}) <= {n_size}"
)
tile_setup.extend(
[
(
f"{gmem_tile} = cute.local_tile("
f"{gmem_tensor}, ({tcgen05_bm}, {tcgen05_bn}), "
f"({tile_coord_m}, {tile_coord_n}))"
),
f"{tcgc_base} = {tcgen05_thr_mma}.partition_C({gmem_tile})",
]
)
return static_setup, tile_setup
simt_static_store_setup, simt_tile_store_setup = store_common_setup(
tensor_name, include_full_tile=True
)
tma_static_store_setup, tma_tile_store_setup = store_common_setup(
tcgen05_value.tma_store_tensor, include_full_tile=False
)
tma_c_buffer_expr = "cutlass.Int32(_tcgen05_subtile)"
if tcgen05_value.role_local_tile_counter:
tma_c_buffer_expr = (
f"{tcgen05_value.role_local_tile_counter} * "
f"cutlass.Int32({subtile_count}) + cutlass.Int32(_tcgen05_subtile)"
)
simt_store_body_core = [
*simt_static_store_setup,
*simt_tile_store_setup,
(
f"{tcgc} = cutlass.utils.gemm.sm100.transform_partitioned_tensor_layout("
f"{tcgc_base})"
),
(
f"{tcgc_planned} = cute.make_tensor("
f"{tcgc}.iterator, "
f"cute.append(cute.append(cute.append({tcgc}.layout, {tcgen05_value.epilogue_rest_mode}), {tcgen05_value.epilogue_rest_mode}), {tcgen05_value.epilogue_rest_mode}))"
),
(
f"{tacc} = cutlass.utils.gemm.sm100.transform_partitioned_tensor_layout("
f"{tcgen05_value.epi_acc_frag_base})"
),
(
f"{tiled_copy_t2r}, {ttr_tacc_base}, {ttr_racc} = "
"cutlass.utils.gemm.sm100.epilogue_tmem_copy_and_partition("
f"{kernel_desc}, {tcgen05_value.epi_tidx}, {tacc}, {tcgc_planned}, {epi_tile}, {tcgen05_value.is_two_cta!s})"
),
f"{thr_copy_t2r} = {tiled_copy_t2r}.get_slice({tcgen05_value.epi_tidx})",
f"{tcgc_epi} = cute.flat_divide({tcgc_planned}, {epi_tile})",
f"{ttr_gc} = {thr_copy_t2r}.partition_D({tcgc_epi})",
(
f"{ttr_tacc_stage} = {ttr_tacc_base}["
f"(None, None, None, None, None, {tcgen05_value.acc_consumer_state}.index)]"
),
(
f"if {tcgen05_value.epi_active}:\n"
f" {tcgen05_value.acc_pipeline}.consumer_wait({tcgen05_value.acc_consumer_state})"
),
f"{ttr_tacc} = cute.group_modes({ttr_tacc_stage}, 3, cute.rank({ttr_tacc_stage}))",
f"{ttr_gc_grouped} = cute.group_modes({ttr_gc}, 3, cute.rank({ttr_gc}))",
(
f"{ttr_racc} = cute.make_rmem_tensor("
f"{ttr_gc_grouped}[(None, None, None, 0)].shape, cutlass.Float32)"
),
f"{ttr_rd} = cute.make_rmem_tensor({ttr_racc}.shape, {target_dtype})",
(
f"{mcld} = cute.max_common_layout("
f"{ttr_rd}.layout, {ttr_gc_grouped}[(None, None, None, 0)].layout)"
),
(
f"{num_bits} = min("
f"{ttr_gc_grouped}.iterator.alignment * 8, "
f"cute.size({mcld}) * {target_dtype}.width, 256)"
),
(
f"{simt_atom} = cute.make_copy_atom("
f"cute.nvgpu.CopyUniversalOp(), {target_dtype}, "
f"num_bits_per_copy={num_bits}, "
f"l1c_evict_priority=cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE)"
),
f"{subtile_count} = cutlass.const_expr(cute.size({ttr_tacc}.shape, mode=[3]))",
(
# Per-subtile loop: TMEM->reg (t2r) first, then reg->GMEM (SIMT
# store). On the last subtile we release the acc consumer slot
# *before* the GMEM store so the next mainloop tile's MMA can
# producer_acquire the TMEM stage and begin issuing UMMAs while
# this tile's epilogue is still draining to GMEM. This mirrors the
# release-acc-inside-the-subtile-loop pattern in Quack's sm100
# gemm epilogue. Without c_pipeline SMEM staging we can only
# release after the final t2r (not per-subtile), but even one
# tile of overlap measurably improves the wide tcgen05 path on
# B200. `cutlass.range(..., unroll_full=True)` keeps the loop
# statically unrolled so `tiled_copy_t2r` (a TiledCopy that wraps
# a tcgen05 tmem_load atom) is not captured as an scf.for iter_arg
# — the cute-to-nvvm pass cannot legalize that conversion through
# iter_args and aborts during compile.
f"for _tcgen05_subtile in cutlass.range({subtile_count}, unroll_full=True):\n"
f" if {tcgen05_value.epi_active}:\n"
f" {ttr_tacc_mn} = {ttr_tacc}[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n"
f" {ttr_gc_subtile} = {ttr_gc_grouped}[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n"
f" cute.copy({tiled_copy_t2r}, {ttr_tacc_mn}, {ttr_racc})\n"
f" {acc_vec} = {ttr_racc}.load().to({target_dtype})\n"
f" {ttr_rd}.store({acc_vec})\n"
f" if _tcgen05_subtile == {subtile_count} - 1:\n"
# `cute.copy(t2r, ...)` issues async TMEM->reg loads. Releasing
# the acc consumer slot lets the MMA producer re-acquire the TMEM
# stage and issue UMMAs that overwrite TMEM, so we must fence the
# in-flight async TMEM loads first to avoid a race on the last
# subtile's `ttr_racc` / `ttr_rd` data. This matches Quack's
# sm100 gemm fence-before-release pattern.
f" cute.arch.fence_view_async_tmem_load()\n"
f" with cute.arch.elect_one():\n"
f" {tcgen05_value.acc_pipeline}.consumer_release({tcgen05_value.acc_consumer_state})\n"
f" if {full_tile}:\n"
f" cute.copy({simt_atom}, {ttr_rd}, {ttr_gc_subtile})\n"
f" else:\n"
f" {coord_tile} = cute.local_tile(cute.make_identity_tensor(({m_size}, {n_size})), ({tcgen05_value.bm}, {tcgen05_value.bn}), ({tile_coord_m}, {tile_coord_n}))\n"
f" {tccc_base} = {tcgen05_value.thr_mma}.partition_C({coord_tile})\n"
f" {tccc} = cutlass.utils.gemm.sm100.transform_partitioned_tensor_layout({tccc_base})\n"
f" {tccc_epi} = cute.flat_divide({tccc}, {epi_tile})\n"
f" {ttr_cc} = {thr_copy_t2r}.partition_D({tccc_epi})\n"
f" {ttr_cc_grouped} = cute.group_modes({ttr_cc}, 3, cute.rank({ttr_cc}))\n"
f" {ttr_cc_subtile} = {ttr_cc_grouped}[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n"
f" {pred_c_shape} = (1, *{ttr_cc_subtile}.shape[1:])\n"
f" {pred_c} = cute.make_rmem_tensor({pred_c_shape}, cutlass.Boolean)\n"
f" for _pred_m in range({ttr_cc_subtile}.shape[1]):\n"
f" for _pred_n in range({ttr_cc_subtile}.shape[2]):\n"
f" _coord = {ttr_cc_subtile}[(0, _pred_m, _pred_n)]\n"
f" {pred_c}[(0, _pred_m, _pred_n)] = cute.elem_less(_coord, ({m_size}, {n_size}))\n"
f" cute.copy({simt_atom}, {ttr_rd}, {ttr_gc_subtile}, pred={pred_c})\n"
# Advance is a per-thread local state update, so it intentionally
# stays outside elect_one; only the mbarrier release is elected.
f"if {tcgen05_value.epi_active}:\n"
+ emit_pipeline_advance(tcgen05_value.acc_consumer_state, indent=" ")
),
]
tma_store_pipeline_setup = [
(
f"{epilog_sync_barrier} = cutlass.pipeline.NamedBarrier("
f"barrier_id=cutlass.Int32({tcgen05_value.epilog_sync_barrier_id}), "
f"num_threads=cutlass.Int32({tcgen05_value.epi_warp_count * 32}))"
),
(
f"{c_pipeline_producer_group} = cutlass.pipeline.CooperativeGroup("
f"cutlass.pipeline.Agent.Thread, cutlass.Int32({tcgen05_value.epi_warp_count * 32}))"
),
(
f"{c_pipeline} = cutlass.pipeline.PipelineTmaStore.create("
f"num_stages={tcgen05_value.c_stage_count}, "
f"producer_group={c_pipeline_producer_group})"
),
]
tma_store_pipeline_tail = (
f"if {tcgen05_value.warp_idx} == cutlass.Int32(0):\n"
f" {c_pipeline}.producer_tail()"
)
c_acquire_placement = state.device_function.config.get(
TCGEN05_C_ACQUIRE_PLACEMENT_CONFIG_KEY,
TCGEN05_C_ACQUIRE_PLACEMENT_PRE_LOOP,
)
acc_wait_placement = state.device_function.config.get(
TCGEN05_ACC_WAIT_PLACEMENT_CONFIG_KEY,
TCGEN05_ACC_WAIT_PLACEMENT_SUBTILE_LOOP,
)
c_store_mode = state.device_function.config.get(
TCGEN05_C_STORE_MODE_CONFIG_KEY,
TCGEN05_C_STORE_MODE_NORMAL,
)
epilogue_layout = state.device_function.config.get(
TCGEN05_EPILOGUE_LAYOUT_CONFIG_KEY,
TCGEN05_EPILOGUE_LAYOUT_NORMAL,
)
diagnose_first_c_acquire_in_loop = (
c_acquire_placement == TCGEN05_C_ACQUIRE_PLACEMENT_FIRST_IN_LOOP
)
diagnose_later_c_acquire_before_barrier = (
c_acquire_placement == TCGEN05_C_ACQUIRE_PLACEMENT_LATER_BEFORE_BARRIER
)
diagnose_acc_wait_before_subtile_loop = (
acc_wait_placement == TCGEN05_ACC_WAIT_PLACEMENT_BEFORE_SUBTILE_LOOP
)
diagnose_skip_epilogue_store = (
c_store_mode == TCGEN05_C_STORE_MODE_SKIP_EPILOGUE_STORE
)
diagnose_split_first_t2r = (
epilogue_layout == TCGEN05_EPILOGUE_LAYOUT_SPLIT_FIRST_T2R
)
diagnose_split_acc_t2r_store_tail = (
epilogue_layout == TCGEN05_EPILOGUE_LAYOUT_SPLIT_ACC_T2R_STORE_TAIL
)
diagnose_module_helper_acc_t2r = (
epilogue_layout == TCGEN05_EPILOGUE_LAYOUT_MODULE_HELPER_ACC_T2R
)
diagnose_module_helper_store_tail = (
epilogue_layout == TCGEN05_EPILOGUE_LAYOUT_MODULE_HELPER_STORE_TAIL
)
diagnose_split_epilogue_layout = (
diagnose_split_first_t2r
or diagnose_split_acc_t2r_store_tail
or diagnose_module_helper_acc_t2r
or diagnose_module_helper_store_tail
)
if diagnose_split_epilogue_layout:
if not (
tcgen05_value.use_role_local_epi and tcgen05_value.use_tma_store_epilogue
):
raise exc.BackendUnsupported(
"cute",
f"{TCGEN05_EPILOGUE_LAYOUT_CONFIG_KEY}={epilogue_layout!r} "
"requires the "
"role-local TMA-store tcgen05 epilogue",
)
if not tcgen05_value.is_two_cta:
raise exc.BackendUnsupported(
"cute",
f"{TCGEN05_EPILOGUE_LAYOUT_CONFIG_KEY}={epilogue_layout!r} requires "
"CtaGroup.TWO",
)
# Conservative proxy for the validated static-full CtaGroup.TWO
# two-or-more-subtile envelope; the exact subtile count is only
# available after the CUTLASS epilogue partitioning below.
if tcgen05_value.bn < TCGEN05_TWO_CTA_BLOCK_N:
raise exc.BackendUnsupported(
"cute",
f"{TCGEN05_EPILOGUE_LAYOUT_CONFIG_KEY}={epilogue_layout!r} is only "
f"validated for CtaGroup.TWO block_n >= {TCGEN05_TWO_CTA_BLOCK_N}",
)
tma_store_first_subtile_acquire = (
[]
if diagnose_first_c_acquire_in_loop
else [
(
f"if {tcgen05_value.epi_active} and "
f"{tcgen05_value.warp_idx} == cutlass.Int32(0):\n"
f" {c_pipeline}.producer_acquire()"
)
]
)
tma_store_loop_first_subtile_acquire = (
(
f" if _tcgen05_subtile == 0 and "
f"{tcgen05_value.warp_idx} == cutlass.Int32(0):\n"
f" {c_pipeline}.producer_acquire()\n"
)
if diagnose_first_c_acquire_in_loop
else ""
)
tma_store_split_first_subtile_acquire = (
(
f" if {tcgen05_value.warp_idx} == cutlass.Int32(0):\n"
f" {c_pipeline}.producer_acquire()\n"
)
if diagnose_first_c_acquire_in_loop
else ""
)
tma_store_loop_later_subtile_acquire = (
""
if diagnose_later_c_acquire_before_barrier
else (
f" if _tcgen05_subtile != 0 and "
f"{tcgen05_value.warp_idx} == cutlass.Int32(0):\n"
f" {c_pipeline}.producer_acquire()\n"
)
)
tma_store_loop_late_later_subtile_acquire = (
(
f" if _tcgen05_subtile != 0 and "
f"{tcgen05_value.warp_idx} == cutlass.Int32(0):\n"
f" {c_pipeline}.producer_acquire()\n"
)
if diagnose_later_c_acquire_before_barrier
else ""
)
tma_store_pre_loop_acc_wait = (
[
(
f"if {tcgen05_value.epi_active}:\n"
f" {tcgen05_value.acc_pipeline}.consumer_wait({tcgen05_value.acc_consumer_state})"
)
]
if diagnose_acc_wait_before_subtile_loop
else []
)
tma_store_loop_acc_wait = (
""
if diagnose_acc_wait_before_subtile_loop
else (
f" if _tcgen05_subtile == 0:\n"
f" {tcgen05_value.acc_pipeline}.consumer_wait({tcgen05_value.acc_consumer_state})\n"
)
)
tma_store_split_first_acc_wait = (
""
if diagnose_acc_wait_before_subtile_loop
else (
f" {tcgen05_value.acc_pipeline}.consumer_wait({tcgen05_value.acc_consumer_state})\n"
)
)
tma_store_split_tail_later_subtile_acquire = (
""
if diagnose_later_c_acquire_before_barrier
else (
f" if {tcgen05_value.warp_idx} == cutlass.Int32(0):\n"
f" {c_pipeline}.producer_acquire()\n"
)
)
tma_store_split_tail_late_later_subtile_acquire = (
(
f" if {tcgen05_value.warp_idx} == cutlass.Int32(0):\n"
f" {c_pipeline}.producer_acquire()\n"
)
if diagnose_later_c_acquire_before_barrier
else ""
)
# Pyrefly does not preserve the non-None tcgen05_value narrowing inside
# the nested source formatter, so keep local string aliases for attributes
# read only by that closure.
tcgen05_epi_active = tcgen05_value.epi_active
tcgen05_acc_pipeline = tcgen05_value.acc_pipeline
tcgen05_acc_consumer_state = tcgen05_value.acc_consumer_state
tcgen05_warp_idx = tcgen05_value.warp_idx
tcgen05_tma_store_atom = tcgen05_value.tma_store_atom
def tma_store_acc_t2r_region(*, acc_wait: str) -> str:
return (
f"{acc_wait}"
f" {ttr_tacc_mn} = {ttr_tacc}[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n"
f" cute.copy({tiled_copy_t2r}, {ttr_tacc_mn}, {ttr_racc})\n"
f" {acc_vec} = {trs_racc}.load().to({target_dtype})\n"
f" if _tcgen05_subtile == {subtile_count} - 1:\n"
f" cute.arch.fence_view_async_tmem_load()\n"
f" with cute.arch.elect_one():\n"
f" {tcgen05_acc_pipeline}.consumer_release({tcgen05_acc_consumer_state})\n"
f" {trs_rd}.store({acc_vec})\n"
)
def tma_store_tail_region(*, late_later_subtile_acquire: str) -> str:
return (
f"{late_later_subtile_acquire}"
f" {epilog_sync_barrier}.arrive_and_wait()\n"
f" {c_buffer} = ({tma_c_buffer_expr}) % cutlass.Int32({tcgen05_c_stage_count})\n"
f" cute.copy({tiled_copy_r2s}, {trs_rd}, {trs_sd}[(None, None, None, {c_buffer})])\n"
f" cute.arch.fence_view_async_shared()\n"
f" {epilog_sync_barrier}.arrive_and_wait()\n"
f" if {tcgen05_warp_idx} == cutlass.Int32(0):\n"
f" cute.copy({tcgen05_tma_store_atom}, {bsg_sd}[(None, {c_buffer})], {bsg_gd}[(None, cutlass.Int32(_tcgen05_subtile))])\n"
f" {c_pipeline}.producer_commit()\n"
)
def tma_store_subtile_body(
*,
first_subtile_acquire: str,
later_subtile_acquire: str,
acc_wait: str,
late_later_subtile_acquire: str,
) -> str:
return (
f" if {tcgen05_epi_active}:\n"
f"{first_subtile_acquire}"
f"{later_subtile_acquire}"
f"{tma_store_acc_t2r_region(acc_wait=acc_wait)}"
f"{tma_store_tail_region(late_later_subtile_acquire=late_later_subtile_acquire)}"
)
def indented_diagnostic_region(source: str) -> str:
if not source:
return " pass\n"
return "".join(f" {line}" for line in source.splitlines(keepends=True))
def tma_store_helper_boundary_subtile_body(
*,
first_subtile_acquire: str,
later_subtile_acquire: str,
acc_wait: str,
late_later_subtile_acquire: str,
) -> str:
acquire_region = f"{first_subtile_acquire}{later_subtile_acquire}"
acc_region = tma_store_acc_t2r_region(acc_wait=acc_wait)
tail_region = tma_store_tail_region(
late_later_subtile_acquire=late_later_subtile_acquire
)
# These constant-true blocks are diagnostic source boundaries. The
# generated-code AST round trip preserves them, while emitted comments
# are not reliable line-info anchors.
return (
f" if {tcgen05_epi_active}:\n"
f" if True:\n"
f"{indented_diagnostic_region(acquire_region)}"
f" if True:\n"
f"{indented_diagnostic_region(acc_region)}"
f" if True:\n"
f"{indented_diagnostic_region(tail_region)}"
)
module_acc_t2r_helper_name = (
df.unique_name("tcgen05_acc_t2r_region")
if diagnose_module_helper_acc_t2r
else ""
)
module_store_tail_helper_name = (
df.unique_name("tcgen05_store_tail_region")
if diagnose_module_helper_store_tail
else ""
)
def tma_store_module_acc_t2r_helper_source(*, acc_wait: str) -> str:
return (
"@cute.jit\n"
f"def {module_acc_t2r_helper_name}("
"_tcgen05_subtile, "
"tcgen05_acc_pipeline, "
"tcgen05_acc_consumer_state, "
"tcgen05_tTR_tAcc, "
"tcgen05_tiled_copy_t2r, "
"tcgen05_tTR_rAcc, "
"tcgen05_tRS_rAcc, "
"tcgen05_tRS_rD, "
"tcgen05_subtile_count"
"):\n"
f"{acc_wait}"
" tcgen05_tTR_tAcc_mn = tcgen05_tTR_tAcc[(None, None, None, cutlass.Int32(_tcgen05_subtile))]\n"
" cute.copy(tcgen05_tiled_copy_t2r, tcgen05_tTR_tAcc_mn, tcgen05_tTR_rAcc)\n"
f" tcgen05_acc_vec = tcgen05_tRS_rAcc.load().to({target_dtype})\n"
" if _tcgen05_subtile == tcgen05_subtile_count - 1:\n"
" cute.arch.fence_view_async_tmem_load()\n"
" with cute.arch.elect_one():\n"
" tcgen05_acc_pipeline.consumer_release(tcgen05_acc_consumer_state)\n"
" tcgen05_tRS_rD.store(tcgen05_acc_vec)"
)
def tma_store_module_acc_t2r_helper_call() -> str:
return (
f" {module_acc_t2r_helper_name}("
f"_tcgen05_subtile, "
f"{tcgen05_acc_pipeline}, "
f"{tcgen05_acc_consumer_state}, "
f"{ttr_tacc}, "
f"{tiled_copy_t2r}, "
f"{ttr_racc}, "
f"{trs_racc}, "
f"{trs_rd}, "
f"{subtile_count})\n"
)
def tma_store_module_helper_subtile_body(
*,
first_subtile_acquire: str,
later_subtile_acquire: str,
late_later_subtile_acquire: str,
) -> str:
return (
f" if {tcgen05_epi_active}:\n"
f"{first_subtile_acquire}"
f"{later_subtile_acquire}"
f"{tma_store_module_acc_t2r_helper_call()}"
f"{tma_store_tail_region(late_later_subtile_acquire=late_later_subtile_acquire)}"
)
def tma_store_module_tail_helper_source(*, late_later_subtile_acquire: str) -> str:
return (
"@cute.jit\n"
f"def {module_store_tail_helper_name}("
"_tcgen05_subtile, "
"tcgen05_tma_c_buffer_index, "
"tcgen05_epilog_sync_barrier, "
"tcgen05_tiled_copy_r2s, "
"tcgen05_tRS_rD, "
"tcgen05_tRS_sD, "
"tcgen05_tma_store_atom, "
"tcgen05_bSG_sD, "
"tcgen05_bSG_gD, "
"tcgen05_c_pipeline, "
"tcgen05_warp_idx"
"):\n"
f"{late_later_subtile_acquire}"
" tcgen05_epilog_sync_barrier.arrive_and_wait()\n"
f" tcgen05_c_buffer = tcgen05_tma_c_buffer_index % cutlass.Int32({tcgen05_c_stage_count})\n"
" cute.copy(tcgen05_tiled_copy_r2s, tcgen05_tRS_rD, tcgen05_tRS_sD[(None, None, None, tcgen05_c_buffer)])\n"
" cute.arch.fence_view_async_shared()\n"
" tcgen05_epilog_sync_barrier.arrive_and_wait()\n"
" if tcgen05_warp_idx == cutlass.Int32(0):\n"
" cute.copy(tcgen05_tma_store_atom, tcgen05_bSG_sD[(None, tcgen05_c_buffer)], tcgen05_bSG_gD[(None, cutlass.Int32(_tcgen05_subtile))])\n"
" tcgen05_c_pipeline.producer_commit()"
)
def tma_store_module_tail_helper_call() -> str:
return (
f" {module_store_tail_helper_name}("
f"_tcgen05_subtile, "
f"{tma_c_buffer_expr}, "
f"{epilog_sync_barrier}, "
f"{tiled_copy_r2s}, "
f"{trs_rd}, "
f"{trs_sd}, "
f"{tcgen05_tma_store_atom}, "
f"{bsg_sd}, "
f"{bsg_gd}, "
f"{c_pipeline}, "
f"{tcgen05_warp_idx})\n"
)
def tma_store_module_tail_subtile_body(
*,
first_subtile_acquire: str,
later_subtile_acquire: str,
acc_wait: str,
) -> str:
return (
f" if {tcgen05_epi_active}:\n"
f"{first_subtile_acquire}"
f"{later_subtile_acquire}"
f"{tma_store_acc_t2r_region(acc_wait=acc_wait)}"
f"{tma_store_module_tail_helper_call()}"
)
if diagnose_split_first_t2r:
tma_store_split_first_subtile_body = tma_store_subtile_body(
first_subtile_acquire=tma_store_split_first_subtile_acquire,
later_subtile_acquire="",
acc_wait=tma_store_split_first_acc_wait,
late_later_subtile_acquire="",
)
tma_store_split_tail_subtile_body = tma_store_subtile_body(
first_subtile_acquire="",
later_subtile_acquire=tma_store_split_tail_later_subtile_acquire,
acc_wait="",
late_later_subtile_acquire=(
tma_store_split_tail_late_later_subtile_acquire
),
)
# Diagnostic-only scaffolding: reuse the one-indent subtile formatter
# for a static first subtile without changing production source layout.
# The tail loop maps split-loop indices back to logical subtile ids 1..N-1;
# unroll_full=True keeps those subtile values compile-time constants.
tma_store_subtile_loop = (
"if True:\n"
f" _tcgen05_subtile = 0\n"
f"{tma_store_split_first_subtile_body}"
f"for _tcgen05_split_subtile in cutlass.range({subtile_count} - 1, unroll_full=True):\n"
f" _tcgen05_subtile = _tcgen05_split_subtile + 1\n"
f"{tma_store_split_tail_subtile_body}"
)
elif diagnose_split_acc_t2r_store_tail:
tma_store_helper_boundary_body = tma_store_helper_boundary_subtile_body(
first_subtile_acquire=tma_store_loop_first_subtile_acquire,
later_subtile_acquire=tma_store_loop_later_subtile_acquire,
acc_wait=tma_store_loop_acc_wait,
late_later_subtile_acquire=tma_store_loop_late_later_subtile_acquire,
)
tma_store_subtile_loop = (
f"for _tcgen05_subtile in cutlass.range({subtile_count}, unroll_full=True):\n"
f"{tma_store_helper_boundary_body}"
)
elif diagnose_module_helper_acc_t2r:
module_helper_acc_wait = (
""
if diagnose_acc_wait_before_subtile_loop
else (
" if _tcgen05_subtile == 0:\n"
" tcgen05_acc_pipeline.consumer_wait(tcgen05_acc_consumer_state)\n"
)
)
state.codegen.module_statements.append(
statement_from_string(
tma_store_module_acc_t2r_helper_source(acc_wait=module_helper_acc_wait)
)
)
tma_store_module_helper_body = tma_store_module_helper_subtile_body(
first_subtile_acquire=tma_store_loop_first_subtile_acquire,
later_subtile_acquire=tma_store_loop_later_subtile_acquire,
late_later_subtile_acquire=tma_store_loop_late_later_subtile_acquire,
)
tma_store_subtile_loop = (
f"for _tcgen05_subtile in cutlass.range({subtile_count}, unroll_full=True):\n"
f"{tma_store_module_helper_body}"
)
elif diagnose_module_helper_store_tail:
module_tail_late_later_subtile_acquire = (
(
" if _tcgen05_subtile != 0 and "
"tcgen05_warp_idx == cutlass.Int32(0):\n"
" tcgen05_c_pipeline.producer_acquire()\n"
)
if diagnose_later_c_acquire_before_barrier
else ""
)
state.codegen.module_statements.append(
statement_from_string(
tma_store_module_tail_helper_source(
late_later_subtile_acquire=module_tail_late_later_subtile_acquire
)
)
)
tma_store_module_tail_body = tma_store_module_tail_subtile_body(
first_subtile_acquire=tma_store_loop_first_subtile_acquire,
later_subtile_acquire=tma_store_loop_later_subtile_acquire,
acc_wait=tma_store_loop_acc_wait,
)
tma_store_subtile_loop = (
f"for _tcgen05_subtile in cutlass.range({subtile_count}, unroll_full=True):\n"
f"{tma_store_module_tail_body}"
)
else:
tma_store_default_subtile_body = tma_store_subtile_body(
first_subtile_acquire=tma_store_loop_first_subtile_acquire,
later_subtile_acquire=tma_store_loop_later_subtile_acquire,
acc_wait=tma_store_loop_acc_wait,
late_later_subtile_acquire=tma_store_loop_late_later_subtile_acquire,
)
tma_store_subtile_loop = (
f"for _tcgen05_subtile in cutlass.range({subtile_count}, unroll_full=True):\n"
f"{tma_store_default_subtile_body}"
)
tma_store_smem_setup = [
# Must match the wrapper-side `tcgen05_d_tma` TMA atom layout in
# `helion/runtime/__init__.py`; both describe one D SMEM stage.
(
f"{smem_d_layout} = cutlass.utils.blackwell_helpers.make_smem_layout_epi("
f"{target_dtype}, cutlass.utils.layout.LayoutEnum.ROW_MAJOR, "
f"{epi_tile}, {tcgen05_value.c_stage_count})"
),
(
f"{smem_d_ptr} = cute.arch.alloc_smem("
f"{target_dtype}, cute.cosize({smem_d_layout}.outer), alignment=1024)"
),
(
f"{smem_d} = cute.make_tensor("
f"cute.recast_ptr({smem_d_ptr}, {smem_d_layout}.inner, dtype={target_dtype}), "
f"{smem_d_layout}.outer)"
),
]
tma_store_acc_layout_setup = [
(
f"{tacc} = cutlass.utils.gemm.sm100.transform_partitioned_tensor_layout("
f"{tcgen05_value.epi_acc_frag_base})"
),
]
tma_store_role_invariant_setup = [
*tma_static_store_setup,
*tma_store_smem_setup,
*tma_store_acc_layout_setup,
]
suppressed_store_body_core = [
(
# Diagnostic-only invalid-output mode. Keep the accumulator
# pipeline draining so persistent kernels do not deadlock, but
# suppress C-pipeline acquire/commit, R2S/SMEM work, and TMA D
# stores to bound whether hot waits are tied to the C-store path.
f"if {tcgen05_value.epi_active}:\n"
f" {tcgen05_value.acc_pipeline}.consumer_wait({tcgen05_value.acc_consumer_state})\n"
f" with cute.arch.elect_one():\n"
f" {tcgen05_value.acc_pipeline}.consumer_release({tcgen05_value.acc_consumer_state})\n"
+ emit_pipeline_advance(
tcgen05_value.acc_consumer_state,
indent=" ",
)
)
]
# Non-role-local stores keep pipeline/SMEM setup before per-tile C
# partitioning so the hoisted role-local prefix matches the same
# invariant setup subset.
tma_store_body_core = [
*([] if tcgen05_value.use_role_local_epi else tma_static_store_setup),
*([] if tcgen05_value.use_role_local_epi else tma_store_pipeline_setup),
*([] if tcgen05_value.use_role_local_epi else tma_store_smem_setup),
*tma_store_first_subtile_acquire,
*tma_tile_store_setup,
(
f"{tcgc} = cutlass.utils.gemm.sm100.transform_partitioned_tensor_layout("
f"{tcgc_base})"
),
(
f"{tcgc_planned} = cute.make_tensor("
f"{tcgc}.iterator, "
f"cute.append(cute.append(cute.append({tcgc}.layout, {tcgen05_value.epilogue_rest_mode}), {tcgen05_value.epilogue_rest_mode}), {tcgen05_value.epilogue_rest_mode}))"
),
*([] if tcgen05_value.use_role_local_epi else tma_store_acc_layout_setup),
(
f"{tiled_copy_t2r}, {ttr_tacc_base}, {ttr_racc} = "
"cutlass.utils.gemm.sm100.epilogue_tmem_copy_and_partition("
f"{kernel_desc}, {tcgen05_value.epi_tidx}, {tacc}, {tcgc_planned}, {epi_tile}, {tcgen05_value.is_two_cta!s})"
),
(f"{ttr_rd} = cute.make_rmem_tensor({ttr_racc}.shape, {target_dtype})"),
(
f"{tiled_copy_r2s}, {trs_rd}, {trs_sd} = "
"cutlass.utils.gemm.sm100.epilogue_smem_copy_and_partition("
f"{kernel_desc}, {tiled_copy_t2r}, {ttr_rd}, "
f"{tcgen05_value.epi_tidx}, {smem_d})"
),
f"{trs_racc} = {tiled_copy_r2s}.retile({ttr_racc})",
f"{tcgc_epi} = cute.flat_divide({tcgc_planned}, {epi_tile})",
(
f"{bsg_sd}, {bsg_gd_partitioned} = cute.nvgpu.cpasync.tma_partition("
f"{tcgen05_value.tma_store_atom}, 0, cute.make_layout(1), "
f"cute.group_modes({smem_d}, 0, 2), "
f"cute.group_modes({tcgc_epi}, 0, 2))"
),
(
f"{bsg_gd} = {bsg_gd_partitioned}["
f"(None, None, None, cutlass.Int32(0), cutlass.Int32(0), cutlass.Int32(0))]"
),
f"{bsg_gd} = cute.group_modes({bsg_gd}, 1, cute.rank({bsg_gd}))",
(
f"{ttr_tacc_stage} = {ttr_tacc_base}["
f"(None, None, None, None, None, {tcgen05_value.acc_consumer_state}.index)]"
),
f"{ttr_tacc} = cute.group_modes({ttr_tacc_stage}, 3, cute.rank({ttr_tacc_stage}))",
f"{subtile_count} = cutlass.const_expr(cute.size({ttr_tacc}.shape, mode=[3]))",
*tma_store_pre_loop_acc_wait,
(
# Warp 0 pre-acquires the first TMA-store SMEM stage before
# per-tile C-store setup. The subtile loop acquires only later
# stages, so C-stage waits can overlap setup, the first
# acc-pipeline wait, and the other epi warps' TMEM
# load/conversion work on later subtile iterations. The
# diagnostic tcgen05_c_acquire_placement=first_in_loop moves only
# that first acquire into the subtile loop; later acquires and
# the accumulator wait keep their default order. The diagnostic
# later_before_barrier placement keeps the first acquire in
# production position and moves only later-subtile acquires just
# before the first epilogue barrier. The diagnostic
# tcgen05_acc_wait_placement=before_subtile_loop keeps both C
# acquire sites in production position and moves only the
# accumulator consumer wait before the subtile loop.
# A CTA-scoped named barrier ensures all epi warps have observed
# warp 0's acquire before they write SMEM; a second barrier ensures
# the SMEM writes and Quack-style async-shared fence are visible
# before warp 0 issues and commits the TMA operation.
# Compute the SMEM ring index after the first barrier so the
# acquire/barrier/index order stays aligned with Quack's
# TMA-store epilogue.
# The accumulator consumer state advances after the loop, matching
# Quack's call-site ordering while preserving the early release.
# After warp 0 commits the TMA store, the next subtile's
# producer_acquire plus the first named barrier are enough to
# keep all epi warps from writing a reused SMEM stage too early.
# Avoiding a post-commit barrier matches Quack's epilogue loop.
# The split_first_t2r diagnostic emits the first static subtile as
# a standalone source block, then loops over later subtile work.
# It is a layout discriminator for the hot acc-wait/T2R SASS row;
# the default production source shape remains the single loop.
tma_store_subtile_loop
# Advance is a per-thread local state update, so it intentionally
# stays outside elect_one; only the mbarrier release is elected.
+ f"if {tcgen05_value.epi_active}:\n"
+ emit_pipeline_advance(tcgen05_value.acc_consumer_state, indent=" ")
),
*([] if tcgen05_value.use_role_local_epi else [tma_store_pipeline_tail]),
]
store_body_core = (
suppressed_store_body_core
if diagnose_skip_epilogue_store
else (
tma_store_body_core
if tcgen05_value.use_tma_store_epilogue
else simt_store_body_core
)
)
main_stmts: list[ast.AST]
if tcgen05_value.use_role_local_epi:
# These setup statements intentionally remain virtual-pid-independent.
# The persistent splitter hoists them before the role-local scheduler
# loops; if future setup reads per-tile state, it must be registered
# as per-tile work instead.
tma_store_hoisted_stmts = (
[
statement_from_string(line)
for line in [
*tma_store_pipeline_setup,
*tma_store_role_invariant_setup,
]
]
if tcgen05_value.use_tma_store_epilogue and not diagnose_skip_epilogue_store
else []
)
sync_before_stmt = statement_from_string("cute.arch.sync_threads()")
main_stmt = statement_from_string(
"if True:\n" + textwrap.indent("\n".join(store_body_core), " ")
)
sync_after_stmt = statement_from_string("cute.arch.sync_threads()")
df.register_cute_tcgen05_per_tile_stmts(
[sync_before_stmt, main_stmt, sync_after_stmt]
)
df.register_cute_tcgen05_epi_role_stmts([main_stmt])
main_stmts = [
*tma_store_hoisted_stmts,
sync_before_stmt,
main_stmt,
sync_after_stmt,
]
else:
store_body = [
"cute.arch.sync_threads()",
*store_body_core,
"cute.arch.sync_threads()",
]
main_stmt = statement_from_string(
"if True:\n" + textwrap.indent("\n".join(store_body), " ")
)
main_stmts = [main_stmt]
# Pipeline drain + TMEM dealloc are one-shot cleanup. They must run
# AFTER all tiles have been processed (in the persistent path) and
# naturally land at the end of the kernel in the non-persistent path.
# Keep them as separate statements so the persistent splitter can
# extract them via the post-loop registration below.
post_loop_lines: list[str] = []
if (
tcgen05_value.use_tma_store_epilogue
and tcgen05_value.use_role_local_epi
and not diagnose_skip_epilogue_store
):
# Role-local persistent epilogues reuse the C-store pipeline across
# scheduler-recycled work tiles. Draining it inside each tile would
# serialize the next tile's epilogue against this tile's TMA stores.
# The tail must run before TMEM dealloc setup below.
post_loop_lines.append(tma_store_pipeline_tail)
if tcgen05_value.use_tma:
post_loop_lines.append(
f"if {tcgen05_value.tma_warp}:\n"
+ emit_producer_tail_tma_umma(
tcgen05_value.tma_pipeline,
tcgen05_value.tma_producer_state,
num_stages=tcgen05_value.ab_stage_count,
indent=" ",
skip_advances=tcgen05_value.skip_ab_producer_advance,
)
)
if tcgen05_value.is_two_cta:
# PDL parity with Quack/CUTLASS: after all MMAs are issued, hint
# dependent kernels before this role starts the final acc drain.
post_loop_lines.append(
f"if {tcgen05_value.exec_active}:\n"
" cute.arch.griddepcontrol_launch_dependents()"
)
post_loop_lines.extend(
[
(
f"if {tcgen05_value.exec_active}:\n"
f" {tcgen05_value.tmem_alloc_barrier}.arrive()"
),
(
f"if {tcgen05_value.exec_active}:\n"
+ emit_producer_tail_umma_async(
tcgen05_value.acc_pipeline,
tcgen05_value.acc_producer_state,
num_stages=tcgen05_value.acc_stage_count,
indent=" ",
)
),
(
f"{tcgen05_value.tmem_allocator} = cutlass.utils.TmemAllocator("
f"{tcgen05_value.tmem_holding_buf}, "
f"barrier_for_retrieve={tcgen05_value.tmem_alloc_barrier}, "
f"allocator_warp_id=0, is_two_cta={tcgen05_value.is_two_cta!s}, "
f"two_cta_tmem_dealloc_mbar_ptr={tcgen05_value.tmem_dealloc_mbar_ptr}, "
f"num_allocated_columns={tcgen05_value.acc_tmem_cols}"
f"{emit_dealloc_mbarrier_initialized_kwarg()})"
),
]
)
if not tcgen05_value.is_two_cta:
# Keep the long-validated cluster_m=1 teardown unchanged. The guarded
# CtaGroup.TWO path follows Quack's dealloc sequence without this CTA
# sync: epi warps synchronize through tmem_alloc_barrier before free.
post_loop_lines.append("cute.arch.sync_threads()")
post_loop_lines.extend(
[
(
f"if {tcgen05_value.epi_active}:\n"
f" {tcgen05_value.tmem_allocator}.relinquish_alloc_permit()"
),
(
f"if {tcgen05_value.epi_active}:\n"
f" {tcgen05_value.tmem_alloc_barrier}.arrive_and_wait()"
),
(
f"if {tcgen05_value.epi_active}:\n"
f" {tcgen05_value.tmem_allocator}.free({tcgen05_value.epi_acc_tmem_ptr})"
),
]
)
post_loop_stmts: list[ast.AST] = [
statement_from_string(line) for line in post_loop_lines
]
df.register_cute_tcgen05_post_loop_stmts(post_loop_stmts)
return [*main_stmts, *post_loop_stmts]
def _codegen_cute_store_loaded_index_trailing_slices(
state: CodegenState,
tensor: torch.Tensor,
subscript: list[object] | tuple[object, ...],
ast_subscript: list[object] | tuple[object, ...],
extra_mask: ast.AST | None,
value_node: torch.fx.Node,
) -> ast.AST | None:
from .._compiler.ast_extension import create
if value_node.target is not load or len(value_node.args) < 2:
return None
source_tensor_node = value_node.args[0]
if not isinstance(source_tensor_node, torch.fx.Node):
return None
source_tensor = source_tensor_node.meta.get("val")
if not isinstance(source_tensor, torch.Tensor):
return None
source_subscript = value_node.args[1]
if not isinstance(source_subscript, (list, tuple)) or not source_subscript:
return None
indexer = source_subscript[0]
if not isinstance(indexer, torch.fx.Node):
return None
indexer_value = indexer.meta.get("val")
if not isinstance(indexer_value, torch.Tensor) or indexer_value.ndim == 0:
return None
trailing_source = [*source_subscript[1:]]
if not trailing_source or not all(idx == slice(None) for idx in trailing_source):
return None
if len(subscript) != indexer_value.ndim + len(trailing_source):
return None
trailing_store = subscript[indexer_value.ndim :]
if not all(idx == slice(None) for idx in trailing_store):
return None
ast_source_subscript = list(
map_arg(tuple(source_subscript), lambda arg: state.env[arg])
)
index_exprs = _cute_index_exprs(
state,
[indexer_value],
[ast_source_subscript[0]],
tensor=source_tensor,
inactive_singleton_slice_expr="0",
)
if len(index_exprs) != 1:
return None
prefix_subscript = [*subscript[: indexer_value.ndim]]
prefix_ast_subscript = [*ast_subscript[: indexer_value.ndim]]
target_prefix = _cute_index_exprs(
state,
prefix_subscript,
prefix_ast_subscript,
tensor=tensor,
inactive_singleton_slice_expr="0",
)
if len(target_prefix) != indexer_value.ndim:
return None
env = CompileEnvironment.current()
index_dtype = env.backend.dtype_str(env.index_dtype)
source_loop_vars = [
state.device_function.new_var("slice_idx", dce=True) for _ in trailing_source
]
source_indices = [
index_exprs[0],
*[f"{index_dtype}({var})" for var in source_loop_vars],
]
target_indices = [
*target_prefix,
*[f"{index_dtype}({var})" for var in source_loop_vars],
]
if len(source_indices) != source_tensor.ndim or len(target_indices) != tensor.ndim:
return None
source_name = state.device_function.tensor_arg(source_tensor).name
target_name = state.device_function.tensor_arg(tensor).name
source_dtype = env.backend.dtype_str(source_tensor.dtype)
target_dtype = env.backend.dtype_str(tensor.dtype)
source_mask = _cute_combined_mask(
state,
[indexer_value],
None,
tensor=source_tensor,
)
target_mask = _cute_combined_mask(
state,
prefix_subscript,
extra_mask,
tensor=tensor,
)
masks = [mask for mask in (source_mask, target_mask) if mask is not None]
mask_expr = " and ".join(f"({mask})" for mask in masks) if masks else None
load_expr = f"{source_name}[{', '.join(source_indices)}]"
if mask_expr is not None:
load_expr = f"({load_expr} if {mask_expr} else {source_dtype}(0))"
store_expr = (
f"{target_name}.__setitem__({_cute_index_tuple(target_indices)}, "
f"{env.backend.ast_to_dtype_expr(load_expr, target_dtype)})"
)
if mask_expr is not None:
store_expr = f"{store_expr} if {mask_expr} else None"
tensor_dim = 0
for idx in prefix_subscript:
block_id = None
if isinstance(idx, torch.SymInt):
block_id = env.get_block_id(idx)
elif idx == slice(None) and tensor_dim < tensor.ndim:
block_id = next(
(
candidate
for candidate in _matching_block_ids(env, tensor.shape[tensor_dim])
if candidate in state.codegen.active_device_loops
),
None,
)
tensor_dim += 1
if block_id is None:
continue
axis = None
grid_state = state.codegen.current_grid_state
if grid_state is not None:
axis = grid_state.block_thread_axes.get(block_id)
if axis is None:
loops = state.codegen.active_device_loops.get(block_id)
if loops:
axis = loops[-1].block_thread_axes.get(block_id)
if axis is None or not (0 <= axis < 3):
continue
block_size = env.block_sizes[block_id].from_config(state.config)
if not isinstance(block_size, int):
continue
state.codegen.max_thread_block_dims[axis] = max(
state.codegen.max_thread_block_dims[axis],
block_size,
)
state.codegen.referenced_thread_block_dims[axis] = max(
state.codegen.referenced_thread_block_dims[axis],
block_size,
)
stmt: ast.stmt = create(ast.Expr, value=expr_from_string(store_expr))
for loop_var, source_pos in reversed(
[*zip(source_loop_vars, range(1, len(source_subscript)), strict=True)]
):
extent = _cute_tensor_dim_size_expr(state, source_tensor, source_pos)
stmt = create(
ast.For,
target=create(ast.Name, id=loop_var, ctx=ast.Store()),
iter=expr_from_string(f"range({extent})"),
body=[stmt],
orelse=[],
type_comment=None,
)
state.add_statement(stmt)
return ast.Constant(value=None)
def _codegen_cute_store_permute_lane_loops(
state: CodegenState,
tensor: torch.Tensor,
subscript: list[object] | tuple[object, ...],
ast_subscript: list[object] | tuple[object, ...],
value: ast.AST,
extra_mask: ast.AST | None,
value_node: torch.fx.Node,
) -> ast.AST | None:
from .._compiler.cute.cute_reshape import _coords_from_flat_index
from .._compiler.cute.cute_reshape import _flat_index_from_coords
from .._compiler.cute.cute_reshape import _get_dim_local_coord
from .._compiler.cute.cute_reshape import _get_tile_shape
from .._compiler.cute.cute_reshape import _permute_reorders_active_dims
from .._compiler.cute.cute_reshape import _shape_op_needs_materialization
from .._compiler.cute.cute_reshape import _store_permute_info
from .._compiler.generate_ast import GenerateAST
from .._compiler.tile_strategy import DeviceGridState
if not isinstance(state.codegen, GenerateAST):
return None
grid_state = state.codegen.current_grid_state
if not isinstance(grid_state, DeviceGridState) or not grid_state.has_lane_loops():
return None
if _shape_op_needs_materialization(value_node):
return None
index_exprs = _cute_index_exprs(
state,
subscript,
ast_subscript,
tensor=tensor,
inactive_singleton_slice_expr="0",
)
index_tuple = _cute_index_tuple(index_exprs)
mask_expr = _cute_combined_mask(state, subscript, extra_mask, tensor=tensor)
tensor_name = state.device_function.tensor_arg(tensor).name
input_node: torch.fx.Node
output_val = value_node.meta.get("val")
read_flat: str
input_shape: list[int]
info = _store_permute_info(value_node)
if info is not None:
input_node, perm = info
input_val = input_node.meta.get("val")
if not isinstance(input_val, torch.Tensor) or not isinstance(
output_val, torch.Tensor
):
return None
if not _permute_reorders_active_dims(state.codegen, input_val, perm):
return None
source_tensor_node = input_node.args[0] if input_node.args else None
source_extra_mask = input_node.args[2] if len(input_node.args) > 2 else None
if (
input_node.op == "call_function"
and input_node.target is load
and isinstance(source_tensor_node, torch.fx.Node)
and source_extra_mask is None
):
source_tensor = source_tensor_node.meta.get("val")
if isinstance(source_tensor, torch.Tensor):
reordered_subscript = [
subscript[perm.index(i)] for i in range(len(perm))
]
reordered_ast_subscript = (
[ast_subscript[perm.index(i)] for i in range(len(perm))]
if isinstance(ast_subscript, (list, tuple))
else None
)
source_index_exprs = _cute_index_exprs(
state,
reordered_subscript,
ast_subscript=reordered_ast_subscript,
tensor=source_tensor,
inactive_singleton_slice_expr="0",
)
source_index_tuple = _cute_index_tuple(source_index_exprs)
source_name = state.device_function.tensor_arg(source_tensor).name
source_mask = _cute_combined_mask(
state,
reordered_subscript,
None,
tensor=source_tensor,
)
source_dtype = CompileEnvironment.current().backend.dtype_str(
source_tensor.dtype
)
return expr_from_string(
(
f"({tensor_name}.__setitem__({index_tuple}, "
f"({source_name}[{source_index_tuple}] if {source_mask} else {source_dtype}(0))) "
f"if {mask_expr} else None)"
)
if source_mask is not None and mask_expr is not None
else (
f"{tensor_name}.__setitem__({index_tuple}, "
f"{source_name}[{source_index_tuple}] if {source_mask} else {source_dtype}(0))"
if source_mask is not None
else (
f"({tensor_name}.__setitem__({index_tuple}, {source_name}[{source_index_tuple}]) "
f"if {mask_expr} else None)"
if mask_expr is not None
else f"{tensor_name}.__setitem__({index_tuple}, {source_name}[{source_index_tuple}])"
)
)
)
raise exc.BackendUnsupported("cute", "permute lane-loop source tensor")
env = CompileEnvironment.current()
df = state.device_function
input_shape = _get_tile_shape(input_val, env, df.config)
output_shape = _get_tile_shape(output_val, env, df.config)
src_coords = [
_get_dim_local_coord(state.codegen, input_val, i)
for i in range(len(input_shape))
]
current_flat = _flat_index_from_coords(src_coords, input_shape)
output_coords = _coords_from_flat_index(current_flat, output_shape)
read_coords = [output_coords[perm.index(i)] for i in range(len(perm))]
read_flat = _flat_index_from_coords(read_coords, input_shape)
elif value_node.target in {
torch.ops.aten.view.default,
torch.ops.aten.reshape.default,
}:
input_arg = value_node.args[0]
if not isinstance(input_arg, torch.fx.Node):
return None
input_node = input_arg
input_val = input_node.meta.get("val")
if not isinstance(input_val, torch.Tensor) or not isinstance(
output_val, torch.Tensor
):
return None
env = CompileEnvironment.current()
df = state.device_function
input_shape = _get_tile_shape(input_val, env, df.config)
output_shape = _get_tile_shape(output_val, env, df.config)
if input_shape == output_shape:
return None
input_non_unit = [s for s in input_shape if s != 1]
output_non_unit = [s for s in output_shape if s != 1]
if input_non_unit == output_non_unit:
return None
src_coords = [
_get_dim_local_coord(state.codegen, input_val, i)
for i in range(len(input_shape))
]
current_flat = _flat_index_from_coords(src_coords, input_shape)
output_coords = [
_get_dim_local_coord(state.codegen, output_val, i)
for i in range(len(output_shape))
]
read_flat = _flat_index_from_coords(output_coords, output_shape)
else:
return None
env = CompileEnvironment.current()
df = state.device_function
input_numel = 1
for size in input_shape:
input_numel *= size
dtype_str = env.backend.dtype_str(input_val.dtype)
smem_ptr = df.new_var("permute_smem_ptr")
smem = df.new_var("permute_smem")
state.codegen.add_statement(
statement_from_string(
f"{smem_ptr} = cute.arch.alloc_smem({dtype_str}, {input_numel})"
)
)
state.codegen.add_statement(
statement_from_string(
f"{smem} = cute.make_tensor({smem_ptr}, ({input_numel},))"
)
)
read_expr = (
f"{df.tensor_arg(tensor).name}.__setitem__({index_tuple}, {smem}[{read_flat}])"
if mask_expr is None
else (
f"({df.tensor_arg(tensor).name}.__setitem__({index_tuple}, {smem}[{read_flat}]) "
f"if {mask_expr} else None)"
)
)
return expr_from_string(
f"({smem}.__setitem__({current_flat}, {{value}}), "
f"cute.arch.sync_threads(), "
f"{read_expr})",
value=value,
)
@_decorators.codegen(store, "metal")
def _(state: CodegenState) -> ast.AST:
# Metal delegates to the same PointerIndexingStrategy as Triton.
# This produces tl.store(ptr + offset, val, mask) in the AST;
# the MSL walker translates it to Metal.
tensor = state.proxy_arg(0)
subscript = state.proxy_arg(1)
assert isinstance(subscript, (list, tuple))
value = state.ast_arg(2)
extra_mask = state.ast_args[3]
assert isinstance(extra_mask, (type(None), ast.AST))
if isinstance(tensor, torch.Tensor):
device_fn = state.device_function
device_fn.device_store_index += 1
indexing_idx = device_fn.device_memory_op_index
device_fn.device_memory_op_index += 1
strategy = device_fn.get_indexing_strategy(indexing_idx)
return strategy.codegen_store(state, tensor, [*subscript], value, extra_mask)
raise exc.BackendUnsupported("metal", f"store target type: {type(tensor)}")
@_decorators.codegen(store, "cute")
def _(state: CodegenState) -> ast.AST:
tensor = state.proxy_arg(0)
subscript = state.proxy_arg(1)
assert isinstance(subscript, (list, tuple))
ast_subscript = state.ast_args[1]
assert isinstance(ast_subscript, (list, tuple))
raw_value = state.ast_args[2]
extra_mask = state.ast_args[3]
assert isinstance(extra_mask, (type(None), ast.AST))
value_node = None
if state.fx_node is not None and len(state.fx_node.args) > 2:
maybe_value_node = state.fx_node.args[2]
if isinstance(maybe_value_node, torch.fx.Node):
value_node = maybe_value_node
if isinstance(tensor, torch.Tensor):
affine_range_store = _codegen_cute_affine_range_store(
state,
tensor,
subscript,
ast_subscript,
raw_value,
extra_mask,
value_node,
)
if affine_range_store is not None:
state.add_statement(affine_range_store)
return ast.Constant(value=None)
strided_slice_store = _codegen_cute_strided_slice_store(
state,
tensor,
subscript,
raw_value,
extra_mask,
value_node,
)
if strided_slice_store is not None:
state.add_statement(strided_slice_store)
return ast.Constant(value=None)
value = state.ast_arg(2)
if value_node is not None:
if value_node.op == "call_function":
if isinstance(tensor, torch.Tensor):
rewritten_stmt = _codegen_cute_store_stack_load(
state,
tensor,
subscript,
ast_subscript,
value,
extra_mask,
value_node,
)
if rewritten_stmt is not None:
return rewritten_stmt
rewritten_stmt = _codegen_cute_store_loaded_index_trailing_slices(
state,
tensor,
subscript,
ast_subscript,
extra_mask,
value_node,
)
if rewritten_stmt is not None:
return rewritten_stmt
rewritten_stmt = _codegen_cute_store_permute_lane_loops(
state,
tensor,
subscript,
ast_subscript,
value,
extra_mask,
value_node,
)
if rewritten_stmt is not None:
return rewritten_stmt
from .._compiler.cute.cute_reshape import codegen_cute_store_permute
rewritten = codegen_cute_store_permute(state, value, value_node)
if rewritten is not None:
value = rewritten
if isinstance(tensor, tuple):
stack_tensor_ast = state.ast_args[0]
assert isinstance(stack_tensor_ast, tuple)
assert len(stack_tensor_ast) == 2
_tensor_like_ast, dev_ptrs_ast = stack_tensor_ast
assert isinstance(dev_ptrs_ast, ast.AST)
tensor_like, dev_ptrs = tensor
offset_expr = _cute_stack_tensor_offset_expr(
state,
tensor_like,
[*subscript],
ast_subscript,
)
backend = CompileEnvironment.current().backend
target_dtype = backend.dtype_str(tensor_like.dtype)
value = expr_from_string(
backend.ast_to_dtype_expr("{value}", target_dtype),
value=value,
)
ptr_expr = _cute_stack_tensor_pointer_expr(
target_dtype, dev_ptrs_ast, offset_expr
)
store_expr = expr_from_string(
"({ptr}).store({value})", ptr=ptr_expr, value=value
)
mask_expr = _cute_stack_tensor_mask_expr(
state,
tensor_like,
dev_ptrs,
[*subscript],
extra_mask,
)
if mask_expr is None:
return store_expr
mask_ast = expr_from_string(mask_expr)
assert isinstance(mask_ast, ast.expr)
assert isinstance(store_expr, ast.expr)
state.add_statement(
ast.fix_missing_locations(
ast.If(
test=mask_ast,
body=[ast.Expr(value=store_expr)],
orelse=[],
)
)
)
return ast.Constant(value=None)
if not isinstance(tensor, torch.Tensor):
raise exc.BackendUnsupported("cute", f"store target type: {type(tensor)}")
_log_cute_layout(state, "store")
if isinstance(value, ast.Name):
rewritten_stmt = _codegen_cute_store_tcgen05_tile(
state,
tensor,
subscript,
ast_subscript,
extra_mask,
value.id,
)
if rewritten_stmt is not None:
stmts = (
rewritten_stmt if isinstance(rewritten_stmt, list) else [rewritten_stmt]
)
for stmt in stmts:
state.add_statement(stmt)
return ast.Constant(value=None)
tensor_name = state.device_function.tensor_arg(tensor).name
backend = CompileEnvironment.current().backend
target_dtype = backend.dtype_str(tensor.dtype)
value = expr_from_string(
backend.ast_to_dtype_expr("{value}", target_dtype),
value=value,
)
index_exprs = _cute_index_exprs(
state,
subscript,
ast_subscript,
tensor=tensor,
inactive_singleton_slice_expr="0",
)
topk_lane_expr: object | None = None
topk_k: object | None = None
if state.fx_node is not None and len(state.fx_node.args) > 2:
value_node = state.fx_node.args[2]
if (
isinstance(value_node, torch.fx.Node)
and value_node.target is operator.getitem
and isinstance(value_node.args[0], torch.fx.Node)
and value_node.args[0].target is torch.ops.aten.topk.default
):
topk_lane_expr = value_node.args[0].meta.get("cute_topk_lane_expr")
topk_k = value_node.args[0].meta.get("cute_topk_k")
if isinstance(topk_lane_expr, str) and isinstance(topk_k, int):
index_exprs[-1] = topk_lane_expr
store_uses_pointer = "None" not in index_exprs
store_expr = _cute_scalar_store_expr(tensor_name, index_exprs, "{value}")
assign_expr = expr_from_string(store_expr, value=value)
mask_expr = _cute_combined_mask(state, subscript, extra_mask, tensor=tensor)
if isinstance(topk_lane_expr, str) and isinstance(topk_k, int):
topk_mask = f"({topk_lane_expr}) < {topk_k}"
mask_expr = topk_mask if mask_expr is None else f"({mask_expr}) and {topk_mask}"
if mask_expr is None:
return assign_expr
if store_uses_pointer:
mask_ast = expr_from_string(mask_expr)
assert isinstance(mask_ast, ast.expr)
assert isinstance(assign_expr, ast.expr)
state.add_statement(
ast.fix_missing_locations(
ast.If(
test=mask_ast,
body=[ast.Expr(value=assign_expr)],
orelse=[],
)
)
)
return ast.Constant(value=None)
return expr_from_string(
f"({store_expr} if {mask_expr} else None)",
value=value,
)
# TODO(joydddd): Add support for stack tensor in ref mode.
@_decorators.ref(store)
def _(
tensor: torch.Tensor,
index: list[object],
value: torch.Tensor | torch.SymInt | float,
extra_mask: torch.Tensor | None = None,
) -> None:
from .ref_tile import RefTile
# Normalize indices and identify tensor indices
indices = []
tensor_idx_positions = []
for i, idx in enumerate(index):
if isinstance(idx, RefTile):
idx = idx.index
# pyrefly: ignore [bad-argument-type]
indices.append(idx)
if isinstance(idx, torch.Tensor):
tensor_idx_positions.append(i)
# Handle broadcasting for multiple tensor indices
if len(tensor_idx_positions) > 1:
grids = torch.meshgrid(
# pyrefly: ignore [bad-argument-type]
*(indices[i] for i in tensor_idx_positions),
indexing="ij",
)
for i, grid in zip(tensor_idx_positions, grids, strict=False):
# pyrefly: ignore [unsupported-operation]
indices[i] = grid
if extra_mask is not None:
mask = extra_mask.to(torch.bool)
# Check bounds for tensor indices
for i, idx in enumerate(indices):
if isinstance(idx, torch.Tensor):
mask = mask & (idx >= 0) & (idx < tensor.shape[i])
mask_count = int(mask.sum().item())
if mask_count == 0:
return
# Use index_put_ for masked stores
valid_indices = []
for idx in indices:
if isinstance(idx, torch.Tensor):
valid_indices.append(idx[mask].long())
else:
idx_val = int(idx) if isinstance(idx, torch.SymInt) else idx
valid_indices.append(
# pyrefly: ignore [no-matching-overload]
torch.full(
(mask_count,), idx_val, dtype=torch.long, device=tensor.device
)
)
if isinstance(value, torch.Tensor):
values = value[mask]
else:
val = int(value) if isinstance(value, torch.SymInt) else value
values = torch.full(
(mask_count,), val, dtype=tensor.dtype, device=tensor.device
)
# Check for duplicate indices - this is undefined behavior in Triton
if valid_indices:
stacked = torch.stack(valid_indices, dim=1)
unique_count = stacked.unique(dim=0).size(0)
if unique_count < stacked.size(0):
raise exc.DuplicateStoreIndicesError(
"hl.store with duplicate indices has undefined behavior in compiled mode. "
"The order in which values are written to the same memory location is "
"non-deterministic and may vary between Triton versions and backends."
)
tensor.index_put_(tuple(valid_indices), values, accumulate=False)
return
# Simple assignment
tensor[tuple(indices)] = ( # pyrefly: ignore[unsupported-operation]
int(value) if isinstance(value, torch.SymInt) else value
)
[docs]
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
def load(
tensor: torch.Tensor | StackTensor,
index: list[object],
extra_mask: torch.Tensor | None = None,
eviction_policy: str | None = None,
) -> torch.Tensor:
"""Load a value from a tensor using a list of indices.
This function is equivalent to `tensor[index]` but allows
setting `extra_mask=` to mask elements beyond the default masking
based on the hl.tile range. It also accepts an optional
`eviction_policy` which is forwarded to the underlying Triton `tl.load`
call to control the cache eviction behavior (e.g., "evict_last").
Args:
tensor: The tensor / stack tensor to load from
index: The indices to use to index into the tensor
extra_mask: The extra mask (beyond automatic tile bounds masking) to apply to the tensor
eviction_policy: Optional Triton load eviction policy to hint cache behavior
Returns:
torch.Tensor: The loaded value
"""
raise exc.NotInsideKernel
@_decorators.prepare_args(load)
def _(
tensor: torch.Tensor | StackTensor,
index: list[object],
extra_mask: torch.Tensor | None = None,
eviction_policy: str | None = None,
) -> tuple[torch.Tensor | tuple, list[object], torch.Tensor | None, str | None]:
from .tile_proxy import Tile
index = Tile._tiles_to_sizes_for_index(index)
if isinstance(tensor, StackTensor):
return (tuple(tensor), index, extra_mask, eviction_policy)
assert isinstance(tensor, torch.Tensor)
return (tensor, index, extra_mask, eviction_policy)
@_decorators.register_fake(load)
def _(
tensor: torch.Tensor | tuple[object, ...],
index: list[object],
extra_mask: torch.Tensor | None = None,
eviction_policy: str | None = None,
) -> torch.Tensor:
if isinstance(tensor, torch.Tensor):
target_shape = SubscriptIndexing.compute_shape(tensor, index)
env = CompileEnvironment.current()
env.backend.process_fake_tensor_load(tensor, index)
return env.new_index_result(tensor, target_shape)
if isinstance(tensor, tuple):
tensor_like, dev_ptrs = tensor
assert isinstance(tensor_like, torch.Tensor)
assert isinstance(dev_ptrs, torch.Tensor)
tensor_shape = SubscriptIndexing.compute_shape(tensor_like, index)
target_shape = list(dev_ptrs.size()) + tensor_shape
return tensor_like.new_empty(target_shape)
raise NotImplementedError(f"Unsupported tensor type: {type(tensor)}")
@_decorators.codegen(load, "triton")
def _(state: CodegenState) -> ast.AST:
tensor = state.proxy_arg(0)
subscript = state.proxy_arg(1)
assert isinstance(subscript, (list, tuple))
ast_subscript = state.ast_args[1]
assert isinstance(ast_subscript, (list, tuple))
extra_mask = state.ast_args[2]
assert isinstance(extra_mask, (type(None), ast.AST))
eviction_policy = state.ast_args[3] if len(state.ast_args) > 3 else None
device_fn = state.device_function
load_idx = device_fn.device_load_index
device_fn.device_load_index += 1
# If no explicit eviction_policy and we're in device code, use tunable
if eviction_policy is None and state.codegen.on_device:
policies = state.config.load_eviction_policies
if load_idx < len(policies):
policy_value = policies[load_idx]
eviction_policy = _EVICTION_POLICY_MAP.get(policy_value, policy_value)
if eviction_policy is not None:
assert isinstance(eviction_policy, str)
eviction_policy = ast.Constant(value=eviction_policy)
if isinstance(tensor, torch.Tensor):
# If tile_index(...) is being broadcast-only indexed
from ..language import tile_index
tensor_node = state.fx_node.args[0] if state.fx_node is not None else None
if (
isinstance(tensor_node, torch.fx.Node)
and tensor_node.op == "call_function"
and tensor_node.target == tile_index
):
# tile.index tensors are not real memory accesses; materialize the
# block index variable with the requested broadcast/reshape.
env = CompileEnvironment.current()
block_id = env.get_block_id(tensor.size(0))
assert block_id is not None
base_var = state.codegen.index_var(block_id)
parts = []
for idx in subscript:
if idx is None:
parts.append("None")
elif idx == slice(None):
parts.append(":")
else:
raise AssertionError(
f"Unexpected index type in tile_index load: {idx}"
)
return expr_from_string(f"{base_var}[{', '.join(parts)}]")
# Use the shared memory op index for indexing strategy
indexing_idx = device_fn.device_memory_op_index
device_fn.device_memory_op_index += 1
strategy = device_fn.get_indexing_strategy(indexing_idx)
if state.codegen.load_transform is not None:
return state.codegen.load_transform(
state,
tensor,
[*subscript],
extra_mask,
eviction_policy,
strategy.codegen_load,
)
return strategy.codegen_load(
state, tensor, [*subscript], extra_mask, eviction_policy
)
if isinstance(tensor, tuple):
from .._compiler.indexing_strategy import StackIndexingStrategy
# Fusion is not supported for stack loads (multi-tensor device pointers);
# fall through to the unfused path regardless of load_transform.
stack_tensor_ast = state.ast_args[0]
assert isinstance(stack_tensor_ast, tuple)
assert len(stack_tensor_ast) == 2
tensor_like_ast, dev_ptrs_ast = stack_tensor_ast
return StackIndexingStrategy.codegen_load(
state, tensor, dev_ptrs_ast, [*subscript], extra_mask, eviction_policy
)
raise NotImplementedError(f"Unsupported tensor type: {type(tensor)}")
@_decorators.codegen(load, "pallas")
def _(state: CodegenState) -> ast.AST:
tensor = state.proxy_arg(0)
subscript = state.proxy_arg(1)
assert isinstance(tensor, torch.Tensor)
assert isinstance(subscript, (list, tuple))
return pallas_codegen.load_expr(state, list(subscript), tensor)
@_decorators.codegen(load, "metal")
def _(state: CodegenState) -> ast.AST:
# Metal delegates to the same PointerIndexingStrategy as Triton.
# This produces tl.load(ptr + offset, mask, other=0) in the AST;
# the MSL walker translates it to Metal.
tensor = state.proxy_arg(0)
subscript = state.proxy_arg(1)
assert isinstance(subscript, (list, tuple))
ast_subscript = state.ast_args[1]
assert isinstance(ast_subscript, (list, tuple))
extra_mask = state.ast_args[2]
assert isinstance(extra_mask, (type(None), ast.AST))
eviction_policy = state.ast_args[3] if len(state.ast_args) > 3 else None
assert isinstance(eviction_policy, (type(None), ast.AST))
if isinstance(tensor, torch.Tensor):
device_fn = state.device_function
device_fn.device_load_index += 1
indexing_idx = device_fn.device_memory_op_index
device_fn.device_memory_op_index += 1
strategy = device_fn.get_indexing_strategy(indexing_idx)
return strategy.codegen_load(
state, tensor, [*subscript], extra_mask, eviction_policy
)
raise exc.BackendUnsupported("metal", f"load tensor type: {type(tensor)}")
@_decorators.codegen(load, "cute")
def _(state: CodegenState) -> object:
tensor = state.proxy_arg(0)
subscript = state.proxy_arg(1)
assert isinstance(subscript, (list, tuple))
ast_subscript = state.ast_args[1]
assert isinstance(ast_subscript, (list, tuple))
extra_mask = state.ast_args[2]
assert isinstance(extra_mask, (type(None), ast.AST))
if isinstance(tensor, tuple):
stack_tensor_ast = state.ast_args[0]
assert isinstance(stack_tensor_ast, tuple)
assert len(stack_tensor_ast) == 2
tensor_like_ast, dev_ptrs_ast = stack_tensor_ast
assert isinstance(dev_ptrs_ast, ast.AST)
tensor_like, dev_ptrs = tensor
offset_expr = _cute_stack_tensor_offset_expr(
state,
tensor_like,
[*subscript],
ast_subscript,
)
backend = CompileEnvironment.current().backend
target_dtype = backend.dtype_str(tensor_like.dtype)
ptr_expr = _cute_stack_tensor_pointer_expr(
target_dtype, dev_ptrs_ast, offset_expr
)
load_expr = f"({ast.unparse(ptr_expr)}).load()"
mask_expr = _cute_stack_tensor_mask_expr(
state,
tensor_like,
dev_ptrs,
[*subscript],
extra_mask,
)
if tensor_like.dtype is torch.bool:
load_expr = f"({load_expr} != cutlass.Uint8(0))"
if mask_expr is None:
return expr_from_string(load_expr)
return expr_from_string(
f"({load_expr} if {mask_expr} else cutlass.Boolean(0))"
)
if mask_expr is None:
return expr_from_string(load_expr)
return expr_from_string(f"({load_expr} if {mask_expr} else {target_dtype}(0))")
if not isinstance(tensor, torch.Tensor):
raise exc.BackendUnsupported("cute", f"load tensor type: {type(tensor)}")
_log_cute_layout(state, "load")
from ..language import tile_index
tensor_node = state.fx_node.args[0] if state.fx_node is not None else None
if (
isinstance(tensor_node, torch.fx.Node)
and tensor_node.op == "call_function"
and tensor_node.target == tile_index
):
env = CompileEnvironment.current()
block_id = env.get_block_id(tensor.size(0))
if block_id is None:
raise exc.BackendUnsupported("cute", "tile_index load block id")
index_var = _cute_active_index_var(state, block_id)
if index_var is None:
raise exc.BackendUnsupported("cute", "inactive tile_index load")
for idx in subscript:
if idx is None or idx == slice(None):
continue
raise exc.BackendUnsupported(
"cute", f"tile_index load index type: {type(idx)}"
)
return expr_from_string(index_var)
if state.device_function.suppress_cute_root_lane_loops or (
state.fx_node is not None
and state.device_function.is_cute_collective_handled_load(state.fx_node.name)
):
zero = CompileEnvironment.current().backend.dtype_str(tensor.dtype)
return expr_from_string(f"{zero}(0)")
packed_affine_lhs = _maybe_codegen_cute_packed_affine_lhs_load(
state, tensor, subscript, extra_mask
)
if packed_affine_lhs is not None:
return packed_affine_lhs
packed_rhs_load = _maybe_codegen_cute_packed_rhs_load(
state, tensor, subscript, extra_mask
)
if packed_rhs_load is not None:
return packed_rhs_load
if _is_cute_affine_range_load_for_store(state, subscript, ast_subscript):
zero = CompileEnvironment.current().backend.dtype_str(tensor.dtype)
return expr_from_string(f"{zero}(0)")
if _is_cute_strided_slice_load_for_store(state, tensor, subscript):
zero = CompileEnvironment.current().backend.dtype_str(tensor.dtype)
return expr_from_string(f"{zero}(0)")
tensor_name = state.device_function.tensor_arg(tensor).name
index_exprs = _cute_index_exprs(
state,
subscript,
ast_subscript,
tensor=tensor,
inactive_slice_expr="None",
inactive_singleton_slice_expr="0",
)
load_expr = _cute_scalar_load_expr(tensor_name, index_exprs)
mask_expr = _cute_combined_mask(
state,
subscript,
extra_mask,
tensor=tensor,
include_tensor_index_masks=False,
)
if tensor.dtype is torch.bool:
load_expr = f"({load_expr} != cutlass.Uint8(0))"
if mask_expr is None:
return expr_from_string(load_expr)
return expr_from_string(f"({load_expr} if {mask_expr} else cutlass.Boolean(0))")
if state.fx_node is not None and any(
user.target in (torch.ops.aten.sort.default, torch.ops.aten.topk.default)
or getattr(user.target, "__name__", None) == "_associative_scan"
for user in state.fx_node.users
):
from .._compiler.cute.indexing import CuteSortableLoad
tensor_dim = 0
sort_index_pos = -1
for idx in subscript:
if idx is None:
continue
if tensor_dim == tensor.ndim - 1:
sort_index_pos = tensor_dim
break
tensor_dim += 1
if sort_index_pos < 0:
raise exc.BackendUnsupported("cute", "sort/topk input rank")
sortable_load = CuteSortableLoad(
expr=expr_from_string(
load_expr
if mask_expr is None
else f"({load_expr} if {mask_expr} else {CompileEnvironment.current().backend.dtype_str(tensor.dtype)}(0))"
),
tensor_name=tensor_name,
index_exprs=tuple(index_exprs),
sort_index_pos=sort_index_pos,
mask_expr=mask_expr,
dtype=tensor.dtype,
)
state.fx_node.meta["cute_sortable_load"] = sortable_load
return sortable_load.expr
if mask_expr is None:
return expr_from_string(load_expr)
zero = CompileEnvironment.current().backend.dtype_str(tensor.dtype)
return expr_from_string(f"({load_expr} if {mask_expr} else {zero}(0))")
@_decorators.get_masked_value(load)
def _(node: torch.fx.Node) -> int:
return 0 # loads are always masked to 0
# TODO(joydddd): Add support for stack tensor in ref mode.
@_decorators.ref(load)
def _(
tensor: torch.Tensor,
index: list[object],
extra_mask: torch.Tensor | None = None,
eviction_policy: str | None = None,
) -> torch.Tensor:
from .ref_tile import RefTile
if extra_mask is None:
# Convert RefTiles to indices
indices = [idx.index if isinstance(idx, RefTile) else idx for idx in index]
# Use meshgrid for Cartesian product when we have multiple tensor indices
tensor_idxs = [
i for i, idx in enumerate(indices) if isinstance(idx, torch.Tensor)
]
if len(tensor_idxs) > 1:
# pyrefly: ignore [bad-argument-type]
grids = torch.meshgrid(*(indices[i] for i in tensor_idxs), indexing="ij")
for i, grid in zip(tensor_idxs, grids, strict=False):
indices[i] = grid
# pyrefly: ignore [bad-argument-type, bad-index]
return tensor[tuple(indices)]
# Create zero result matching mask shape
result = torch.zeros(extra_mask.shape, dtype=tensor.dtype, device=tensor.device)
# Process indices: convert RefTiles and clamp tensor indices
orig_indices, safe_indices, is_tensor_mask = [], [], []
for i, idx in enumerate(index):
if isinstance(idx, RefTile):
idx = idx.index # Convert RefTile to tensor
if isinstance(idx, torch.Tensor):
dim_size = tensor.shape[i] if i < len(tensor.shape) else tensor.numel()
orig_indices.append(idx)
safe_indices.append(torch.clamp(idx, 0, dim_size - 1))
is_tensor_mask.append(True)
else:
orig_indices.append(idx)
safe_indices.append(idx)
is_tensor_mask.append(False)
# Apply broadcasting if we have multiple tensor indices
tensor_positions = [i for i, is_tensor in enumerate(is_tensor_mask) if is_tensor]
if len(tensor_positions) > 1:
# Add unsqueeze operations for broadcasting
broadcast_indices = []
for i, (idx, is_tensor) in enumerate(
zip(safe_indices, is_tensor_mask, strict=False)
):
if is_tensor:
new_idx = idx
# Add dimension for each other tensor index
for j, other_pos in enumerate(tensor_positions):
if other_pos != i:
new_idx = new_idx.unsqueeze(j if other_pos < i else -1)
broadcast_indices.append(new_idx)
else:
broadcast_indices.append(idx)
values = tensor[tuple(broadcast_indices)]
else:
values = tensor[tuple(safe_indices)]
# Build validity mask
valid_mask = extra_mask.clone()
for i, (orig_idx, is_tensor) in enumerate(
zip(orig_indices, is_tensor_mask, strict=False)
):
if is_tensor:
dim_size = tensor.shape[i] if i < len(tensor.shape) else tensor.numel()
in_bounds = (orig_idx >= 0) & (orig_idx < dim_size)
# Broadcast to match mask shape by adding dimensions
# Count how many tensor indices come before and after this one
n_before = sum(1 for j in range(i) if is_tensor_mask[j])
n_after = sum(
1 for j in range(i + 1, len(is_tensor_mask)) if is_tensor_mask[j]
)
# Add dimensions: n_after dimensions at the end, n_before at the beginning
for _ in range(n_after):
in_bounds = in_bounds.unsqueeze(-1)
for _ in range(n_before):
in_bounds = in_bounds.unsqueeze(0)
valid_mask = valid_mask & in_bounds
return torch.where(valid_mask, values, result)