from __future__ import annotations
import triton
import triton.language as tl
__all__ = ["triton_send_signal", "triton_wait_multiple_signal", "triton_wait_signal"]
@triton.jit
def triton_send_signal(
addr: tl.tensor,
update: tl.constexpr,
sem: tl.constexpr,
scope: tl.constexpr,
op: tl.constexpr,
skip_sync: tl.constexpr,
) -> tl.tensor:
"""
Signal global memory barrier(s).
This function atomically sets global memory barriers to a update value,
signaling to other CTAs waiting on the barrier(s).
Args:
addr: Memory address of the barrier(s) to wait on
update: Set the barrier to
sem: Memory semantics for the atomic operation. Options: "release", "relaxed".
scope: Scope of the atomic operation. Options: "gpu", "sys"
op: Atomic operation type: "atomic_xchg", "atomic_add"
skip_sync: Skip CTA synchronization before setting the barrier. (default: False)
Returns:
The old value of the barrier(s) before the update.
"""
if not skip_sync:
tl.inline_asm_elementwise(
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
)
tl.static_assert(
sem == "release" or sem == "relaxed",
"Invalid memory semantic. options: 'release', 'relaxed'. ",
)
tl.static_assert(
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu','sys'. "
)
if op == "atomic_xchg":
barrier_status = tl.atomic_xchg(addr, update, sem=sem, scope=scope)
elif op == "atomic_add":
barrier_status = tl.atomic_add(addr, update, sem=sem, scope=scope)
else:
raise NotImplementedError(
f"Unsupported op '{op}' for send signal on gmem barrier. "
)
return barrier_status
[docs]
@triton.jit
def triton_wait_signal(
addr: tl.tensor,
expect: tl.constexpr,
update: tl.constexpr,
sem: tl.constexpr,
scope: tl.constexpr,
op: tl.constexpr,
skip_sync: tl.constexpr,
sync_before: tl.constexpr = False, # pyright: ignore[reportArgumentType]
) -> None:
"""
Wait for a global memory barrier to reach the expected value.
This function implements a spin-wait loop that continuously checks a memory location
until it reaches the expected value, providing synchronization across CTAs.
Args:
addr: Memory address of the barrier to wait on (Must be a scalar)
expect: Expected value to wait for
update: Update the barrier with once acquired
sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed".
scope: Scope of the atomic operation. Options: "gpu", "sys"
op: Atomic operation type: "ld", "atomic_cas"
skip_sync: Skip CTA sync after acquiring the barrier (default: False)
sync_before: Add a CTA sync before the wait (default: False)
"""
tl.static_assert(
addr.type.is_ptr(), # pyright: ignore[reportAttributeAccessIssue]
"Barrier address must be a scalar. Do you want to use '_triton_wait_multiple_signal'? ",
)
tl.static_assert(
(sem == "acquire" or sem == "relaxed") or sem == "release",
"Invalid memory semantic. options: 'acquire', 'relaxed', 'release'. ",
)
tl.static_assert(
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. "
)
tl.static_assert(
op == "ld" or op == "atomic_cas",
"Invalid op. options: 'ld', 'atomic_cas'. ",
)
if sync_before:
tl.inline_asm_elementwise(
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
)
# Spin-wait loop:
# Uses atomic_add with update=0 for ld.global.{sem}.{scope}
# Triton generates smem broadcasting of tl.atomic_add return value in ptx,
# but it is optimized away by ptxas in SASS, hence no performance overhead.
if op == "ld":
tl.static_assert(
update == 0, "ld wait on gmem_barriers cannot update the lock. "
)
while tl.atomic_add(addr, 0, sem=sem, scope=scope) != expect:
pass
elif op == "atomic_cas":
while tl.atomic_cas(addr, expect, update, sem=sem, scope=scope) != expect:
pass
else:
raise NotImplementedError(
f"Unsupported op '{op}' for wait signal on gmem barrier. "
)
if not skip_sync:
tl.inline_asm_elementwise(
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
)
# tl.debug_barrier() cause significant performance loss. (Perhaps breaks triton prefetching?)
@triton.jit
def triton_wait_multiple_signal(
addr: tl.tensor,
expect: tl.constexpr,
update: tl.constexpr,
sem: tl.constexpr,
scope: tl.constexpr,
op: tl.constexpr,
skip_sync: tl.constexpr,
sync_before: tl.constexpr = False, # pyright: ignore[reportArgumentType]
) -> None:
"""
Simultenuoslly wait for multiple global memory barrier to reach the expected value.
This function implements each thread in a CTA spin-waits and continuously checks a memory location until it reaches the expected value, providing synchronization across CTAs.
Args:
addr: Memory addresses of the barriers to wait on (Maximum 32 barriers)
expect: Expected value to wait for
update: Update the barrier with once acquired
sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed".
scope: Scope of the atomic operation. Options: "gpu", "sys"
op: Atomic operation type: "ld", "atomic_cas"
skip_sync: Skip CTA synchronization after acquiring the barrier. (default: False)
"""
tl.static_assert(
(sem == "acquire" or sem == "relaxed") or sem == "release",
"Invalid memory semantic. options: 'acquire', 'relaxed' 'release'. ",
)
tl.static_assert(
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. "
)
tl.static_assert(
op == "ld" or op == "atomic_cas",
"Invalid op. options: 'ld', 'atomic_cas'. ",
)
tl.static_assert(
addr.dtype == tl.pointer_type(tl.int32),
"Invalid barrier value type. Only supports int32 for multi barrier signal. ",
)
addr = tl.ravel(addr)
tl.static_assert(len(addr.shape) == 1, "addr must be a 1D tensor. ")
tl.static_assert(addr.shape[0] <= 32, "Wait on at most 32 barriers at a time. ")
# Assume Triton always sets tid.y == tid.z == 0.
if op == "ld":
tl.inline_asm_elementwise(
f"""
{{
.reg .u32 %tmp32_<3>;
.reg .pred %p<2>;
mov.u32 %tmp32_0, %tid.x;
setp.lt.s32 %p1, %tmp32_0, $2;
mov.u32 $0, 0;
// initialize tmp_0 to 0
wait_block:
@%p1 ld.global.{sem}.{scope}.u32 $0, [$1];
setp.ne.u32 %p0, $0, $3;
and.pred %p0, %p0, %p1;
@%p0 bra wait_block;
}}
""",
"=r, l, r, r",
[addr, addr.shape[0], expect],
dtype=addr.dtype.element_ty,
is_pure=False,
pack=1,
)
elif op == "atomic_cas":
tl.inline_asm_elementwise(
f"""
{{
.reg .u32 %tmp32_<3>;
.reg .pred %p<2>;
mov.u32 %tmp32_0, %tid.x;
setp.lt.s32 %p1, %tmp32_0, $2;
mov.u32 $0, 0;
// initialize tmp_0 to 0
wait_block:
@%p1 atom.global.{sem}.{scope}.cas.b32 $0, [$1], $3, $4;
setp.ne.u32 %p0, $0, $3;
and.pred %p0, %p0, %p1;
@%p0 bra wait_block;
}}
""",
"=r, l, r, r, r",
[addr, addr.shape[0], expect, update],
dtype=addr.dtype.element_ty,
is_pure=False,
pack=1,
)
else:
raise NotImplementedError(
f"Unsupported op '{op}' for wait signal on gmem barrier. "
)
if not skip_sync:
tl.inline_asm_elementwise(
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
)