Source code for helion.runtime.triton_helpers

from __future__ import annotations

import triton
import triton.language as tl

__all__ = ["triton_send_signal", "triton_wait_multiple_signal", "triton_wait_signal"]


@triton.jit
def triton_send_signal(
    addr: tl.tensor,
    update: tl.constexpr,
    sem: tl.constexpr,
    scope: tl.constexpr,
    op: tl.constexpr,
    skip_sync: tl.constexpr,
) -> tl.tensor:
    """
    Signal global memory barrier(s).

    This function atomically sets global memory barriers to a update value,
    signaling to other CTAs waiting on the barrier(s).

    Args:
        addr: Memory address of the barrier(s) to wait on
        update: Set the barrier to
        sem: Memory semantics for the atomic operation. Options: "release", "relaxed".
        scope: Scope of the atomic operation. Options: "gpu", "sys"
        op: Atomic operation type: "atomic_xchg", "atomic_add"
        skip_sync: Skip CTA synchronization before setting the barrier. (default: False)
    Returns:
        The old value of the barrier(s) before the update.
    """
    if not skip_sync:
        tl.inline_asm_elementwise(
            "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
        )

    tl.static_assert(
        sem == "release" or sem == "relaxed",
        "Invalid memory semantic. options: 'release', 'relaxed'. ",
    )
    tl.static_assert(
        scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu','sys'. "
    )

    if op == "atomic_xchg":
        barrier_status = tl.atomic_xchg(addr, update, sem=sem, scope=scope)
    elif op == "atomic_add":
        barrier_status = tl.atomic_add(addr, update, sem=sem, scope=scope)
    else:
        raise NotImplementedError(
            f"Unsupported op '{op}' for send signal on gmem barrier. "
        )
    return barrier_status


[docs] @triton.jit def triton_wait_signal( addr: tl.tensor, expect: tl.constexpr, update: tl.constexpr, sem: tl.constexpr, scope: tl.constexpr, op: tl.constexpr, skip_sync: tl.constexpr, sync_before: tl.constexpr = False, # pyright: ignore[reportArgumentType] ) -> None: """ Wait for a global memory barrier to reach the expected value. This function implements a spin-wait loop that continuously checks a memory location until it reaches the expected value, providing synchronization across CTAs. Args: addr: Memory address of the barrier to wait on (Must be a scalar) expect: Expected value to wait for update: Update the barrier with once acquired sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed". scope: Scope of the atomic operation. Options: "gpu", "sys" op: Atomic operation type: "ld", "atomic_cas" skip_sync: Skip CTA sync after acquiring the barrier (default: False) sync_before: Add a CTA sync before the wait (default: False) """ tl.static_assert( addr.type.is_ptr(), # pyright: ignore[reportAttributeAccessIssue] "Barrier address must be a scalar. Do you want to use '_triton_wait_multiple_signal'? ", ) tl.static_assert( (sem == "acquire" or sem == "relaxed") or sem == "release", "Invalid memory semantic. options: 'acquire', 'relaxed', 'release'. ", ) tl.static_assert( scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. " ) tl.static_assert( op == "ld" or op == "atomic_cas", "Invalid op. options: 'ld', 'atomic_cas'. ", ) if sync_before: tl.inline_asm_elementwise( "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 ) # Spin-wait loop: # Uses atomic_add with update=0 for ld.global.{sem}.{scope} # Triton generates smem broadcasting of tl.atomic_add return value in ptx, # but it is optimized away by ptxas in SASS, hence no performance overhead. if op == "ld": tl.static_assert( update == 0, "ld wait on gmem_barriers cannot update the lock. " ) while tl.atomic_add(addr, 0, sem=sem, scope=scope) != expect: pass elif op == "atomic_cas": while tl.atomic_cas(addr, expect, update, sem=sem, scope=scope) != expect: pass else: raise NotImplementedError( f"Unsupported op '{op}' for wait signal on gmem barrier. " ) if not skip_sync: tl.inline_asm_elementwise( "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 )
# tl.debug_barrier() cause significant performance loss. (Perhaps breaks triton prefetching?) @triton.jit def triton_wait_multiple_signal( addr: tl.tensor, expect: tl.constexpr, update: tl.constexpr, sem: tl.constexpr, scope: tl.constexpr, op: tl.constexpr, skip_sync: tl.constexpr, sync_before: tl.constexpr = False, # pyright: ignore[reportArgumentType] ) -> None: """ Simultenuoslly wait for multiple global memory barrier to reach the expected value. This function implements each thread in a CTA spin-waits and continuously checks a memory location until it reaches the expected value, providing synchronization across CTAs. Args: addr: Memory addresses of the barriers to wait on (Maximum 32 barriers) expect: Expected value to wait for update: Update the barrier with once acquired sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed". scope: Scope of the atomic operation. Options: "gpu", "sys" op: Atomic operation type: "ld", "atomic_cas" skip_sync: Skip CTA synchronization after acquiring the barrier. (default: False) """ tl.static_assert( (sem == "acquire" or sem == "relaxed") or sem == "release", "Invalid memory semantic. options: 'acquire', 'relaxed' 'release'. ", ) tl.static_assert( scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. " ) tl.static_assert( op == "ld" or op == "atomic_cas", "Invalid op. options: 'ld', 'atomic_cas'. ", ) tl.static_assert( addr.dtype == tl.pointer_type(tl.int32), "Invalid barrier value type. Only supports int32 for multi barrier signal. ", ) addr = tl.ravel(addr) tl.static_assert(len(addr.shape) == 1, "addr must be a 1D tensor. ") tl.static_assert(addr.shape[0] <= 32, "Wait on at most 32 barriers at a time. ") # Assume Triton always sets tid.y == tid.z == 0. if op == "ld": tl.inline_asm_elementwise( f""" {{ .reg .u32 %tmp32_<3>; .reg .pred %p<2>; mov.u32 %tmp32_0, %tid.x; setp.lt.s32 %p1, %tmp32_0, $2; mov.u32 $0, 0; // initialize tmp_0 to 0 wait_block: @%p1 ld.global.{sem}.{scope}.u32 $0, [$1]; setp.ne.u32 %p0, $0, $3; and.pred %p0, %p0, %p1; @%p0 bra wait_block; }} """, "=r, l, r, r", [addr, addr.shape[0], expect], dtype=addr.dtype.element_ty, is_pure=False, pack=1, ) elif op == "atomic_cas": tl.inline_asm_elementwise( f""" {{ .reg .u32 %tmp32_<3>; .reg .pred %p<2>; mov.u32 %tmp32_0, %tid.x; setp.lt.s32 %p1, %tmp32_0, $2; mov.u32 $0, 0; // initialize tmp_0 to 0 wait_block: @%p1 atom.global.{sem}.{scope}.cas.b32 $0, [$1], $3, $4; setp.ne.u32 %p0, $0, $3; and.pred %p0, %p0, %p1; @%p0 bra wait_block; }} """, "=r, l, r, r, r", [addr, addr.shape[0], expect, update], dtype=addr.dtype.element_ty, is_pure=False, pack=1, ) else: raise NotImplementedError( f"Unsupported op '{op}' for wait signal on gmem barrier. " ) if not skip_sync: tl.inline_asm_elementwise( "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 )