from __future__ import annotations
import ast
import itertools
from typing import TYPE_CHECKING
from typing import Callable
import torch
from torch._inductor.codegen.simd import constant_repr
from torch.fx import has_side_effect
from .. import exc
from .._compiler.ast_extension import expr_from_string
from .._compiler.compile_environment import _symint_expr
from .._compiler.host_function import HostFunction
from .._compiler.indexing_strategy import SubscriptIndexing
from .._compiler.variable_origin import GridOrigin
from . import _decorators
from . import _tracing_ops
if TYPE_CHECKING:
from .._compiler.inductor_lowering import CodegenState
__all__ = [
"atomic_add",
"atomic_and",
"atomic_cas",
"atomic_max",
"atomic_min",
"atomic_or",
"atomic_xchg",
"atomic_xor",
]
_VALID_SEMS: set[str] = {"relaxed", "acquire", "release", "acq_rel"}
def _validate_sem(sem: str) -> None:
if sem not in _VALID_SEMS:
raise exc.InternalError(
ValueError(
f"Invalid memory semantic '{sem}'. Valid options are: relaxed, acquire, release, acq_rel"
)
)
def _prepare_mem_args(
target: torch.Tensor,
index: list[object],
*values: object,
sem: str = "relaxed",
) -> tuple:
from .tile_proxy import Tile
_validate_sem(sem)
index = Tile._prepare_index(index)
index = Tile._tiles_to_sizes_for_index(index)
return (target, index, *values, sem)
def _codegen_common(
op: str, state: CodegenState, value_exprs: list[ast.AST]
) -> ast.AST:
"""Route any single-value atomic op through the atomic_indexing strategy."""
target = state.proxy_arg(0)
index = state.proxy_arg(1)
sem = expr_from_string(repr(state.proxy_arg(len(state.ast_args) - 1)))
assert isinstance(target, torch.Tensor)
assert isinstance(index, list)
host_function = HostFunction.current()
if target not in host_function.tensor_to_origin:
raise exc.AtomicOnDeviceTensor(op)
device_fn = state.device_function
indexing_idx = device_fn.atomic_op_index
device_fn.atomic_op_index += 1
strategy = device_fn.get_atomic_indexing_strategy(indexing_idx)
return strategy.codegen_atomic(op, state, target, index, value_exprs[0], sem)
def _cute_pointer_expr(
state: CodegenState,
target: torch.Tensor,
index: list[object],
ast_index: list[object] | tuple[object, ...] | None = None,
) -> str:
from .memory_ops import _cute_index_exprs
index_exprs = _cute_index_exprs(state, index, ast_index)
name = state.device_function.tensor_arg(target).name
coord = (
f"({index_exprs[0]},)"
if len(index_exprs) == 1
else f"({', '.join(index_exprs)})"
)
return f"({name}.iterator + cute.crd2idx({coord}, {name}.layout)).llvm_ptr"
def _resolve_cute_atomic_kwargs(cute_func: str, requested: list[str]) -> list[str]:
"""Map our intended ``cute.arch.<cute_func>`` kwarg names onto whatever
the live signature actually exposes.
Helion's emitted code refers to ``cute.arch.atomic_*`` parameters by
name (``val``, ``cmp``). Some nvidia-cutlass-dsl wheels have shipped
with these renamed (e.g. ``val`` -> ``value``); the old emission
style then trips a ``TypeError`` deep inside CUTLASS at run time.
Probe the signature at codegen time and rewrite the kwarg names to
match what the live wrapper accepts. Falls back to the requested
name when none of the rename candidates appears, so healthy installs
are unaffected.
"""
import inspect
try:
import cutlass.cute as cute # type: ignore[import-not-found]
except ImportError:
return list(requested)
func = getattr(getattr(cute, "arch", None), cute_func, None)
if func is None:
return list(requested)
try:
params = set(inspect.signature(func).parameters)
except (TypeError, ValueError):
return list(requested)
rename_candidates: dict[str, tuple[str, ...]] = {
"val": ("val", "value", "rhs", "src", "a"),
"cmp": ("cmp", "compare", "expected", "exp"),
}
resolved: list[str] = []
for name in requested:
candidates = rename_candidates.get(name, (name,))
chosen = next((c for c in candidates if c in params), name)
resolved.append(chosen)
return resolved
def _codegen_common_cute(
cute_func: str,
state: CodegenState,
*,
value_exprs: list[ast.AST],
keyword_names: list[str],
) -> ast.AST:
from .._compiler.compile_environment import CompileEnvironment
target = state.proxy_arg(0)
index = state.proxy_arg(1)
sem = expr_from_string(repr(state.proxy_arg(len(state.ast_args) - 1)))
assert isinstance(target, torch.Tensor)
assert isinstance(index, list)
host_function = HostFunction.current()
if target not in host_function.tensor_to_origin:
raise exc.AtomicOnDeviceTensor(cute_func)
backend = CompileEnvironment.current().backend
target_dtype = backend.dtype_str(target.dtype)
cast_value_exprs = [
expr_from_string(
backend.ast_to_dtype_expr("{value}", target_dtype),
value=value_expr,
)
for value_expr in value_exprs
]
tensor_index_stmt = _codegen_tensor_index_common_cute(
cute_func,
state,
target,
index,
sem,
cast_value_exprs,
keyword_names,
)
if tensor_index_stmt is not None:
return tensor_index_stmt
ast_index = state.ast_args[1]
assert isinstance(ast_index, (list, tuple))
pointer = _cute_pointer_expr(state, target, index, ast_index)
resolved_kwargs = _resolve_cute_atomic_kwargs(cute_func, keyword_names)
values_section = ", ".join(
f"{actual}={{{intent}}}"
for intent, actual in zip(keyword_names, resolved_kwargs, strict=True)
)
placeholders = dict(zip(keyword_names, cast_value_exprs, strict=True))
atomic_expr = expr_from_string(
f"cute.arch.{cute_func}({{ptr}}, {values_section}, sem={{sem}})",
ptr=expr_from_string(pointer),
sem=sem,
**placeholders,
)
return _guard_cute_atomic_expr(state, index, target_dtype, atomic_expr)
def _guard_cute_atomic_expr(
state: CodegenState,
index: list[object],
target_dtype: str,
atomic_expr: ast.AST,
*,
extra_predicates: list[str] | None = None,
) -> ast.AST:
predicates = [
predicate
for predicate in (
_cute_active_mask_predicate(state),
_cute_leader_thread_predicate(state, index),
_cute_unindexed_axis_leader_predicate(state, index),
*(extra_predicates or []),
)
if predicate is not None
]
if not predicates:
return atomic_expr
predicate_expr = expr_from_string(" and ".join(predicates))
assert isinstance(predicate_expr, ast.expr)
assert isinstance(atomic_expr, ast.expr)
if state.fx_node is not None and len(state.fx_node.users) == 0:
state.codegen.add_statement(
ast.fix_missing_locations(
ast.If(
test=predicate_expr,
body=[ast.Expr(value=atomic_expr)],
orelse=[],
)
)
)
return ast.Constant(value=None)
result_var = state.device_function.new_var("_atomic_prev", dce=True)
zero_value = expr_from_string(f"{target_dtype}(0)")
assert isinstance(zero_value, ast.expr)
state.codegen.add_statement(
ast.fix_missing_locations(
ast.Assign(
targets=[ast.Name(id=result_var, ctx=ast.Store())],
value=zero_value,
)
)
)
state.codegen.add_statement(
ast.fix_missing_locations(
ast.If(
test=predicate_expr,
body=[
ast.Assign(
targets=[ast.Name(id=result_var, ctx=ast.Store())],
value=atomic_expr,
)
],
orelse=[],
)
)
)
return expr_from_string(result_var)
def _cute_tensor_index_leader_predicate(
state: CodegenState,
tensor_index: torch.Tensor,
) -> str | None:
from .._compiler.compile_environment import CompileEnvironment
env = CompileEnvironment.current()
block_id = env.resolve_block_id(tensor_index.shape[0])
if block_id is None:
return None
assert state.fx_node is not None
block_id = env.resolve_codegen_block_id(
block_id, state.codegen, state.fx_node.graph
)
index_axes: set[int] = set()
other_axes: set[int] = set()
grid_state = state.codegen.current_grid_state
if grid_state is not None:
for candidate_block_id, thread_axis in grid_state.block_thread_axes.items():
if candidate_block_id == block_id:
index_axes.add(thread_axis)
else:
other_axes.add(thread_axis)
for loops in state.codegen.active_device_loops.values():
for loop_state in loops:
for candidate_block_id, thread_axis in loop_state.block_thread_axes.items():
if candidate_block_id == block_id:
index_axes.add(thread_axis)
else:
other_axes.add(thread_axis)
leader_axes = sorted(axis for axis in other_axes if axis not in index_axes)
if not leader_axes:
return None
return " and ".join(
f"(cute.arch.thread_idx()[{axis}] == 0)" for axis in leader_axes
)
def _cute_leader_thread_predicate(
state: CodegenState,
index: list[object],
) -> str | None:
scalar_origin_block_ids: set[int] = set()
for idx in index:
if not isinstance(idx, torch.SymInt):
continue
expr = _symint_expr(idx)
if expr is None:
continue
origin_info = HostFunction.current().expr_to_origin.get(expr)
if origin_info is None or not isinstance(origin_info.origin, GridOrigin):
continue
if type(origin_info.origin) is GridOrigin:
continue
scalar_origin_block_ids.add(origin_info.origin.block_id)
if not scalar_origin_block_ids:
return None
axes: set[int] = set()
grid_state = state.codegen.current_grid_state
if grid_state is not None:
for block_id in scalar_origin_block_ids:
thread_axis = grid_state.block_thread_axes.get(block_id)
if thread_axis is not None:
axes.add(thread_axis)
for loops in state.codegen.active_device_loops.values():
for loop_state in loops:
for block_id in scalar_origin_block_ids:
thread_axis = loop_state.block_thread_axes.get(block_id)
if thread_axis is not None:
axes.add(thread_axis)
if not axes:
return None
return " and ".join(
f"(cute.arch.thread_idx()[{axis}] == 0)" for axis in sorted(axes)
)
def _cute_unindexed_axis_leader_predicate(
state: CodegenState,
index: list[object],
) -> str | None:
"""Return predicate restricting atomics to one thread per unindexed parallel axis.
When an atomic op is invoked inside ``hl.tile([m, n])`` but the index only
covers a subset of the tile dimensions (e.g. ``hl.atomic_add(dy, [tile_i],
reduced)`` after a reduction across ``tile_j``), every thread on the
unindexed axis would otherwise re-issue the atomic with the same value.
Restrict those axes to ``thread_idx[axis] == 0`` so the op fires once per
unique (indexed) coordinate.
"""
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.variable_origin import BlockSizeOrigin
env = CompileEnvironment.current()
indexed_block_ids: set[int] = set()
has_block_size_index = False
for idx in index:
if not isinstance(idx, torch.SymInt):
continue
expr = _symint_expr(idx)
if expr is None:
continue
origin_info = HostFunction.current().expr_to_origin.get(expr)
if origin_info is None or not isinstance(origin_info.origin, BlockSizeOrigin):
continue
has_block_size_index = True
assert state.fx_node is not None
indexed_block_ids.add(
env.resolve_codegen_block_id(
origin_info.origin.block_id,
state.codegen,
state.fx_node.graph,
)
)
if not has_block_size_index:
return None
assert state.fx_node is not None
fx_graph = state.fx_node.graph
leader_axes: set[int] = set()
def collect(thread_axes: dict[int, int]) -> None:
for candidate_block_id, thread_axis in thread_axes.items():
resolved = env.resolve_codegen_block_id(
candidate_block_id, state.codegen, fx_graph
)
if resolved in indexed_block_ids:
continue
leader_axes.add(thread_axis)
grid_state = state.codegen.current_grid_state
if grid_state is not None:
collect(grid_state.block_thread_axes)
for loops in state.codegen.active_device_loops.values():
for loop_state in loops:
collect(loop_state.block_thread_axes)
if not leader_axes:
return None
return " and ".join(
f"(cute.arch.thread_idx()[{axis}] == 0)" for axis in sorted(leader_axes)
)
def _cute_active_mask_predicate(state: CodegenState) -> str | None:
masks: list[str] = []
seen_blocks: set[int] = set()
for block_id, loops in state.codegen.active_device_loops.items():
if block_id in seen_blocks or not loops:
continue
seen_blocks.add(block_id)
mask_var = loops[-1].strategy.mask_var(block_id)
if mask_var is not None:
masks.append(f"({mask_var})")
grid_state = state.codegen.current_grid_state
if grid_state is not None:
for block_id in grid_state.block_ids:
if block_id in seen_blocks:
continue
seen_blocks.add(block_id)
mask_var = grid_state.strategy.mask_var(block_id)
if mask_var is not None:
masks.append(f"({mask_var})")
if not masks:
return None
return " and ".join(masks)
def _resolve_tensor_index_iota_node(
state: CodegenState, index_node: torch.fx.Node
) -> torch.fx.Node | None:
from .._compiler.device_ir import NodeArgsGraphInfo
current = index_node
visited: set[torch.fx.Node] = set()
while True:
if current in visited:
return None
visited.add(current)
if current.target is torch.ops.prims.iota.default:
return current
if current.op == "call_function" and current.target in {
_tracing_ops._new_var,
_tracing_ops._phi,
torch.ops.aten.clone.default,
torch.ops.aten.detach.default,
torch.ops.prims.convert_element_type.default,
}:
arg = current.args[0] if current.args else None
if not isinstance(arg, torch.fx.Node):
return None
current = arg
continue
if current.op != "placeholder":
return None
graph_infos = [
graph_info
for graph_info in state.codegen.codegen_graphs
if graph_info.graph is current.graph
]
if len(graph_infos) != 1:
return None
graph_info = graph_infos[0]
if not isinstance(graph_info, NodeArgsGraphInfo):
return None
outer_node = graph_info.placeholder_to_outer_arg(current)
if not isinstance(outer_node, torch.fx.Node):
return None
current = outer_node
def _codegen_tensor_index_common_cute(
cute_func: str,
state: CodegenState,
target: torch.Tensor,
index: list[object],
sem: ast.AST,
value_exprs: list[ast.AST],
keyword_names: list[str],
) -> ast.AST | None:
from .._compiler.compile_environment import CompileEnvironment
from .memory_ops import _cute_active_index_var
fx_node = state.fx_node
if fx_node is None or len(index) != 1 or len(fx_node.args) < 2:
return None
tensor_index = index[0] if isinstance(index[0], torch.Tensor) else None
fx_index = fx_node.args[1]
if not isinstance(fx_index, (list, tuple)) or len(fx_index) != 1:
return None
index_node = fx_index[0]
if not isinstance(index_node, torch.fx.Node):
return None
iota_node = _resolve_tensor_index_iota_node(state, index_node)
if iota_node is None:
return None
iota_val = iota_node.meta.get("val")
if isinstance(iota_val, torch.Tensor) and iota_val.ndim == 1:
tensor_index = iota_val
if tensor_index is None or tensor_index.ndim != 1:
return None
iota_start = iota_node.kwargs.get("start", 0)
iota_step = iota_node.kwargs.get("step", 1)
if iota_step != 1 or not isinstance(iota_start, int):
return _codegen_tensor_index_loop_common_cute(
cute_func,
state,
target,
tensor_index,
index_node,
sem,
value_exprs,
keyword_names,
)
env = CompileEnvironment.current()
block_id = env.resolve_block_id(tensor_index.shape[0])
if block_id is None:
return _codegen_tensor_index_loop_common_cute(
cute_func,
state,
target,
tensor_index,
index_node,
sem,
value_exprs,
keyword_names,
)
block_id = env.resolve_codegen_block_id(block_id, state.codegen, fx_node.graph)
if (index_var := _cute_active_index_var(state, block_id)) is None:
return _codegen_tensor_index_loop_common_cute(
cute_func,
state,
target,
tensor_index,
index_node,
sem,
value_exprs,
keyword_names,
)
tensor_name = state.device_function.tensor_arg(target).name
resolved_kwargs = _resolve_cute_atomic_kwargs(cute_func, keyword_names)
values_section = ", ".join(
f"{actual}={{{intent}}}"
for intent, actual in zip(keyword_names, resolved_kwargs, strict=True)
)
placeholders = dict(zip(keyword_names, value_exprs, strict=True))
atomic_expr = expr_from_string(
"cute.arch."
+ cute_func
+ "("
+ f"({tensor_name}.iterator + "
+ f"cute.crd2idx((cutlass.Int32({iota_start}) + {index_var},), {tensor_name}.layout)).llvm_ptr, "
+ values_section
+ ", sem={sem})",
sem=sem,
**placeholders,
)
target_dtype = env.backend.dtype_str(target.dtype)
extra_predicates = [
predicate
for predicate in (_cute_tensor_index_leader_predicate(state, tensor_index),)
if predicate is not None
]
return _guard_cute_atomic_expr(
state,
index,
target_dtype,
atomic_expr,
extra_predicates=extra_predicates,
)
def _codegen_tensor_index_loop_common_cute(
cute_func: str,
state: CodegenState,
target: torch.Tensor,
tensor_index: torch.Tensor,
index_node: torch.fx.Node,
sem: ast.AST,
value_exprs: list[ast.AST],
keyword_names: list[str],
) -> ast.AST | None:
from .._compiler.ast_extension import statement_from_string
fx_node = state.fx_node
if fx_node is None or len(fx_node.users) > 0:
return None
if tensor_index.ndim != 1:
return None
extent = tensor_index.shape[0]
if not isinstance(extent, int):
return None
ast_index = state.ast_args[1]
if not isinstance(ast_index, (list, tuple)) or len(ast_index) != 1:
return None
ast_index_expr = ast_index[0]
if not isinstance(ast_index_expr, ast.AST):
return None
iota_node = _resolve_tensor_index_iota_node(state, index_node)
indexed_values: list[ast.AST] = []
value_arg_offset = 2
for value_expr, _keyword_name in zip(value_exprs, keyword_names, strict=True):
value_proxy = state.proxy_arg(value_arg_offset)
value_arg_offset += 1
if isinstance(value_proxy, torch.Tensor) and value_proxy.ndim == 1:
tensor_arg = state.device_function.tensor_arg(value_proxy)
indexed_values.append(
expr_from_string(
"{value}[{idx}]",
value=expr_from_string(tensor_arg.name),
idx=expr_from_string("_tensor_index_i"),
)
)
continue
if extent != 1:
return None
indexed_values.append(value_expr)
if iota_node is not None:
start = iota_node.kwargs.get("start", 0)
step = iota_node.kwargs.get("step", 1)
if not isinstance(start, int) or not isinstance(step, int):
return None
index_expr = expr_from_string(
f"cutlass.Int32({start}) + cutlass.Int32({step}) * cutlass.Int32(_tensor_index_i)"
)
else:
index_expr = expr_from_string(
"cutlass.Int32({index}[{idx}])",
index=ast_index_expr,
idx=expr_from_string("_tensor_index_i"),
)
tensor_name = state.device_function.tensor_arg(target).name
resolved_kwargs = _resolve_cute_atomic_kwargs(cute_func, keyword_names)
values_section = ", ".join(
f"{actual}={{{intent}}}"
for intent, actual in zip(keyword_names, resolved_kwargs, strict=True)
)
placeholders = dict(zip(keyword_names, indexed_values, strict=True))
atomic_expr = expr_from_string(
"cute.arch."
+ cute_func
+ "("
+ f"({tensor_name}.iterator + "
+ f"cute.crd2idx(({{index}},), {tensor_name}.layout)).llvm_ptr, "
+ values_section
+ ", sem={sem})",
index=index_expr,
sem=sem,
**placeholders,
)
assert isinstance(atomic_expr, ast.expr)
predicate_terms = [
predicate
for predicate in (
_cute_active_mask_predicate(state),
_cute_tensor_index_leader_predicate(state, tensor_index),
)
if predicate is not None
]
predicate_expr = (
ast.parse(" and ".join(predicate_terms), mode="eval").body
if predicate_terms
else None
)
inner = (
ast.fix_missing_locations(
ast.If(
test=predicate_expr,
body=[ast.Expr(value=atomic_expr)],
orelse=[],
)
)
if predicate_expr is not None
else ast.Expr(value=atomic_expr)
)
loop = statement_from_string(f"for _tensor_index_i in range({extent}):\n pass")
assert isinstance(loop, ast.For)
loop.body = [inner]
state.codegen.add_statement(loop)
return ast.Constant(value=None)
def _pallas_atomic_load_prev(
state: CodegenState,
) -> tuple[str, str, str]:
"""Load previous value for a Pallas atomic op.
On TPU, each kernel instance has exclusive access to its tile, so
atomics are implemented as regular load-compute-store sequences.
Returns (tensor_name, index_str, prev_var_name).
"""
from .._compiler.ast_extension import statement_from_string
from .._compiler.pallas import codegen as pallas_codegen
target = state.proxy_arg(0)
index = state.proxy_arg(1)
assert isinstance(target, torch.Tensor)
assert isinstance(index, (list, tuple))
host_function = HostFunction.current()
if target not in host_function.tensor_to_origin:
raise exc.AtomicOnDeviceTensor("pallas atomic")
name = state.device_function.tensor_arg(target).name
index_str, _ = pallas_codegen.index_str(state, index, target)
prev_var = state.device_function.new_var("_prev", dce=True)
state.codegen.add_statement(
statement_from_string(f"{prev_var} = {name}[{index_str}]")
)
return name, index_str, prev_var # pyrefly: ignore[bad-return]
def _to_ast_values(values: list[object]) -> list[ast.AST]:
out: list[ast.AST] = []
for v in values:
if isinstance(v, (int, float, bool)):
out.append(expr_from_string(constant_repr(v)))
else:
assert isinstance(v, ast.AST)
out.append(v)
return out
def _ref_atomic_binop(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | float | bool,
op: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""Shared ref implementation for simple atomic binary ops (xchg/and/or/xor/max/min).
Processes indices, clones the previous value, applies the op, and returns prev.
For xchg, pass op=lambda old, val: val.
"""
from .ref_tile import RefTile
processed_index: list[object] = []
for idx in index:
if isinstance(idx, RefTile):
processed_index.append(idx._slice)
elif isinstance(idx, torch.Tensor) and idx.numel() == 1:
processed_index.append(int(idx.item()))
else:
processed_index.append(idx)
idx_tuple = tuple(processed_index)
# pyrefly: ignore [bad-index]
prev = target[idx_tuple].clone()
val = (
value
if isinstance(value, torch.Tensor)
else torch.as_tensor(value, dtype=target.dtype, device=target.device)
)
# pyrefly: ignore [bad-index, unsupported-operation]
target[idx_tuple] = op(target[idx_tuple], val)
return prev
def _ref_apply(
target: torch.Tensor,
index: list[object],
apply_fn: Callable[[torch.Tensor, tuple, object], None],
value: object,
) -> None:
from .ref_tile import RefTile
# Convert indices to proper format
processed_index: list[object] = []
for idx in index:
if isinstance(idx, RefTile):
processed_index.append(idx._slice)
elif isinstance(idx, torch.Tensor) and idx.numel() == 1:
processed_index.append(int(idx.item()))
else:
processed_index.append(idx)
# Find tensor indices that need element-wise processing
tensor_indices = [
(i, idx)
for i, idx in enumerate(processed_index)
if isinstance(idx, torch.Tensor) and idx.numel() > 1
]
if tensor_indices:
# Element-wise processing for tensor indices (handle first tensor index)
i, tensor_idx = tensor_indices[0]
if tensor_idx.ndim == 0:
coords_iter = [()]
else:
ranges = [range(dim) for dim in tensor_idx.shape]
coords_iter = itertools.product(*ranges)
for coords in coords_iter:
elem = tensor_idx[coords].item()
new_index = processed_index.copy()
new_index[i] = int(elem)
if isinstance(value, torch.Tensor) and value.numel() > 1:
next_value = value[coords]
else:
next_value = value
_ref_apply(target, new_index, apply_fn, next_value)
else:
apply_fn(target, tuple(processed_index), value)
# -- atomic_add --
[docs]
@has_side_effect
@_decorators.api(allow_host_tensor=True, tiles_as_sizes=True)
def atomic_add(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | float,
sem: str = "relaxed",
) -> torch.Tensor:
"""
Atomically add a value to a target tensor.
Performs an atomic read-modify-write that adds ``value`` to
``target[index]``. This is safe for concurrent access from multiple
threads/blocks.
Args:
target: Tensor to update.
index: Indices selecting elements to update. Can include tiles.
value: Value(s) to add (tensor or scalar).
sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``,
``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``.
Returns:
torch.Tensor: The previous value(s) stored at ``target[index]`` before the update.
Example:
@helion.kernel
def global_sum(x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
for tile in hl.tile(x.size(0)):
hl.atomic_add(result, [0], x[tile].sum())
return result
Notes:
- Use for race-free accumulation across parallel execution.
- Higher memory semantics may reduce performance.
"""
raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_add)
def _(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | float,
sem: str = "relaxed",
) -> tuple[torch.Tensor, object, torch.Tensor | float | int, str]:
return _prepare_mem_args(target, index, value, sem=sem)
@_decorators.register_fake(atomic_add)
def _(
target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed"
) -> torch.Tensor:
target_shape = SubscriptIndexing.compute_shape(target, index)
return target.new_empty(target_shape)
@_decorators.ref(atomic_add)
def _(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | float,
sem: str = "relaxed",
) -> torch.Tensor:
_validate_sem(sem)
from .ref_tile import RefTile
# Convert indices for shape computation and fast path detection
processed_index: list[object] = []
has_tensor_index = False
for idx in index:
if isinstance(idx, RefTile):
processed_index.append(idx._slice)
elif isinstance(idx, torch.Tensor):
if idx.numel() == 1:
processed_index.append(int(idx.item()))
else:
processed_index.append(idx)
has_tensor_index = True
else:
processed_index.append(idx)
def _convert_value_to_target_dtype(val: object) -> torch.Tensor:
if isinstance(val, torch.Tensor):
vt = val.to(device=target.device)
if vt.dtype != target.dtype:
vt = vt.to(dtype=target.dtype)
return vt
return torch.as_tensor(val, dtype=target.dtype, device=target.device)
if has_tensor_index:
ret_shape = SubscriptIndexing.compute_shape(target, processed_index)
prev_chunks: list[torch.Tensor] = []
def apply(t: torch.Tensor, idx_tuple: tuple, v: object) -> None:
prev_val = t[idx_tuple].clone()
val_tensor = _convert_value_to_target_dtype(v)
t[idx_tuple] = t[idx_tuple] + val_tensor
prev_chunks.append(prev_val.reshape(-1))
_ref_apply(target, index, apply, value)
if prev_chunks:
flat_prev = torch.cat(prev_chunks)
else:
flat_prev = target.new_empty(0, dtype=target.dtype, device=target.device)
return flat_prev.reshape(ret_shape)
idx_tuple = tuple(processed_index)
# pyrefly: ignore [bad-index]
prev = target[idx_tuple].clone()
val_tensor = _convert_value_to_target_dtype(value)
# pyrefly: ignore [bad-index, unsupported-operation]
target[idx_tuple] = target[idx_tuple] + val_tensor
return prev
@_decorators.codegen(atomic_add, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_add", state, _to_ast_values([value_expr]))
@_decorators.codegen(atomic_add, "cute")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common_cute(
"atomic_add",
state,
value_exprs=_to_ast_values([value_expr]),
keyword_names=["val"],
)
@_decorators.codegen(atomic_add, "pallas")
def _(state: CodegenState) -> ast.AST:
from .._compiler.ast_extension import statement_from_string
from .._compiler.compile_environment import CompileEnvironment
name, index_str, prev_var = _pallas_atomic_load_prev(state)
value_ast = _to_ast_values([state.ast_args[2]])[0]
target = state.proxy_arg(0)
assert isinstance(target, torch.Tensor)
backend = CompileEnvironment.current().backend
target_dtype = backend.dtype_str(target.dtype)
# Cast the sum to the target dtype so the store doesn't fail when
# the value dtype differs (e.g. float32 accumulator into bfloat16 ref).
cast = backend.cast_expr(f"{prev_var} + {{value}}", target_dtype)
state.codegen.add_statement(
statement_from_string(f"{name}[{index_str}] = {cast}", value=value_ast)
)
return expr_from_string(prev_var)
# -- atomic_xchg --
[docs]
@has_side_effect
@_decorators.api(allow_host_tensor=True, tiles_as_sizes=True)
def atomic_xchg(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | float | bool,
sem: str = "relaxed",
) -> torch.Tensor:
"""
Atomically exchange (set) a value at ``target[index]``.
Args:
target: Tensor to update.
index: Indices selecting elements to update. Can include tiles.
value: New value(s) to set.
sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``,
``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``.
Returns:
torch.Tensor: The previous value(s) stored at ``target[index]`` before the update.
"""
raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_xchg)
def _(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | float | bool,
sem: str = "relaxed",
) -> tuple[torch.Tensor, object, object, str]:
return _prepare_mem_args(target, index, value, sem=sem)
@_decorators.register_fake(atomic_xchg)
def _(
target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed"
) -> torch.Tensor:
target_shape = SubscriptIndexing.compute_shape(target, index)
return target.new_empty(target_shape)
@_decorators.ref(atomic_xchg)
def _(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | float | bool,
sem: str = "relaxed",
) -> torch.Tensor:
_validate_sem(sem)
return _ref_atomic_binop(target, index, value, lambda old, val: val)
@_decorators.codegen(atomic_xchg, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_xchg", state, _to_ast_values([value_expr]))
@_decorators.codegen(atomic_xchg, "cute")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common_cute(
"atomic_exch",
state,
value_exprs=_to_ast_values([value_expr]),
keyword_names=["val"],
)
@_decorators.codegen(atomic_xchg, "pallas")
def _(state: CodegenState) -> ast.AST:
from .._compiler.ast_extension import statement_from_string
name, index_str, prev_var = _pallas_atomic_load_prev(state)
value_ast = _to_ast_values([state.ast_args[2]])[0]
state.codegen.add_statement(
statement_from_string(f"{name}[{index_str}] = {{value}}", value=value_ast)
)
return expr_from_string(prev_var)
# -- atomic_and/or/xor --
[docs]
@has_side_effect
@_decorators.api(allow_host_tensor=True, tiles_as_sizes=True)
def atomic_and(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | int | bool,
sem: str = "relaxed",
) -> torch.Tensor:
"""
Atomically apply bitwise AND with ``value`` to ``target[index]``.
Args:
target: Tensor to update (integer/bool dtype).
index: Indices selecting elements to update. Can include tiles.
value: Value(s) to AND with.
sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``,
``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``.
Returns:
torch.Tensor: The previous value(s) stored at ``target[index]`` before the update.
"""
raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_and)
def _(
target: torch.Tensor, index: list[object], value: object, sem: str = "relaxed"
) -> tuple[torch.Tensor, object, object, str]:
return _prepare_mem_args(target, index, value, sem=sem)
@_decorators.register_fake(atomic_and)
def _(
target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed"
) -> torch.Tensor:
target_shape = SubscriptIndexing.compute_shape(target, index)
return target.new_empty(target_shape)
@_decorators.ref(atomic_and)
def _(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | int | bool,
sem: str = "relaxed",
) -> torch.Tensor:
_validate_sem(sem)
return _ref_atomic_binop(target, index, value, torch.bitwise_and)
@_decorators.codegen(atomic_and, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_and", state, _to_ast_values([value_expr]))
@_decorators.codegen(atomic_and, "cute")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common_cute(
"atomic_and",
state,
value_exprs=_to_ast_values([value_expr]),
keyword_names=["val"],
)
@_decorators.codegen(atomic_and, "pallas")
def _(state: CodegenState) -> ast.AST:
from .._compiler.ast_extension import statement_from_string
name, index_str, prev_var = _pallas_atomic_load_prev(state)
value_ast = _to_ast_values([state.ast_args[2]])[0]
state.codegen.add_statement(
statement_from_string(
f"{name}[{index_str}] = {prev_var} & {{value}}", value=value_ast
)
)
return expr_from_string(prev_var)
[docs]
@has_side_effect
@_decorators.api(allow_host_tensor=True, tiles_as_sizes=True)
def atomic_or(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | int | bool,
sem: str = "relaxed",
) -> torch.Tensor:
"""
Atomically apply bitwise OR with ``value`` to ``target[index]``.
Args:
target: Tensor to update (integer/bool dtype).
index: Indices selecting elements to update. Can include tiles.
value: Value(s) to OR with.
sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``,
``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``.
Returns:
torch.Tensor: The previous value(s) stored at ``target[index]`` before the update.
"""
raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_or)
def _(
target: torch.Tensor, index: list[object], value: object, sem: str = "relaxed"
) -> tuple[torch.Tensor, object, object, str]:
return _prepare_mem_args(target, index, value, sem=sem)
@_decorators.register_fake(atomic_or)
def _(
target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed"
) -> torch.Tensor:
target_shape = SubscriptIndexing.compute_shape(target, index)
return target.new_empty(target_shape)
@_decorators.ref(atomic_or)
def _(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | int | bool,
sem: str = "relaxed",
) -> torch.Tensor:
_validate_sem(sem)
return _ref_atomic_binop(target, index, value, torch.bitwise_or)
@_decorators.codegen(atomic_or, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_or", state, _to_ast_values([value_expr]))
@_decorators.codegen(atomic_or, "cute")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common_cute(
"atomic_or",
state,
value_exprs=_to_ast_values([value_expr]),
keyword_names=["val"],
)
@_decorators.codegen(atomic_or, "pallas")
def _(state: CodegenState) -> ast.AST:
from .._compiler.ast_extension import statement_from_string
name, index_str, prev_var = _pallas_atomic_load_prev(state)
value_ast = _to_ast_values([state.ast_args[2]])[0]
state.codegen.add_statement(
statement_from_string(
f"{name}[{index_str}] = {prev_var} | {{value}}", value=value_ast
)
)
return expr_from_string(prev_var)
[docs]
@has_side_effect
@_decorators.api(allow_host_tensor=True, tiles_as_sizes=True)
def atomic_xor(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | int | bool,
sem: str = "relaxed",
) -> torch.Tensor:
"""
Atomically apply bitwise XOR with ``value`` to ``target[index]``.
Args:
target: Tensor to update (integer/bool dtype).
index: Indices selecting elements to update. Can include tiles.
value: Value(s) to XOR with.
sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``,
``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``.
Returns:
torch.Tensor: The previous value(s) stored at ``target[index]`` before the update.
"""
raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_xor)
def _(
target: torch.Tensor, index: list[object], value: object, sem: str = "relaxed"
) -> tuple[torch.Tensor, object, object, str]:
return _prepare_mem_args(target, index, value, sem=sem)
@_decorators.register_fake(atomic_xor)
def _(
target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed"
) -> torch.Tensor:
target_shape = SubscriptIndexing.compute_shape(target, index)
return target.new_empty(target_shape)
@_decorators.ref(atomic_xor)
def _(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | int | bool,
sem: str = "relaxed",
) -> torch.Tensor:
_validate_sem(sem)
return _ref_atomic_binop(target, index, value, torch.bitwise_xor)
@_decorators.codegen(atomic_xor, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_xor", state, _to_ast_values([value_expr]))
@_decorators.codegen(atomic_xor, "cute")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common_cute(
"atomic_xor",
state,
value_exprs=_to_ast_values([value_expr]),
keyword_names=["val"],
)
@_decorators.codegen(atomic_xor, "pallas")
def _(state: CodegenState) -> ast.AST:
from .._compiler.ast_extension import statement_from_string
name, index_str, prev_var = _pallas_atomic_load_prev(state)
value_ast = _to_ast_values([state.ast_args[2]])[0]
state.codegen.add_statement(
statement_from_string(
f"{name}[{index_str}] = {prev_var} ^ {{value}}", value=value_ast
)
)
return expr_from_string(prev_var)
# -- atomic_max/min --
[docs]
@has_side_effect
@_decorators.api(allow_host_tensor=True, tiles_as_sizes=True)
def atomic_max(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | float,
sem: str = "relaxed",
) -> torch.Tensor:
"""
Atomically update ``target[index]`` with the maximum of current value
and ``value``.
Args:
target: Tensor to update.
index: Indices selecting elements to update. Can include tiles.
value: Value(s) to compare with.
sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``,
``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``.
Returns:
torch.Tensor: The previous value(s) stored at ``target[index]`` before the update.
"""
raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_max)
def _(
target: torch.Tensor, index: list[object], value: object, sem: str = "relaxed"
) -> tuple[torch.Tensor, object, object, str]:
return _prepare_mem_args(target, index, value, sem=sem)
@_decorators.register_fake(atomic_max)
def _(
target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed"
) -> torch.Tensor:
target_shape = SubscriptIndexing.compute_shape(target, index)
return target.new_empty(target_shape)
@_decorators.ref(atomic_max)
def _(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | float,
sem: str = "relaxed",
) -> torch.Tensor:
_validate_sem(sem)
return _ref_atomic_binop(target, index, value, torch.maximum)
@_decorators.codegen(atomic_max, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_max", state, _to_ast_values([value_expr]))
@_decorators.codegen(atomic_max, "cute")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common_cute(
"atomic_max",
state,
value_exprs=_to_ast_values([value_expr]),
keyword_names=["val"],
)
@_decorators.codegen(atomic_max, "pallas")
def _(state: CodegenState) -> ast.AST:
from .._compiler.ast_extension import statement_from_string
name, index_str, prev_var = _pallas_atomic_load_prev(state)
value_ast = _to_ast_values([state.ast_args[2]])[0]
state.codegen.add_statement(
statement_from_string(
f"{name}[{index_str}] = jnp.maximum({prev_var}, {{value}})",
value=value_ast,
)
)
return expr_from_string(prev_var)
[docs]
@has_side_effect
@_decorators.api(allow_host_tensor=True, tiles_as_sizes=True)
def atomic_min(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | float,
sem: str = "relaxed",
) -> torch.Tensor:
"""
Atomically update ``target[index]`` with the minimum of current value
and ``value``.
Args:
target: Tensor to update.
index: Indices selecting elements to update. Can include tiles.
value: Value(s) to compare with.
sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``,
``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``.
Returns:
torch.Tensor: The previous value(s) stored at ``target[index]`` before the update.
"""
raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_min)
def _(
target: torch.Tensor, index: list[object], value: object, sem: str = "relaxed"
) -> tuple[torch.Tensor, object, object, str]:
return _prepare_mem_args(target, index, value, sem=sem)
@_decorators.register_fake(atomic_min)
def _(
target: torch.Tensor, index: list[object], value: torch.Tensor, sem: str = "relaxed"
) -> torch.Tensor:
target_shape = SubscriptIndexing.compute_shape(target, index)
return target.new_empty(target_shape)
@_decorators.ref(atomic_min)
def _(
target: torch.Tensor,
index: list[object],
value: torch.Tensor | float,
sem: str = "relaxed",
) -> torch.Tensor:
_validate_sem(sem)
return _ref_atomic_binop(target, index, value, torch.minimum)
@_decorators.codegen(atomic_min, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_min", state, _to_ast_values([value_expr]))
@_decorators.codegen(atomic_min, "cute")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common_cute(
"atomic_min",
state,
value_exprs=_to_ast_values([value_expr]),
keyword_names=["val"],
)
@_decorators.codegen(atomic_min, "pallas")
def _(state: CodegenState) -> ast.AST:
from .._compiler.ast_extension import statement_from_string
name, index_str, prev_var = _pallas_atomic_load_prev(state)
value_ast = _to_ast_values([state.ast_args[2]])[0]
state.codegen.add_statement(
statement_from_string(
f"{name}[{index_str}] = jnp.minimum({prev_var}, {{value}})",
value=value_ast,
)
)
return expr_from_string(prev_var)
# -- atomic_cas --
[docs]
@has_side_effect
@_decorators.api(allow_host_tensor=True, tiles_as_sizes=True)
def atomic_cas(
target: torch.Tensor,
index: list[object],
expected: torch.Tensor | float | bool,
value: torch.Tensor | float | bool,
sem: str = "relaxed",
) -> torch.Tensor:
"""
Atomically compare-and-swap a value at ``target[index]``.
If the current value equals ``expected``, writes ``value``. Otherwise
leaves memory unchanged.
Args:
target: Tensor to update.
index: Indices selecting elements to update. Can include tiles.
expected: Expected current value(s) used for comparison.
value: New value(s) to write if comparison succeeds.
sem: Memory ordering semantics. One of ``"relaxed"``, ``"acquire"``,
``"release"``, ``"acq_rel"``. Defaults to ``"relaxed"``.
Returns:
torch.Tensor: The previous value(s) stored at ``target[index]`` before the compare-and-swap.
Note:
Triton CAS doesn’t support a masked form; our generated code uses
an unmasked CAS and relies on index masking to avoid OOB.
"""
raise exc.NotInsideKernel
@_decorators.prepare_args(atomic_cas)
def _(
target: torch.Tensor,
index: list[object],
expected: object,
value: object,
sem: str = "relaxed",
) -> tuple[torch.Tensor, object, object, object, str]:
return _prepare_mem_args(target, index, expected, value, sem=sem)
@_decorators.register_fake(atomic_cas)
def _(
target: torch.Tensor,
index: list[object],
expected: torch.Tensor,
value: torch.Tensor,
sem: str = "relaxed",
) -> torch.Tensor:
target_shape = SubscriptIndexing.compute_shape(target, index)
return target.new_empty(target_shape)
@_decorators.ref(atomic_cas)
def _(
target: torch.Tensor,
index: list[object],
expected: torch.Tensor | float | bool,
value: torch.Tensor | float | bool,
sem: str = "relaxed",
) -> torch.Tensor:
_validate_sem(sem)
from .ref_tile import RefTile
processed_index: list[object] = []
for idx in index:
if isinstance(idx, RefTile):
processed_index.append(idx._slice)
elif isinstance(idx, torch.Tensor) and idx.numel() == 1:
processed_index.append(int(idx.item()))
else:
processed_index.append(idx)
idx_tuple = tuple(processed_index)
# pyrefly: ignore [bad-index]
prev = target[idx_tuple].clone()
exp_t = (
expected
if isinstance(expected, torch.Tensor)
else torch.as_tensor(expected, dtype=target.dtype, device=target.device)
)
val_t = (
value
if isinstance(value, torch.Tensor)
else torch.as_tensor(value, dtype=target.dtype, device=target.device)
)
# pyrefly: ignore [bad-index]
mask = target[idx_tuple] == exp_t
# pyrefly: ignore [bad-index, unsupported-operation]
target[idx_tuple] = torch.where(mask, val_t, target[idx_tuple])
return prev
@_decorators.codegen(atomic_cas, "triton")
def _(state: CodegenState) -> ast.AST:
exp_expr = state.ast_args[2]
val_expr = state.ast_args[3]
target = state.proxy_arg(0)
index = state.proxy_arg(1)
sem = expr_from_string(repr(state.proxy_arg(len(state.ast_args) - 1)))
assert isinstance(target, torch.Tensor)
assert isinstance(index, list)
# CAS always uses pointer (not a TMA reduction op, two values),
# but increment the counter to keep per-op atomic_indexing aligned.
device_fn = state.device_function
device_fn.atomic_op_index += 1
indices = SubscriptIndexing.create(state, target, index)
name = state.device_function.tensor_arg(target).name
exp_ast, val_ast = _to_ast_values([exp_expr, val_expr])
return expr_from_string(
f"tl.atomic_cas({name} + {{offset}}, {{exp}}, {{val}}, sem={{sem}})",
offset=indices.index_expr,
exp=exp_ast,
val=val_ast,
sem=sem,
)
@_decorators.codegen(atomic_cas, "cute")
def _(state: CodegenState) -> ast.AST:
exp_expr = state.ast_args[2]
val_expr = state.ast_args[3]
target = state.proxy_arg(0)
index = state.proxy_arg(1)
sem = expr_from_string(repr(state.proxy_arg(len(state.ast_args) - 1)))
assert isinstance(target, torch.Tensor)
assert isinstance(index, list)
host_function = HostFunction.current()
if target not in host_function.tensor_to_origin:
raise exc.AtomicOnDeviceTensor("atomic_cas")
pointer = _cute_pointer_expr(state, target, index)
exp_ast, val_ast = _to_ast_values([exp_expr, val_expr])
cmp_kw, val_kw = _resolve_cute_atomic_kwargs("atomic_cas", ["cmp", "val"])
return expr_from_string(
f"cute.arch.atomic_cas({{ptr}}, {cmp_kw}={{exp}}, {val_kw}={{val}}, sem={{sem}})",
ptr=expr_from_string(pointer),
exp=exp_ast,
val=val_ast,
sem=sem,
)
@_decorators.codegen(atomic_cas, "pallas")
def _(state: CodegenState) -> ast.AST:
from .._compiler.ast_extension import statement_from_string
from .._compiler.pallas import codegen as pallas_codegen
target = state.proxy_arg(0)
index = state.proxy_arg(1)
assert isinstance(target, torch.Tensor)
assert isinstance(index, (list, tuple))
host_function = HostFunction.current()
if target not in host_function.tensor_to_origin:
raise exc.AtomicOnDeviceTensor("pallas atomic_cas")
name = state.device_function.tensor_arg(target).name
index_str, _ = pallas_codegen.index_str(state, index, target)
prev_var = state.device_function.new_var("_prev", dce=True)
state.codegen.add_statement(
statement_from_string(f"{prev_var} = {name}[{index_str}]")
)
exp_ast, val_ast = _to_ast_values([state.ast_args[2], state.ast_args[3]])
state.codegen.add_statement(
statement_from_string(
f"{name}[{index_str}] = jnp.where({prev_var} == {{exp}}, {{val}}, {prev_var})",
exp=exp_ast,
val=val_ast,
)
)
return expr_from_string(prev_var)
ATOMIC_OPS: frozenset[Callable[..., object]] = frozenset(
{
atomic_add,
atomic_and,
atomic_cas,
atomic_max,
atomic_min,
atomic_or,
atomic_xchg,
atomic_xor,
}
)