Rate this Page

Source code for helion.language.memory_ops

from __future__ import annotations

import ast
from typing import TYPE_CHECKING

import torch
from torch.fx import has_side_effect

from .. import exc
from .._compiler.indexing_strategy import SubscriptIndexing
from . import _decorators
from .stack_tensor import StackTensor

if TYPE_CHECKING:
    from .._compiler.inductor_lowering import CodegenState

__all__ = ["load", "store"]

# Map short config names to full Triton API names for eviction policies
_EVICTION_POLICY_MAP = {
    "": None,
    "first": "evict_first",
    "last": "evict_last",
}


[docs] @has_side_effect @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def store( tensor: torch.Tensor | StackTensor, index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, ) -> None: """Store a value to a tensor using a list of indices. This function is equivalent to `tensor[index] = value` but allows setting `extra_mask=` to mask elements beyond the default masking based on the hl.tile range. Args: tensor: The tensor / stack tensor to store to index: The indices to use to index into the tensor value: The value to store extra_mask: The extra mask (beyond automatic tile bounds masking) to apply to the tensor Returns: None """ raise exc.NotInsideKernel
@_decorators.prepare_args(store) def _( tensor: torch.Tensor | StackTensor, index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, ) -> tuple[ torch.Tensor | tuple, list[object], torch.Tensor | torch.SymInt | float | int, torch.Tensor | None, ]: from .tile_proxy import Tile if isinstance(value, torch.Tensor) and value.dtype != tensor.dtype: value = value.to(tensor.dtype) index = Tile._tiles_to_sizes(index) if isinstance(tensor, StackTensor): return (tuple(tensor), index, value, extra_mask) if isinstance(tensor, torch.Tensor): return (tensor, index, value, extra_mask) raise NotImplementedError(f"Cannot store to type: {type(tensor)}") @_decorators.register_fake(store) def _( tensor: torch.Tensor | tuple[object, ...], index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, ) -> None: return None @_decorators.codegen(store) def _(state: CodegenState) -> ast.AST: tensor = state.proxy_arg(0) subscript = state.proxy_arg(1) assert isinstance(subscript, (list, tuple)) value = state.ast_arg(2) extra_mask = state.ast_args[3] assert isinstance(extra_mask, (type(None), ast.AST)) if isinstance(tensor, torch.Tensor): device_fn = state.device_function device_fn.device_store_index += 1 # Use the shared memory op index for indexing strategy indexing_idx = device_fn.device_memory_op_index device_fn.device_memory_op_index += 1 strategy = device_fn.get_indexing_strategy(indexing_idx) return strategy.codegen_store(state, tensor, [*subscript], value, extra_mask) if isinstance(tensor, tuple): from .._compiler.indexing_strategy import StackIndexingStrategy stack_tensor_ast = state.ast_args[0] assert isinstance(stack_tensor_ast, tuple) assert len(stack_tensor_ast) == 2 tensor_like_ast, dev_ptrs_ast = stack_tensor_ast return StackIndexingStrategy.codegen_store( state, tensor, dev_ptrs_ast, [*subscript], value, extra_mask ) raise NotImplementedError(f"Cannot store to type: {type(tensor)}") # TODO(joydddd): Add support for stack tensor in ref mode. @_decorators.ref(store) def _( tensor: torch.Tensor, index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, ) -> None: from .ref_tile import RefTile # Normalize indices and identify tensor indices indices = [] tensor_idx_positions = [] for i, idx in enumerate(index): if isinstance(idx, RefTile): idx = idx.index indices.append(idx) if isinstance(idx, torch.Tensor): tensor_idx_positions.append(i) # Handle broadcasting for multiple tensor indices if len(tensor_idx_positions) > 1: grids = torch.meshgrid( *(indices[i] for i in tensor_idx_positions), indexing="ij" ) for i, grid in zip(tensor_idx_positions, grids, strict=False): indices[i] = grid if extra_mask is not None: mask = extra_mask.to(torch.bool) # Check bounds for tensor indices for i, idx in enumerate(indices): if isinstance(idx, torch.Tensor): mask = mask & (idx >= 0) & (idx < tensor.shape[i]) mask_count = int(mask.sum().item()) if mask_count == 0: return # Use index_put_ for masked stores valid_indices = [] for idx in indices: if isinstance(idx, torch.Tensor): valid_indices.append(idx[mask].long()) else: idx_val = int(idx) if isinstance(idx, torch.SymInt) else idx valid_indices.append( torch.full( (mask_count,), idx_val, dtype=torch.long, device=tensor.device ) ) if isinstance(value, torch.Tensor): values = value[mask] else: val = int(value) if isinstance(value, torch.SymInt) else value values = torch.full( (mask_count,), val, dtype=tensor.dtype, device=tensor.device ) tensor.index_put_(tuple(valid_indices), values, accumulate=False) return # Simple assignment tensor[tuple(indices)] = int(value) if isinstance(value, torch.SymInt) else value
[docs] @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def load( tensor: torch.Tensor | StackTensor, index: list[object], extra_mask: torch.Tensor | None = None, eviction_policy: str | None = None, ) -> torch.Tensor: """Load a value from a tensor using a list of indices. This function is equivalent to `tensor[index]` but allows setting `extra_mask=` to mask elements beyond the default masking based on the hl.tile range. It also accepts an optional `eviction_policy` which is forwarded to the underlying Triton `tl.load` call to control the cache eviction behavior (e.g., "evict_last"). Args: tensor: The tensor / stack tensor to load from index: The indices to use to index into the tensor extra_mask: The extra mask (beyond automatic tile bounds masking) to apply to the tensor eviction_policy: Optional Triton load eviction policy to hint cache behavior Returns: torch.Tensor: The loaded value """ raise exc.NotInsideKernel
@_decorators.prepare_args(load) def _( tensor: torch.Tensor | StackTensor, index: list[object], extra_mask: torch.Tensor | None = None, eviction_policy: str | None = None, ) -> tuple[torch.Tensor | tuple, list[object], torch.Tensor | None, str | None]: from .tile_proxy import Tile index = Tile._tiles_to_sizes(index) if isinstance(tensor, StackTensor): return (tuple(tensor), index, extra_mask, eviction_policy) assert isinstance(tensor, torch.Tensor) return (tensor, index, extra_mask, eviction_policy) @_decorators.register_fake(load) def _( tensor: torch.Tensor | tuple[object, ...], index: list[object], extra_mask: torch.Tensor | None = None, eviction_policy: str | None = None, ) -> torch.Tensor: if isinstance(tensor, torch.Tensor): target_shape = SubscriptIndexing.compute_shape(tensor, index) return tensor.new_empty(target_shape) if isinstance(tensor, tuple): tensor_like, dev_ptrs = tensor assert isinstance(tensor_like, torch.Tensor) assert isinstance(dev_ptrs, torch.Tensor) tensor_shape = SubscriptIndexing.compute_shape(tensor_like, index) target_shape = list(dev_ptrs.size()) + tensor_shape return tensor_like.new_empty(target_shape) raise NotImplementedError(f"Unsupported tensor type: {type(tensor)}") @_decorators.codegen(load) def _(state: CodegenState) -> ast.AST: tensor = state.proxy_arg(0) subscript = state.proxy_arg(1) assert isinstance(subscript, (list, tuple)) extra_mask = state.ast_args[2] assert isinstance(extra_mask, (type(None), ast.AST)) eviction_policy = state.ast_args[3] if len(state.ast_args) > 3 else None device_fn = state.device_function load_idx = device_fn.device_load_index device_fn.device_load_index += 1 # If no explicit eviction_policy and we're in device code, use tunable if eviction_policy is None and state.codegen.on_device: policies = state.config.load_eviction_policies if load_idx < len(policies): policy_value = policies[load_idx] eviction_policy = _EVICTION_POLICY_MAP.get(policy_value, policy_value) if eviction_policy is not None: assert isinstance(eviction_policy, str) eviction_policy = ast.Constant(value=eviction_policy) if isinstance(tensor, torch.Tensor): # Use the shared memory op index for indexing strategy indexing_idx = device_fn.device_memory_op_index device_fn.device_memory_op_index += 1 strategy = device_fn.get_indexing_strategy(indexing_idx) return strategy.codegen_load( state, tensor, [*subscript], extra_mask, eviction_policy ) if isinstance(tensor, tuple): from .._compiler.indexing_strategy import StackIndexingStrategy stack_tensor_ast = state.ast_args[0] assert isinstance(stack_tensor_ast, tuple) assert len(stack_tensor_ast) == 2 tensor_like_ast, dev_ptrs_ast = stack_tensor_ast return StackIndexingStrategy.codegen_load( state, tensor, dev_ptrs_ast, [*subscript], extra_mask, eviction_policy ) raise NotImplementedError(f"Unsupported tensor type: {type(tensor)}") @_decorators.get_masked_value(load) def _(node: torch.fx.Node) -> int: return 0 # loads are always masked to 0 # TODO(joydddd): Add support for stack tensor in ref mode. @_decorators.ref(load) def _( tensor: torch.Tensor, index: list[object], extra_mask: torch.Tensor | None = None, eviction_policy: str | None = None, ) -> torch.Tensor: from .ref_tile import RefTile if extra_mask is None: return tensor[tuple(index)] # pyright: ignore[reportArgumentType] # Create zero result matching mask shape result = torch.zeros(extra_mask.shape, dtype=tensor.dtype, device=tensor.device) # Process indices: convert RefTiles and clamp tensor indices orig_indices, safe_indices, is_tensor_mask = [], [], [] for i, idx in enumerate(index): if isinstance(idx, RefTile): idx = idx.index # Convert RefTile to tensor if isinstance(idx, torch.Tensor): dim_size = tensor.shape[i] if i < len(tensor.shape) else tensor.numel() orig_indices.append(idx) safe_indices.append(torch.clamp(idx, 0, dim_size - 1)) is_tensor_mask.append(True) else: orig_indices.append(idx) safe_indices.append(idx) is_tensor_mask.append(False) # Apply broadcasting if we have multiple tensor indices tensor_positions = [i for i, is_tensor in enumerate(is_tensor_mask) if is_tensor] if len(tensor_positions) > 1: # Add unsqueeze operations for broadcasting broadcast_indices = [] for i, (idx, is_tensor) in enumerate( zip(safe_indices, is_tensor_mask, strict=False) ): if is_tensor: new_idx = idx # Add dimension for each other tensor index for j, other_pos in enumerate(tensor_positions): if other_pos != i: new_idx = new_idx.unsqueeze(j if other_pos < i else -1) broadcast_indices.append(new_idx) else: broadcast_indices.append(idx) values = tensor[tuple(broadcast_indices)] else: values = tensor[tuple(safe_indices)] # Build validity mask valid_mask = extra_mask.clone() for i, (orig_idx, is_tensor) in enumerate( zip(orig_indices, is_tensor_mask, strict=False) ): if is_tensor: dim_size = tensor.shape[i] if i < len(tensor.shape) else tensor.numel() in_bounds = (orig_idx >= 0) & (orig_idx < dim_size) # Broadcast to match mask shape by adding dimensions # Count how many tensor indices come before and after this one n_before = sum(1 for j in range(i) if is_tensor_mask[j]) n_after = sum( 1 for j in range(i + 1, len(is_tensor_mask)) if is_tensor_mask[j] ) # Add dimensions: n_after dimensions at the end, n_before at the beginning for _ in range(n_after): in_bounds = in_bounds.unsqueeze(-1) for _ in range(n_before): in_bounds = in_bounds.unsqueeze(0) valid_mask = valid_mask & in_bounds return torch.where(valid_mask, values, result)