Source code for helion.runtime

from __future__ import annotations

import functools

import torch
import triton

from .config import Config as Config
from .kernel import Kernel as Kernel
from .kernel import kernel as kernel
from .triton_helpers import triton_send_signal as triton_send_signal
from .triton_helpers import triton_wait_multiple_signal as triton_wait_multiple_signal
from .triton_helpers import triton_wait_signal as triton_wait_signal


def _alloc_fn(size: int, alignment: int, stream: int | None) -> torch.Tensor:
    return torch.empty(size, device="cuda", dtype=torch.int8)


[docs] @functools.cache def set_triton_allocator() -> None: try: from triton.runtime._allocation import NullAllocator from triton.runtime._allocation import _allocator if not isinstance(_allocator, NullAllocator): return except ImportError: pass triton.set_allocator(_alloc_fn)
[docs] def get_num_sm(device: torch.device) -> int: """ Get the number of streaming multiprocessors (SMs) for the specified device. Args: device: Device to query. Returns: Grid size to use for a persistent kernel on the device. """ assert device.type == "cuda", "TODO: implement for other devices" return torch.cuda.get_device_properties(device.index).multi_processor_count
def default_launcher( triton_kernel: triton.JITFunction, grid: tuple[int, ...], *args: object, num_warps: int, num_stages: int, ) -> object: """Default launcher function that executes the kernel immediately.""" return triton_kernel.run( *args, grid=grid, warmup=False, num_warps=num_warps, num_stages=num_stages )