from __future__ import annotations
import ast
from typing import TYPE_CHECKING
import torch
from torch.fx import has_side_effect
from .. import exc
from .._compiler.indexing_strategy import SubscriptIndexing
from . import _decorators
from .stack_tensor import StackTensor
if TYPE_CHECKING:
from .._compiler.inductor_lowering import CodegenState
__all__ = ["load", "store"]
# Map short config names to full Triton API names for eviction policies
_EVICTION_POLICY_MAP = {
"": None,
"first": "evict_first",
"last": "evict_last",
}
[docs]
@has_side_effect
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
def store(
tensor: torch.Tensor | StackTensor,
index: list[object],
value: torch.Tensor | torch.SymInt | float,
extra_mask: torch.Tensor | None = None,
) -> None:
"""Store a value to a tensor using a list of indices.
This function is equivalent to `tensor[index] = value` but allows
setting `extra_mask=` to mask elements beyond the default masking
based on the hl.tile range.
Args:
tensor: The tensor / stack tensor to store to
index: The indices to use to index into the tensor
value: The value to store
extra_mask: The extra mask (beyond automatic tile bounds masking) to apply to the tensor
Returns:
None
"""
raise exc.NotInsideKernel
@_decorators.prepare_args(store)
def _(
tensor: torch.Tensor | StackTensor,
index: list[object],
value: torch.Tensor | torch.SymInt | float,
extra_mask: torch.Tensor | None = None,
) -> tuple[
torch.Tensor | tuple,
list[object],
torch.Tensor | torch.SymInt | float | int,
torch.Tensor | None,
]:
from .tile_proxy import Tile
if isinstance(value, torch.Tensor) and value.dtype != tensor.dtype:
value = value.to(tensor.dtype)
index = Tile._tiles_to_sizes(index)
if isinstance(tensor, StackTensor):
return (tuple(tensor), index, value, extra_mask)
if isinstance(tensor, torch.Tensor):
return (tensor, index, value, extra_mask)
raise NotImplementedError(f"Cannot store to type: {type(tensor)}")
@_decorators.register_fake(store)
def _(
tensor: torch.Tensor | tuple[object, ...],
index: list[object],
value: torch.Tensor | torch.SymInt | float,
extra_mask: torch.Tensor | None = None,
) -> None:
return None
@_decorators.codegen(store)
def _(state: CodegenState) -> ast.AST:
tensor = state.proxy_arg(0)
subscript = state.proxy_arg(1)
assert isinstance(subscript, (list, tuple))
value = state.ast_arg(2)
extra_mask = state.ast_args[3]
assert isinstance(extra_mask, (type(None), ast.AST))
if isinstance(tensor, torch.Tensor):
device_fn = state.device_function
device_fn.device_store_index += 1
# Use the shared memory op index for indexing strategy
indexing_idx = device_fn.device_memory_op_index
device_fn.device_memory_op_index += 1
strategy = device_fn.get_indexing_strategy(indexing_idx)
return strategy.codegen_store(state, tensor, [*subscript], value, extra_mask)
if isinstance(tensor, tuple):
from .._compiler.indexing_strategy import StackIndexingStrategy
stack_tensor_ast = state.ast_args[0]
assert isinstance(stack_tensor_ast, tuple)
assert len(stack_tensor_ast) == 2
tensor_like_ast, dev_ptrs_ast = stack_tensor_ast
return StackIndexingStrategy.codegen_store(
state, tensor, dev_ptrs_ast, [*subscript], value, extra_mask
)
raise NotImplementedError(f"Cannot store to type: {type(tensor)}")
# TODO(joydddd): Add support for stack tensor in ref mode.
@_decorators.ref(store)
def _(
tensor: torch.Tensor,
index: list[object],
value: torch.Tensor | torch.SymInt | float,
extra_mask: torch.Tensor | None = None,
) -> None:
from .ref_tile import RefTile
# Normalize indices and identify tensor indices
indices = []
tensor_idx_positions = []
for i, idx in enumerate(index):
if isinstance(idx, RefTile):
idx = idx.index
indices.append(idx)
if isinstance(idx, torch.Tensor):
tensor_idx_positions.append(i)
# Handle broadcasting for multiple tensor indices
if len(tensor_idx_positions) > 1:
grids = torch.meshgrid(
*(indices[i] for i in tensor_idx_positions), indexing="ij"
)
for i, grid in zip(tensor_idx_positions, grids, strict=False):
indices[i] = grid
if extra_mask is not None:
mask = extra_mask.to(torch.bool)
# Check bounds for tensor indices
for i, idx in enumerate(indices):
if isinstance(idx, torch.Tensor):
mask = mask & (idx >= 0) & (idx < tensor.shape[i])
mask_count = int(mask.sum().item())
if mask_count == 0:
return
# Use index_put_ for masked stores
valid_indices = []
for idx in indices:
if isinstance(idx, torch.Tensor):
valid_indices.append(idx[mask].long())
else:
idx_val = int(idx) if isinstance(idx, torch.SymInt) else idx
valid_indices.append(
torch.full(
(mask_count,), idx_val, dtype=torch.long, device=tensor.device
)
)
if isinstance(value, torch.Tensor):
values = value[mask]
else:
val = int(value) if isinstance(value, torch.SymInt) else value
values = torch.full(
(mask_count,), val, dtype=tensor.dtype, device=tensor.device
)
tensor.index_put_(tuple(valid_indices), values, accumulate=False)
return
# Simple assignment
tensor[tuple(indices)] = int(value) if isinstance(value, torch.SymInt) else value
[docs]
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
def load(
tensor: torch.Tensor | StackTensor,
index: list[object],
extra_mask: torch.Tensor | None = None,
eviction_policy: str | None = None,
) -> torch.Tensor:
"""Load a value from a tensor using a list of indices.
This function is equivalent to `tensor[index]` but allows
setting `extra_mask=` to mask elements beyond the default masking
based on the hl.tile range. It also accepts an optional
`eviction_policy` which is forwarded to the underlying Triton `tl.load`
call to control the cache eviction behavior (e.g., "evict_last").
Args:
tensor: The tensor / stack tensor to load from
index: The indices to use to index into the tensor
extra_mask: The extra mask (beyond automatic tile bounds masking) to apply to the tensor
eviction_policy: Optional Triton load eviction policy to hint cache behavior
Returns:
torch.Tensor: The loaded value
"""
raise exc.NotInsideKernel
@_decorators.prepare_args(load)
def _(
tensor: torch.Tensor | StackTensor,
index: list[object],
extra_mask: torch.Tensor | None = None,
eviction_policy: str | None = None,
) -> tuple[torch.Tensor | tuple, list[object], torch.Tensor | None, str | None]:
from .tile_proxy import Tile
index = Tile._tiles_to_sizes(index)
if isinstance(tensor, StackTensor):
return (tuple(tensor), index, extra_mask, eviction_policy)
assert isinstance(tensor, torch.Tensor)
return (tensor, index, extra_mask, eviction_policy)
@_decorators.register_fake(load)
def _(
tensor: torch.Tensor | tuple[object, ...],
index: list[object],
extra_mask: torch.Tensor | None = None,
eviction_policy: str | None = None,
) -> torch.Tensor:
if isinstance(tensor, torch.Tensor):
target_shape = SubscriptIndexing.compute_shape(tensor, index)
return tensor.new_empty(target_shape)
if isinstance(tensor, tuple):
tensor_like, dev_ptrs = tensor
assert isinstance(tensor_like, torch.Tensor)
assert isinstance(dev_ptrs, torch.Tensor)
tensor_shape = SubscriptIndexing.compute_shape(tensor_like, index)
target_shape = list(dev_ptrs.size()) + tensor_shape
return tensor_like.new_empty(target_shape)
raise NotImplementedError(f"Unsupported tensor type: {type(tensor)}")
@_decorators.codegen(load)
def _(state: CodegenState) -> ast.AST:
tensor = state.proxy_arg(0)
subscript = state.proxy_arg(1)
assert isinstance(subscript, (list, tuple))
extra_mask = state.ast_args[2]
assert isinstance(extra_mask, (type(None), ast.AST))
eviction_policy = state.ast_args[3] if len(state.ast_args) > 3 else None
device_fn = state.device_function
load_idx = device_fn.device_load_index
device_fn.device_load_index += 1
# If no explicit eviction_policy and we're in device code, use tunable
if eviction_policy is None and state.codegen.on_device:
policies = state.config.load_eviction_policies
if load_idx < len(policies):
policy_value = policies[load_idx]
eviction_policy = _EVICTION_POLICY_MAP.get(policy_value, policy_value)
if eviction_policy is not None:
assert isinstance(eviction_policy, str)
eviction_policy = ast.Constant(value=eviction_policy)
if isinstance(tensor, torch.Tensor):
# Use the shared memory op index for indexing strategy
indexing_idx = device_fn.device_memory_op_index
device_fn.device_memory_op_index += 1
strategy = device_fn.get_indexing_strategy(indexing_idx)
return strategy.codegen_load(
state, tensor, [*subscript], extra_mask, eviction_policy
)
if isinstance(tensor, tuple):
from .._compiler.indexing_strategy import StackIndexingStrategy
stack_tensor_ast = state.ast_args[0]
assert isinstance(stack_tensor_ast, tuple)
assert len(stack_tensor_ast) == 2
tensor_like_ast, dev_ptrs_ast = stack_tensor_ast
return StackIndexingStrategy.codegen_load(
state, tensor, dev_ptrs_ast, [*subscript], extra_mask, eviction_policy
)
raise NotImplementedError(f"Unsupported tensor type: {type(tensor)}")
@_decorators.get_masked_value(load)
def _(node: torch.fx.Node) -> int:
return 0 # loads are always masked to 0
# TODO(joydddd): Add support for stack tensor in ref mode.
@_decorators.ref(load)
def _(
tensor: torch.Tensor,
index: list[object],
extra_mask: torch.Tensor | None = None,
eviction_policy: str | None = None,
) -> torch.Tensor:
from .ref_tile import RefTile
if extra_mask is None:
return tensor[tuple(index)] # pyright: ignore[reportArgumentType]
# Create zero result matching mask shape
result = torch.zeros(extra_mask.shape, dtype=tensor.dtype, device=tensor.device)
# Process indices: convert RefTiles and clamp tensor indices
orig_indices, safe_indices, is_tensor_mask = [], [], []
for i, idx in enumerate(index):
if isinstance(idx, RefTile):
idx = idx.index # Convert RefTile to tensor
if isinstance(idx, torch.Tensor):
dim_size = tensor.shape[i] if i < len(tensor.shape) else tensor.numel()
orig_indices.append(idx)
safe_indices.append(torch.clamp(idx, 0, dim_size - 1))
is_tensor_mask.append(True)
else:
orig_indices.append(idx)
safe_indices.append(idx)
is_tensor_mask.append(False)
# Apply broadcasting if we have multiple tensor indices
tensor_positions = [i for i, is_tensor in enumerate(is_tensor_mask) if is_tensor]
if len(tensor_positions) > 1:
# Add unsqueeze operations for broadcasting
broadcast_indices = []
for i, (idx, is_tensor) in enumerate(
zip(safe_indices, is_tensor_mask, strict=False)
):
if is_tensor:
new_idx = idx
# Add dimension for each other tensor index
for j, other_pos in enumerate(tensor_positions):
if other_pos != i:
new_idx = new_idx.unsqueeze(j if other_pos < i else -1)
broadcast_indices.append(new_idx)
else:
broadcast_indices.append(idx)
values = tensor[tuple(broadcast_indices)]
else:
values = tensor[tuple(safe_indices)]
# Build validity mask
valid_mask = extra_mask.clone()
for i, (orig_idx, is_tensor) in enumerate(
zip(orig_indices, is_tensor_mask, strict=False)
):
if is_tensor:
dim_size = tensor.shape[i] if i < len(tensor.shape) else tensor.numel()
in_bounds = (orig_idx >= 0) & (orig_idx < dim_size)
# Broadcast to match mask shape by adding dimensions
# Count how many tensor indices come before and after this one
n_before = sum(1 for j in range(i) if is_tensor_mask[j])
n_after = sum(
1 for j in range(i + 1, len(is_tensor_mask)) if is_tensor_mask[j]
)
# Add dimensions: n_after dimensions at the end, n_before at the beginning
for _ in range(n_after):
in_bounds = in_bounds.unsqueeze(-1)
for _ in range(n_before):
in_bounds = in_bounds.unsqueeze(0)
valid_mask = valid_mask & in_bounds
return torch.where(valid_mask, values, result)