Source code for helion.runtime
from __future__ import annotations
import contextvars
import os
import torch
import triton
from .. import _compat as _compat # ensure Triton compatibility patches run
from .config import Config as Config
from .kernel import Kernel as Kernel
from .kernel import kernel as kernel
from .triton_helpers import triton_send_signal as triton_send_signal
from .triton_helpers import triton_wait_multiple_signal as triton_wait_multiple_signal
from .triton_helpers import triton_wait_signal as triton_wait_signal
def _alloc_fn(size: int, alignment: int, stream: int | None) -> torch.Tensor:
# Dynamically get device from Triton backend
current_target = triton.runtime.driver.active.get_current_target()
if current_target is None:
raise RuntimeError("No active Triton target available")
backend = current_target.backend
return torch.empty(size, device=backend, dtype=torch.int8)
[docs]
def set_triton_allocator() -> None:
try:
from triton import set_allocator
from triton.runtime._allocation import NullAllocator
from triton.runtime._allocation import _allocator
except ImportError:
return
if isinstance(_allocator, contextvars.ContextVar):
existing = _allocator.get()
else: # older versions of Triton
existing = _allocator
# if allocator isn't NullAllocator, we assume it is set by the user
if isinstance(existing, NullAllocator):
set_allocator(_alloc_fn)
[docs]
def get_num_sm(device: torch.device, *, reserved_sms: int = 0) -> int:
"""
Get the number of streaming multiprocessors (SMs) for the specified device.
Args:
device: Device to query.
reserved_sms: Number of SMs to keep free for other work (e.g., communication
kernels). Defaults to 0 meaning all device SMs are available to Helion.
Returns:
Grid size to use for a persistent kernel on the device after accounting
for any reserved SMs. Always at least 1.
"""
available_sms: int
assert device.type in [
"cuda",
"xpu",
"cpu",
"mtia",
], "TODO: implement for other devices"
if device.type == "cpu":
try:
num_threads = int(torch.get_num_threads())
except Exception:
num_threads = 0
available_sms = num_threads if num_threads > 0 else int(os.cpu_count() or 1)
elif device.type == "cuda":
available_sms = torch.cuda.get_device_properties(
device.index
).multi_processor_count
# TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number.
elif device.type == "xpu":
available_sms = torch.xpu.get_device_properties(device.index).gpu_subslice_count
elif device.type == "mtia":
device_props = torch.mtia.get_device_properties(device.index)
if "max_grid_height" in device_props and "max_grid_width" in device_props:
available_sms = (
device_props["max_grid_height"] * device_props["max_grid_width"]
)
else:
raise RuntimeError(
f"Unable to determine SM count for MTIA device. "
f"Available properties: {list(device_props.keys())}"
)
else:
raise NotImplementedError(
f"get_num_sm not implemented for device type: {device.type}"
)
if reserved_sms <= 0:
return available_sms
return max(available_sms - reserved_sms, 1)
def default_launcher(
triton_kernel: triton.JITFunction,
grid: tuple[int, ...],
*args: object,
num_warps: int,
num_stages: int,
**kwargs: dict,
) -> object:
"""Default launcher function that executes the kernel immediately."""
# For both CUDA and MTIA, use the same kernel execution
return triton_kernel.run(
*args,
grid=grid,
warmup=False,
num_warps=num_warps,
num_stages=num_stages,
**kwargs,
)