Source code for helion.language.atomic_ops

from __future__ import annotations

import ast
from typing import TYPE_CHECKING
from typing import Callable

import torch
from torch._inductor.codegen.simd import constant_repr
from torch.fx import has_side_effect

from .. import exc
from .._compiler.ast_extension import expr_from_string
from .._compiler.indexing_strategy import SubscriptIndexing
from . import _decorators

if TYPE_CHECKING:
    from .._compiler.inductor_lowering import CodegenState

__all__ = [
    "atomic_add",
    "atomic_and",
    "atomic_cas",
    "atomic_max",
    "atomic_min",
    "atomic_or",
    "atomic_xchg",
    "atomic_xor",
]


_VALID_SEMS: set[str] = {"relaxed", "acquire", "release", "acq_rel"}


def _validate_sem(sem: str) -> None:
    if sem not in _VALID_SEMS:
        raise exc.InternalError(
            ValueError(
                f"Invalid memory semantic '{sem}'. Valid options are: relaxed, acquire, release, acq_rel"
            )
        )


def _prepare_mem_args(
    target: torch.Tensor,
    index: list[object],
    *values: object,
    sem: str = "relaxed",
) -> tuple:
    from .tile_proxy import Tile

    _validate_sem(sem)
    index = Tile._prepare_index(index)
    index = Tile._tiles_to_sizes(index)
    return (target, index, *values, sem)


def _codegen_common(
    tl_func: str, state: CodegenState, value_exprs: list[ast.AST]
) -> ast.AST:
    target = state.proxy_arg(0)
    index = state.proxy_arg(1)
    sem = expr_from_string(repr(state.proxy_arg(len(state.ast_args) - 1)))

    assert isinstance(target, torch.Tensor)
    assert isinstance(index, list)

    indices = SubscriptIndexing.create(state, target, index)
    name = state.device_function.tensor_arg(target).name

    placeholder_names = [f"v{i}" for i in range(len(value_exprs))]
    values_section = (
        ", " + ", ".join([f"{{{n}}}" for n in placeholder_names]) if value_exprs else ""
    )
    placeholders = dict(zip(placeholder_names, value_exprs, strict=False))
    return expr_from_string(
        f"tl.{tl_func}({name} + {{offset}}{values_section}, mask={{mask}}, sem={{sem}})",
        offset=indices.index_expr,
        mask=indices.mask_expr,
        sem=sem,
        **placeholders,
    )


def _to_ast_values(values: list[object]) -> list[ast.AST]:
    out: list[ast.AST] = []
    for v in values:
        if isinstance(v, (int, float, bool)):
            out.append(expr_from_string(constant_repr(v)))
        else:
            assert isinstance(v, ast.AST)
            out.append(v)
    return out


def _ref_apply(
    target: torch.Tensor,
    index: list[object],
    apply_fn: Callable[[torch.Tensor, tuple, object], None],
    value: object,
) -> None:
    from .ref_tile import RefTile

    # Convert indices to proper format
    processed_index: list[object] = []
    for idx in index:
        if isinstance(idx, RefTile):
            processed_index.append(idx._slice)
        elif isinstance(idx, torch.Tensor) and idx.numel() == 1:
            processed_index.append(int(idx.item()))
        else:
            processed_index.append(idx)

    # Find tensor indices that need element-wise processing
    tensor_indices = [
        (i, idx)
        for i, idx in enumerate(processed_index)
        if isinstance(idx, torch.Tensor) and idx.numel() > 1
    ]

    if tensor_indices:
        # Element-wise processing for tensor indices (handle first tensor index)
        i, tensor_idx = tensor_indices[0]
        for j, elem in enumerate(tensor_idx):
            new_index = processed_index.copy()
            new_index[i] = int(elem.item())
            val = (
                value[j]
                if isinstance(value, torch.Tensor) and value.numel() > 1
                else value
            )
            apply_fn(target, tuple(new_index), val)
    else:
        apply_fn(target, tuple(processed_index), value)


# -- atomic_add --


[docs] @has_side_effect @_decorators.api(allow_host_tensor=True, tiles_as_sizes=True) def atomic_add( target: torch.Tensor, index: list[object], value: torch.Tensor | float, sem: str = "relaxed", ) -> torch.Tensor: """ Atomically add a value to a target tensor. Performs an atomic read-modify-write that adds ``value`` to ``target[index]``. This is safe for concurrent access from multiple threads/blocks. Args: target: Tensor to update. index: Indices selecting elements to update. Can include tiles. value: Value(s) to add (tensor or scalar). sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``, ``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``. Returns: torch.Tensor: The previous value(s) stored at ``target[index]`` before the update. Example: @helion.kernel def global_sum(x: torch.Tensor, result: torch.Tensor) -> torch.Tensor: for tile in hl.tile(x.size(0)): hl.atomic_add(result, [0], x[tile].sum()) return result Notes: - Use for race-free accumulation across parallel execution. - Higher memory semantics may reduce performance. """ raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_add) def _( target: torch.Tensor, index: list[object], value: torch.Tensor | float, sem: str = "relaxed", ) -> tuple[torch.Tensor, object, torch.Tensor | float | int, str]: return _prepare_mem_args(target, index, value, sem=sem) @_decorators.register_fake(atomic_add) def _( target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed" ) -> torch.Tensor: target_shape = SubscriptIndexing.compute_shape(target, index) return target.new_empty(target_shape) @_decorators.ref(atomic_add) def _( target: torch.Tensor, index: list[object], value: torch.Tensor | float, sem: str = "relaxed", ) -> torch.Tensor: _validate_sem(sem) from .ref_tile import RefTile # Convert indices and detect tensor indices for element-wise updates processed_index: list[object] = [] tensor_indices: list[tuple[int, torch.Tensor]] = [] for i, idx in enumerate(index): if isinstance(idx, RefTile): processed_index.append(idx._slice) elif isinstance(idx, torch.Tensor): if idx.numel() == 1: processed_index.append(int(idx.item())) else: processed_index.append(idx) tensor_indices.append((i, idx)) else: processed_index.append(idx) if tensor_indices: # Element-wise processing for the first tensor index to ensure correct semantics i, idx_tensor = tensor_indices[0] ret = torch.empty_like(idx_tensor, dtype=target.dtype, device=target.device) # Flatten to assign easily flat_ret = ret.reshape(-1) flat_idx = idx_tensor.reshape(-1) # Prepare value per element if isinstance(value, torch.Tensor) and value.numel() > 1: flat_val = value.reshape(-1) else: flat_val = None for j, elem in enumerate(flat_idx): new_index = list(processed_index) new_index[i] = int(elem.item()) new_index_t = tuple(new_index) prev = target[new_index_t] # pyright: ignore[reportArgumentType] vj = flat_val[j] if flat_val is not None else value # Convert scalar to tensor on device vj_t = ( vj if isinstance(vj, torch.Tensor) else torch.as_tensor(vj, dtype=target.dtype, device=target.device) ) target[new_index_t] = target[new_index_t] + vj_t # pyright: ignore[reportArgumentType] flat_ret[j] = prev # pyright: ignore[reportArgumentType] return ret # Scalar or simple indexing path idx_tuple = tuple(processed_index) prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] val = ( value if isinstance(value, torch.Tensor) else torch.as_tensor(value, dtype=target.dtype, device=target.device) ) target[idx_tuple] = target[idx_tuple] + val # pyright: ignore[reportArgumentType] return prev @_decorators.codegen(atomic_add) def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_add", state, _to_ast_values([value_expr])) # -- atomic_xchg --
[docs] @has_side_effect @_decorators.api(allow_host_tensor=True, tiles_as_sizes=True) def atomic_xchg( target: torch.Tensor, index: list[object], value: torch.Tensor | float | bool, sem: str = "relaxed", ) -> torch.Tensor: """ Atomically exchange (set) a value at ``target[index]``. Args: target: Tensor to update. index: Indices selecting elements to update. Can include tiles. value: New value(s) to set. sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``, ``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``. Returns: torch.Tensor: The previous value(s) stored at ``target[index]`` before the update. """ raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_xchg) def _( target: torch.Tensor, index: list[object], value: torch.Tensor | float | bool, sem: str = "relaxed", ) -> tuple[torch.Tensor, object, object, str]: return _prepare_mem_args(target, index, value, sem=sem) @_decorators.register_fake(atomic_xchg) def _( target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed" ) -> torch.Tensor: target_shape = SubscriptIndexing.compute_shape(target, index) return target.new_empty(target_shape) @_decorators.ref(atomic_xchg) def _( target: torch.Tensor, index: list[object], value: torch.Tensor | float | bool, sem: str = "relaxed", ) -> torch.Tensor: _validate_sem(sem) from .ref_tile import RefTile processed_index: list[object] = [] for idx in index: if isinstance(idx, RefTile): processed_index.append(idx._slice) elif isinstance(idx, torch.Tensor) and idx.numel() == 1: processed_index.append(int(idx.item())) else: processed_index.append(idx) idx_tuple = tuple(processed_index) prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] val = ( value if isinstance(value, torch.Tensor) else torch.as_tensor(value, dtype=target.dtype, device=target.device) ) target[idx_tuple] = val # pyright: ignore[reportArgumentType] return prev @_decorators.codegen(atomic_xchg) def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_xchg", state, _to_ast_values([value_expr])) # -- atomic_and/or/xor --
[docs] @has_side_effect @_decorators.api(allow_host_tensor=True, tiles_as_sizes=True) def atomic_and( target: torch.Tensor, index: list[object], value: torch.Tensor | int | bool, sem: str = "relaxed", ) -> torch.Tensor: """ Atomically apply bitwise AND with ``value`` to ``target[index]``. Args: target: Tensor to update (integer/bool dtype). index: Indices selecting elements to update. Can include tiles. value: Value(s) to AND with. sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``, ``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``. Returns: torch.Tensor: The previous value(s) stored at ``target[index]`` before the update. """ raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_and) def _( target: torch.Tensor, index: list[object], value: object, sem: str = "relaxed" ) -> tuple[torch.Tensor, object, object, str]: return _prepare_mem_args(target, index, value, sem=sem) @_decorators.register_fake(atomic_and) def _( target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed" ) -> torch.Tensor: target_shape = SubscriptIndexing.compute_shape(target, index) return target.new_empty(target_shape) @_decorators.ref(atomic_and) def _( target: torch.Tensor, index: list[object], value: torch.Tensor | int | bool, sem: str = "relaxed", ) -> torch.Tensor: _validate_sem(sem) from .ref_tile import RefTile processed_index: list[object] = [] for idx in index: if isinstance(idx, RefTile): processed_index.append(idx._slice) elif isinstance(idx, torch.Tensor) and idx.numel() == 1: processed_index.append(int(idx.item())) else: processed_index.append(idx) idx_tuple = tuple(processed_index) prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] val = ( value if isinstance(value, torch.Tensor) else torch.as_tensor(value, dtype=target.dtype, device=target.device) ) target[idx_tuple] = target[idx_tuple] & val # pyright: ignore[reportArgumentType] return prev @_decorators.codegen(atomic_and) def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_and", state, _to_ast_values([value_expr]))
[docs] @has_side_effect @_decorators.api(allow_host_tensor=True, tiles_as_sizes=True) def atomic_or( target: torch.Tensor, index: list[object], value: torch.Tensor | int | bool, sem: str = "relaxed", ) -> torch.Tensor: """ Atomically apply bitwise OR with ``value`` to ``target[index]``. Args: target: Tensor to update (integer/bool dtype). index: Indices selecting elements to update. Can include tiles. value: Value(s) to OR with. sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``, ``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``. Returns: torch.Tensor: The previous value(s) stored at ``target[index]`` before the update. """ raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_or) def _( target: torch.Tensor, index: list[object], value: object, sem: str = "relaxed" ) -> tuple[torch.Tensor, object, object, str]: return _prepare_mem_args(target, index, value, sem=sem) @_decorators.register_fake(atomic_or) def _( target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed" ) -> torch.Tensor: target_shape = SubscriptIndexing.compute_shape(target, index) return target.new_empty(target_shape) @_decorators.ref(atomic_or) def _( target: torch.Tensor, index: list[object], value: torch.Tensor | int | bool, sem: str = "relaxed", ) -> torch.Tensor: _validate_sem(sem) from .ref_tile import RefTile processed_index: list[object] = [] for idx in index: if isinstance(idx, RefTile): processed_index.append(idx._slice) elif isinstance(idx, torch.Tensor) and idx.numel() == 1: processed_index.append(int(idx.item())) else: processed_index.append(idx) idx_tuple = tuple(processed_index) prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] val = ( value if isinstance(value, torch.Tensor) else torch.as_tensor(value, dtype=target.dtype, device=target.device) ) target[idx_tuple] = target[idx_tuple] | val # pyright: ignore[reportArgumentType] return prev @_decorators.codegen(atomic_or) def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_or", state, _to_ast_values([value_expr]))
[docs] @has_side_effect @_decorators.api(allow_host_tensor=True, tiles_as_sizes=True) def atomic_xor( target: torch.Tensor, index: list[object], value: torch.Tensor | int | bool, sem: str = "relaxed", ) -> torch.Tensor: """ Atomically apply bitwise XOR with ``value`` to ``target[index]``. Args: target: Tensor to update (integer/bool dtype). index: Indices selecting elements to update. Can include tiles. value: Value(s) to XOR with. sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``, ``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``. Returns: torch.Tensor: The previous value(s) stored at ``target[index]`` before the update. """ raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_xor) def _( target: torch.Tensor, index: list[object], value: object, sem: str = "relaxed" ) -> tuple[torch.Tensor, object, object, str]: return _prepare_mem_args(target, index, value, sem=sem) @_decorators.register_fake(atomic_xor) def _( target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed" ) -> torch.Tensor: target_shape = SubscriptIndexing.compute_shape(target, index) return target.new_empty(target_shape) @_decorators.ref(atomic_xor) def _( target: torch.Tensor, index: list[object], value: torch.Tensor | int | bool, sem: str = "relaxed", ) -> torch.Tensor: _validate_sem(sem) from .ref_tile import RefTile processed_index: list[object] = [] for idx in index: if isinstance(idx, RefTile): processed_index.append(idx._slice) elif isinstance(idx, torch.Tensor) and idx.numel() == 1: processed_index.append(int(idx.item())) else: processed_index.append(idx) idx_tuple = tuple(processed_index) prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] val = ( value if isinstance(value, torch.Tensor) else torch.as_tensor(value, dtype=target.dtype, device=target.device) ) target[idx_tuple] = target[idx_tuple] ^ val # pyright: ignore[reportArgumentType] return prev @_decorators.codegen(atomic_xor) def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_xor", state, _to_ast_values([value_expr])) # -- atomic_max/min --
[docs] @has_side_effect @_decorators.api(allow_host_tensor=True, tiles_as_sizes=True) def atomic_max( target: torch.Tensor, index: list[object], value: torch.Tensor | float, sem: str = "relaxed", ) -> torch.Tensor: """ Atomically update ``target[index]`` with the maximum of current value and ``value``. Args: target: Tensor to update. index: Indices selecting elements to update. Can include tiles. value: Value(s) to compare with. sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``, ``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``. Returns: torch.Tensor: The previous value(s) stored at ``target[index]`` before the update. """ raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_max) def _( target: torch.Tensor, index: list[object], value: object, sem: str = "relaxed" ) -> tuple[torch.Tensor, object, object, str]: return _prepare_mem_args(target, index, value, sem=sem) @_decorators.register_fake(atomic_max) def _( target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed" ) -> torch.Tensor: target_shape = SubscriptIndexing.compute_shape(target, index) return target.new_empty(target_shape) @_decorators.ref(atomic_max) def _( target: torch.Tensor, index: list[object], value: torch.Tensor | float, sem: str = "relaxed", ) -> None: _validate_sem(sem) def apply(t: torch.Tensor, idx: tuple, v: object) -> None: t[idx] = torch.maximum( t[idx], torch.as_tensor(v, dtype=t[idx].dtype, device=t.device) ) # pyright: ignore[reportArgumentType] _ref_apply(target, index, apply, value) @_decorators.codegen(atomic_max) def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_max", state, _to_ast_values([value_expr]))
[docs] @has_side_effect @_decorators.api(allow_host_tensor=True, tiles_as_sizes=True) def atomic_min( target: torch.Tensor, index: list[object], value: torch.Tensor | float, sem: str = "relaxed", ) -> torch.Tensor: """ Atomically update ``target[index]`` with the minimum of current value and ``value``. Args: target: Tensor to update. index: Indices selecting elements to update. Can include tiles. value: Value(s) to compare with. sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``, ``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``. Returns: torch.Tensor: The previous value(s) stored at ``target[index]`` before the update. """ raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_min) def _( target: torch.Tensor, index: list[object], value: object, sem: str = "relaxed" ) -> tuple[torch.Tensor, object, object, str]: return _prepare_mem_args(target, index, value, sem=sem) @_decorators.register_fake(atomic_min) def _( target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed" ) -> torch.Tensor: target_shape = SubscriptIndexing.compute_shape(target, index) return target.new_empty(target_shape) @_decorators.ref(atomic_min) def _( target: torch.Tensor, index: list[object], value: torch.Tensor | float, sem: str = "relaxed", ) -> torch.Tensor: _validate_sem(sem) from .ref_tile import RefTile processed_index: list[object] = [] for idx in index: if isinstance(idx, RefTile): processed_index.append(idx._slice) elif isinstance(idx, torch.Tensor) and idx.numel() == 1: processed_index.append(int(idx.item())) else: processed_index.append(idx) idx_tuple = tuple(processed_index) prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] val = ( value if isinstance(value, torch.Tensor) else torch.as_tensor(value, dtype=target.dtype, device=target.device) ) target[idx_tuple] = torch.minimum(target[idx_tuple], val) # pyright: ignore[reportArgumentType] return prev @_decorators.codegen(atomic_min) def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_min", state, _to_ast_values([value_expr])) # -- atomic_cas --
[docs] @has_side_effect @_decorators.api(allow_host_tensor=True, tiles_as_sizes=True) def atomic_cas( target: torch.Tensor, index: list[object], expected: torch.Tensor | float | bool, value: torch.Tensor | float | bool, sem: str = "relaxed", ) -> torch.Tensor: """ Atomically compare-and-swap a value at ``target[index]``. If the current value equals ``expected``, writes ``value``. Otherwise leaves memory unchanged. Args: target: Tensor to update. index: Indices selecting elements to update. Can include tiles. expected: Expected current value(s) used for comparison. value: New value(s) to write if comparison succeeds. sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``, ``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``. Returns: torch.Tensor: The previous value(s) stored at ``target[index]`` before the compare-and-swap. Note: Triton CAS doesn’t support a masked form; our generated code uses an unmasked CAS and relies on index masking to avoid OOB. """ raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_cas) def _( target: torch.Tensor, index: list[object], expected: object, value: object, sem: str = "relaxed", ) -> tuple[torch.Tensor, object, object, object, str]: return _prepare_mem_args(target, index, expected, value, sem=sem) @_decorators.register_fake(atomic_cas) def _( target: torch.Tensor, index: list[object], expected: torch.Tensor, value: torch.Tensor, sem: str = "relaxed", ) -> torch.Tensor: target_shape = SubscriptIndexing.compute_shape(target, index) return target.new_empty(target_shape) @_decorators.ref(atomic_cas) def _( target: torch.Tensor, index: list[object], expected: torch.Tensor | float | bool, value: torch.Tensor | float | bool, sem: str = "relaxed", ) -> torch.Tensor: _validate_sem(sem) from .ref_tile import RefTile processed_index: list[object] = [] for idx in index: if isinstance(idx, RefTile): processed_index.append(idx._slice) elif isinstance(idx, torch.Tensor) and idx.numel() == 1: processed_index.append(int(idx.item())) else: processed_index.append(idx) idx_tuple = tuple(processed_index) prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] exp_t = ( expected if isinstance(expected, torch.Tensor) else torch.as_tensor(expected, dtype=target.dtype, device=target.device) ) val_t = ( value if isinstance(value, torch.Tensor) else torch.as_tensor(value, dtype=target.dtype, device=target.device) ) mask = target[idx_tuple] == exp_t # pyright: ignore[reportArgumentType] target[idx_tuple] = torch.where(mask, val_t, target[idx_tuple]) # pyright: ignore[reportArgumentType] return prev @_decorators.codegen(atomic_cas) def _(state: CodegenState) -> ast.AST: exp_expr = state.ast_args[2] val_expr = state.ast_args[3] target = state.proxy_arg(0) index = state.proxy_arg(1) sem = expr_from_string(repr(state.proxy_arg(len(state.ast_args) - 1))) assert isinstance(target, torch.Tensor) assert isinstance(index, list) indices = SubscriptIndexing.create(state, target, index) name = state.device_function.tensor_arg(target).name exp_ast, val_ast = _to_ast_values([exp_expr, val_expr]) return expr_from_string( f"tl.atomic_cas({name} + {{offset}}, {{exp}}, {{val}}, sem={{sem}})", offset=indices.index_expr, exp=exp_ast, val=val_ast, sem=sem, )