Source code for helion.language.memory_ops

from __future__ import annotations

import ast
from typing import TYPE_CHECKING

import torch
from torch._inductor.codegen.simd import constant_repr
from torch.fx import has_side_effect

from .. import exc
from .._compiler.ast_extension import expr_from_string
from .._compiler.indexing_strategy import SubscriptIndexing
from . import _decorators

if TYPE_CHECKING:
    from .._compiler.inductor_lowering import CodegenState

__all__ = ["atomic_add", "load", "store"]


[docs] @has_side_effect @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def store( tensor: torch.Tensor, index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, ) -> None: """Store a value to a tensor using a list of indices. This function is equivalent to `tensor[index] = value` but allows setting `extra_mask=` to mask elements beyond the default masking based on the hl.tile range. Args: tensor: The tensor to store to index: The indices to use to index into the tensor value: The value to store extra_mask: The extra mask (beyond automatic tile bounds masking) to apply to the tensor Returns: None """ raise exc.NotInsideKernel
@_decorators.prepare_args(store) def _( tensor: torch.Tensor, index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, list[object], torch.Tensor | torch.SymInt | float, torch.Tensor | None ]: from .tile_proxy import Tile if isinstance(value, torch.Tensor) and value.dtype != tensor.dtype: value = value.to(tensor.dtype) index = Tile._tiles_to_sizes(index) return (tensor, index, value, extra_mask) @_decorators.register_fake(store) def _( tensor: torch.Tensor, index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, ) -> None: return None @_decorators.codegen(store) def _(state: CodegenState) -> ast.AST: tensor = state.proxy_arg(0) assert isinstance(tensor, torch.Tensor) subscript = state.proxy_arg(1) assert isinstance(subscript, (list, tuple)) value = state.ast_arg(2) extra_mask = state.ast_args[3] assert isinstance(extra_mask, (type(None), ast.AST)) return state.device_function.indexing_strategy.codegen_store( state, tensor, [*subscript], value, extra_mask ) @_decorators.ref(store) def _( tensor: torch.Tensor, index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, ) -> None: # Convert index list to tuple for tensor indexing index_tuple = tuple(index) # Apply extra mask if provided if extra_mask is not None: # Only store where the mask is True if isinstance(value, torch.Tensor): tensor[index_tuple] = torch.where(extra_mask, value, tensor[index_tuple]) # pyright: ignore[reportArgumentType] else: # For scalar values, we need to create a tensor of the right shape current = tensor[index_tuple] # pyright: ignore[reportArgumentType] # Cast value to a proper numeric type for full_like if isinstance(value, torch.SymInt): numeric_value = int(value) else: numeric_value = value tensor[index_tuple] = torch.where( # pyright: ignore[reportArgumentType] extra_mask, torch.full_like(current, numeric_value), current ) else: # Handle SymInt case for assignment if isinstance(value, torch.SymInt): tensor[index_tuple] = int(value) # pyright: ignore[reportArgumentType] else: tensor[index_tuple] = value # pyright: ignore[reportArgumentType]
[docs] @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def load( tensor: torch.Tensor, index: list[object], extra_mask: torch.Tensor | None = None, ) -> torch.Tensor: """Load a value from a tensor using a list of indices. This function is equivalent to `tensor[index]` but allows setting `extra_mask=` to mask elements beyond the default masking based on the hl.tile range. Args: tensor: The tensor to load from index: The indices to use to index into the tensor extra_mask: The extra mask (beyond automatic tile bounds masking) to apply to the tensor Returns: torch.Tensor: The loaded value """ raise exc.NotInsideKernel
@_decorators.register_fake(load) def _( tensor: torch.Tensor, index: list[object], extra_mask: torch.Tensor | None = None ) -> torch.Tensor: return tensor.new_empty(SubscriptIndexing.compute_shape(tensor, index)) @_decorators.codegen(load) def _(state: CodegenState) -> ast.AST: tensor = state.proxy_arg(0) assert isinstance(tensor, torch.Tensor) subscript = state.proxy_arg(1) assert isinstance(subscript, (list, tuple)) extra_mask = state.ast_args[2] assert isinstance(extra_mask, (type(None), ast.AST)) return state.device_function.indexing_strategy.codegen_load( state, tensor, [*subscript], extra_mask ) @_decorators.get_masked_value(load) def _(node: torch.fx.Node) -> int: return 0 # loads are always masked to 0 @_decorators.ref(load) def _( tensor: torch.Tensor, index: list[object], extra_mask: torch.Tensor | None = None, ) -> torch.Tensor: from .ref_tile import RefTile if extra_mask is None: return tensor[tuple(index)] # pyright: ignore[reportArgumentType] # Create zero result matching mask shape result = torch.zeros(extra_mask.shape, dtype=tensor.dtype, device=tensor.device) # Process indices: convert RefTiles and clamp tensor indices orig_indices, safe_indices, is_tensor_mask = [], [], [] for i, idx in enumerate(index): if isinstance(idx, RefTile): idx = idx.index # Convert RefTile to tensor if isinstance(idx, torch.Tensor): dim_size = tensor.shape[i] if i < len(tensor.shape) else tensor.numel() orig_indices.append(idx) safe_indices.append(torch.clamp(idx, 0, dim_size - 1)) is_tensor_mask.append(True) else: orig_indices.append(idx) safe_indices.append(idx) is_tensor_mask.append(False) # Apply broadcasting if we have multiple tensor indices tensor_positions = [i for i, is_tensor in enumerate(is_tensor_mask) if is_tensor] if len(tensor_positions) > 1: # Add unsqueeze operations for broadcasting broadcast_indices = [] for i, (idx, is_tensor) in enumerate( zip(safe_indices, is_tensor_mask, strict=False) ): if is_tensor: new_idx = idx # Add dimension for each other tensor index for j, other_pos in enumerate(tensor_positions): if other_pos != i: new_idx = new_idx.unsqueeze(j if other_pos < i else -1) broadcast_indices.append(new_idx) else: broadcast_indices.append(idx) values = tensor[tuple(broadcast_indices)] else: values = tensor[tuple(safe_indices)] # Build validity mask valid_mask = extra_mask.clone() for i, (orig_idx, is_tensor) in enumerate( zip(orig_indices, is_tensor_mask, strict=False) ): if is_tensor: dim_size = tensor.shape[i] if i < len(tensor.shape) else tensor.numel() in_bounds = (orig_idx >= 0) & (orig_idx < dim_size) # Broadcast to match mask shape by adding dimensions # Count how many tensor indices come before and after this one n_before = sum(1 for j in range(i) if is_tensor_mask[j]) n_after = sum( 1 for j in range(i + 1, len(is_tensor_mask)) if is_tensor_mask[j] ) # Add dimensions: n_after dimensions at the end, n_before at the beginning for _ in range(n_after): in_bounds = in_bounds.unsqueeze(-1) for _ in range(n_before): in_bounds = in_bounds.unsqueeze(0) valid_mask = valid_mask & in_bounds return torch.where(valid_mask, values, result)
[docs] @has_side_effect @_decorators.api(allow_host_tensor=True) def atomic_add( target: torch.Tensor, index: list[object], value: torch.Tensor | float, sem: str = "relaxed", ) -> None: """ Atomically add a value to a target tensor. Performs an atomic read-modify-write operation that adds value to target[index]. This is safe for concurrent access from multiple threads/blocks. Args: target: The tensor to add to index: Indices into target for accumulating values value: The value to add (tensor or scalar) sem: Memory ordering semantics (default: 'relaxed') - 'relaxed': No ordering constraints - 'acquire': Acquire semantics - 'release': Release semantics - 'acq_rel': Acquire-release semantics Returns: None Examples: .. code-block:: python @helion.kernel def global_sum(x: torch.Tensor, result: torch.Tensor) -> torch.Tensor: # Each tile computes local sum, then atomically adds to global for tile in hl.tile(x.size(0)): local_data = x[tile] local_sum = local_data.sum() hl.atomic_add(result, [0], local_sum) return result See Also: - :func:`~helion.language.store`: For non-atomic stores - :func:`~helion.language.load`: For atomic loads Note: - Required for race-free accumulation across parallel execution - Performance depends on memory access patterns and contention - Consider using regular operations when atomicity isn't needed - Higher memory semantics (acquire/release) have performance overhead """ raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_add) def _( target: torch.Tensor, index: list[object], value: torch.Tensor | float, sem: str = "relaxed", ) -> tuple[torch.Tensor, object, torch.Tensor | float | int, str]: from .tile_proxy import Tile valid_sems = {"relaxed", "acquire", "release", "acq_rel"} if sem not in valid_sems: raise ValueError( f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}." ) index = Tile._prepare_index(index) index = Tile._tiles_to_sizes(index) return (target, index, value, sem) @_decorators.register_fake(atomic_add) def _( target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed" ) -> None: return None @_decorators.ref(atomic_add) def _( target: torch.Tensor, index: list[object], value: torch.Tensor | float, sem: str = "relaxed", ) -> None: """Reference implementation of atomic_add for interpret mode.""" from .. import exc from .ref_tile import RefTile # Validate sem parameter if sem not in ["relaxed", "acquire", "release", "acq_rel"]: raise exc.InternalError( ValueError( f"Invalid memory semantic '{sem}'. Valid options are: relaxed, acquire, release, acq_rel" ) ) # Convert indices to proper format processed_index = [] for idx in index: if isinstance(idx, RefTile): processed_index.append(idx._slice) elif isinstance(idx, torch.Tensor) and idx.numel() == 1: processed_index.append(int(idx.item())) else: processed_index.append(idx) # Find tensor indices that need element-wise processing tensor_indices = [ (i, idx) for i, idx in enumerate(processed_index) if isinstance(idx, torch.Tensor) and idx.numel() > 1 ] if tensor_indices: # Element-wise processing for tensor indices i, tensor_idx = tensor_indices[0] # Handle first tensor index for j, elem in enumerate(tensor_idx): new_index = processed_index.copy() new_index[i] = int(elem.item()) val = ( value[j] if isinstance(value, torch.Tensor) and value.numel() > 1 else value ) target[tuple(new_index)] += val else: # Direct atomic add target[tuple(processed_index)] += value @_decorators.codegen(atomic_add) def _(state: CodegenState) -> ast.AST: target = state.proxy_arg(0) index = state.proxy_arg(1) sem = expr_from_string(repr(state.proxy_arg(3))) assert isinstance(target, torch.Tensor) assert isinstance(index, list) indices = SubscriptIndexing.create(state, target, index) name = state.device_function.tensor_arg(target).name value_expr = state.ast_args[2] if isinstance(value_expr, (int, float, bool)): value_expr = expr_from_string(constant_repr(value_expr)) assert isinstance(value_expr, ast.AST) return expr_from_string( f"tl.atomic_add({name} + offset, value, mask=mask, sem=sem)", value=value_expr, offset=indices.index_expr, mask=indices.mask_expr, sem=sem, )