Rate this Page

Source code for helion.runtime

from __future__ import annotations

from contextlib import suppress
import contextvars
import hashlib
import linecache
import sys
from typing import Any
from typing import cast

import torch

from .. import _compat as _compat  # ensure Triton compatibility patches run
from .. import exc
from .._utils import triton_is_available
from .config import Config as Config
from .kernel import Kernel as Kernel
from .kernel import kernel as kernel

_CUTLASS_SHUTDOWN_PATCHED = False


def _patch_cutlass_jit_shutdown_unload() -> None:
    """Avoid CUDA library unload hangs during interpreter shutdown.

    On current CUTLASS DSL builds, ``CudaDialectJitModule.__del__`` unconditionally
    calls ``cudaLibraryUnload``. On B200 this can hang during Python finalization
    after a CuTe kernel has already finished executing. Skipping that unload during
    interpreter teardown lets the process exit cleanly while preserving the normal
    unload path during regular runtime GC.
    """

    global _CUTLASS_SHUTDOWN_PATCHED
    if _CUTLASS_SHUTDOWN_PATCHED:
        return

    try:
        import cutlass.cutlass_dsl.cuda_jit_executor as cuda_jit_executor
    except ImportError:
        return

    module_type = cuda_jit_executor.CudaDialectJitModule
    if getattr(module_type, "_helion_shutdown_patch", False):
        _CUTLASS_SHUTDOWN_PATCHED = True
        return

    original_del = cast("Any", module_type.__del__)

    def _helion_del(self: object) -> None:
        module = cast("Any", self)
        if sys.is_finalizing():
            with suppress(Exception):
                module._unloaded = True
            return
        original_del(module)

    module_type.__del__ = _helion_del
    module_type._helion_shutdown_patch = True
    _CUTLASS_SHUTDOWN_PATCHED = True


if triton_is_available():
    import triton

    from .triton_helpers import triton_send_signal as triton_send_signal
    from .triton_helpers import (
        triton_wait_multiple_signal as triton_wait_multiple_signal,
    )
    from .triton_helpers import triton_wait_signal as triton_wait_signal

    def _alloc_fn(size: int, alignment: int, stream: int | None) -> torch.Tensor:
        # Dynamically get device from Triton backend
        current_target = triton.runtime.driver.active.get_current_target()
        if current_target is None:
            raise RuntimeError("No active Triton target available")
        backend = current_target.backend
        return torch.empty(size, device=backend, dtype=torch.int8)

    def set_triton_allocator() -> None:
        try:
            from triton import set_allocator
            from triton.runtime._allocation import NullAllocator
            from triton.runtime._allocation import _allocator
        except ImportError:
            return
        if isinstance(_allocator, contextvars.ContextVar):
            existing = _allocator.get()
        else:  # older versions of Triton
            existing = _allocator
        # if allocator isn't NullAllocator, we assume it is set by the user
        if isinstance(existing, NullAllocator):
            set_allocator(_alloc_fn)
else:

[docs] def set_triton_allocator() -> None: # type: ignore[misc] pass
[docs] def get_num_sm(device: torch.device, *, reserved_sms: int = 0) -> int: """ Get the number of streaming multiprocessors (SMs) for the specified device. Args: device: Device to query. reserved_sms: Number of SMs to keep free for other work (e.g., communication kernels). Defaults to 0 meaning all device SMs are available to Helion. Returns: Grid size to use for a persistent kernel on the device after accounting for any reserved SMs. Always at least 1. """ available_sms: int assert device.type in [ "cuda", "xpu", "mtia", "mps", ], "TODO: implement for other devices" if device.type == "cuda": available_sms = torch.cuda.get_device_properties( device.index ).multi_processor_count # TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number. elif device.type == "xpu": available_sms = torch.xpu.get_device_properties(device.index).gpu_subslice_count elif device.type == "mps": available_sms = torch.backends.mps.get_core_count() elif device.type == "mtia": device_props = torch.mtia.get_device_properties(device.index) if "max_grid_height" in device_props and "max_grid_width" in device_props: available_sms = ( device_props["max_grid_height"] * device_props["max_grid_width"] ) else: raise RuntimeError( f"Unable to determine SM count for MTIA device. " f"Available properties: {list(device_props.keys())}" ) else: raise NotImplementedError( f"get_num_sm not implemented for device type: {device.type}" ) if reserved_sms <= 0: return available_sms return max(available_sms - reserved_sms, 1)
def default_launcher( triton_kernel: object, grid: tuple[int, ...], *args: object, num_warps: int, num_stages: int, ptx_options: str | None = None, launch_cooperative_grid: bool = False, **kwargs: dict, ) -> object: """Default launcher function that executes the kernel immediately.""" # For both CUDA and MTIA, use the same kernel execution run_kwargs: dict = { "grid": grid, "warmup": False, "num_warps": num_warps, "num_stages": num_stages, "launch_cooperative_grid": launch_cooperative_grid, **kwargs, } if ptx_options is not None: run_kwargs["ptx_options"] = ptx_options return triton_kernel.run( # type: ignore[union-attr] *args, **run_kwargs, ) def _pallas_make_block_spec( pl: object, jnp: object, tensor: torch.Tensor, entry: tuple[tuple[int | None, ...], tuple[int | tuple[int, int, int] | None, ...]] | None, ) -> object: """Build one ``pl.BlockSpec`` from compile-time ``(block_shape, grid_dims)``.""" if entry is None: ndim = tensor.ndim full_shape = tuple(tensor.shape) def index_map_full(*grid_args: object, _nd: int = ndim) -> tuple[object, ...]: # pyrefly: ignore[missing-attribute] return tuple(jnp.int32(0) for _ in range(_nd)) return pl.BlockSpec(full_shape, index_map_full) # type: ignore[union-attr] block_shape_template, grid_dims = entry block_shape = tuple( min(bs, tensor.shape[d]) if bs is not None else tensor.shape[d] for d, bs in enumerate(block_shape_template) ) def _index_for_dim( grid_args: tuple[object, ...], g: int | tuple[int, int, int] | None, jnp: object = jnp, ) -> object: if g is None: return jnp.int32(0) # pyrefly: ignore[missing-attribute] if isinstance(g, tuple): # Flat grid decomposition: (grid_dim, stride, num_blocks) grid_dim, stride, num_blocks = g val = grid_args[grid_dim] if stride > 1: val = val // stride # type: ignore[operator] val = val % num_blocks # type: ignore[operator] return jnp.int32(val) # pyrefly: ignore[missing-attribute] return jnp.int32(grid_args[g]) # pyrefly: ignore[missing-attribute] def index_map( *grid_args: object, _grid_dims: tuple[int | tuple[int, int, int] | None, ...] = grid_dims, ) -> tuple[object, ...]: return tuple(_index_for_dim(grid_args, g) for g in _grid_dims) return pl.BlockSpec(block_shape, index_map) # type: ignore[union-attr] # Per-tensor block spec info: see ``_pallas_make_block_spec``. # grid_dims entries are int (direct grid dim), tuple (flat decomposition), # or None (untiled dim). _BlockSpecInfo = list[ tuple[tuple[int | None, ...], tuple[int | tuple[int, int, int] | None, ...]] | None ] def _pallas_build_block_specs( pl: object, jnp: object, grid: tuple[int, ...], args: tuple[object, ...], tensor_arg_indices: list[int], output_indices: list[int], block_spec_info: _BlockSpecInfo | None = None, ) -> tuple[list[object] | None, object | None]: """Build ``in_specs`` and ``out_specs`` for ``pl.pallas_call``.""" if block_spec_info is None or len(grid) == 0: return None, None in_specs = [] for tensor_pos, idx in enumerate(tensor_arg_indices): t = args[idx] assert isinstance(t, torch.Tensor) in_specs.append( _pallas_make_block_spec(pl, jnp, t, block_spec_info[tensor_pos]) ) arg_to_tensor_pos = {orig: tpos for tpos, orig in enumerate(tensor_arg_indices)} out_specs_list = [] for idx in output_indices: t = args[idx] assert isinstance(t, torch.Tensor) out_specs_list.append( _pallas_make_block_spec(pl, jnp, t, block_spec_info[arg_to_tensor_pos[idx]]) ) out_specs = out_specs_list if len(out_specs_list) > 1 else out_specs_list[0] return in_specs, out_specs def _pallas_prepare_args( args: tuple[object, ...], _output_indices: list[int], ) -> tuple[ set[int], list[int], dict[int, object], int, dict[int, int], list[object], set[int], tuple[object, ...], ]: """Extract and organize tensor/non-tensor args for Pallas launchers. Returns a tuple of: - output_set: set of output arg positions - tensor_arg_indices: positions of tensor args - non_tensor_args: mapping of non-tensor arg positions to values - n_tensor_inputs: count of tensor args - arg_to_tensor_pos: mapping from original position to tensor-only position - outputs: list of output tensors - inplace_positions: positions that are both input and output - out_shapes: JAX placeholders for output shapes """ from torch_tpu._internal.pallas.pallas import ( # pyrefly: ignore[missing-import] jax_placeholder, ) output_set = set(_output_indices) tensor_arg_indices = [ i for i in range(len(args)) if isinstance(args[i], torch.Tensor) ] non_tensor_args: dict[int, object] = { i: args[i] for i in range(len(args)) if not isinstance(args[i], torch.Tensor) } n_tensor_inputs = len(tensor_arg_indices) arg_to_tensor_pos = {orig: tpos for tpos, orig in enumerate(tensor_arg_indices)} outputs = [args[i] for i in _output_indices] inplace_positions = output_set & set(tensor_arg_indices) out_shapes = tuple(jax_placeholder(out) for out in outputs) # type: ignore[arg-type] return ( output_set, tensor_arg_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, outputs, inplace_positions, out_shapes, ) def _pallas_make_reordered_kernel( pallas_kernel: object, args: tuple[object, ...], tensor_arg_indices: list[int], non_tensor_args: dict[int, object], n_tensor_inputs: int, _output_indices: list[int], inplace_positions: set[int], arg_to_tensor_pos: dict[int, int], n_extra_refs: int = 0, skip_inplace_copy: bool = False, ) -> object: """Create a wrapper kernel that reorders pallas_call refs to the original arg order. ``pallas_call`` provides refs as ``[inputs..., outputs...]``, but Helion kernels expect the original parameter order. When *n_extra_refs* > 0 (e.g. scratch buffers), those trailing refs are appended after the reordered args. When *skip_inplace_copy* is True, the initial ``out_ref[...] = in_ref[...]`` copy for inplace positions is skipped. This is needed for the pipeline launcher where refs are in HBM (``pl.ANY``) and direct load/store is not allowed — ``input_output_aliases`` already handles the aliasing. """ def reordered_kernel(*refs: object) -> None: n_kernel_params = len(args) original_order: list[object] = [None] * n_kernel_params for tensor_pos, orig_pos in enumerate(tensor_arg_indices): original_order[orig_pos] = refs[tensor_pos] for orig_pos, value in non_tensor_args.items(): original_order[orig_pos] = value for out_idx, orig_pos in enumerate(_output_indices): out_ref = refs[n_tensor_inputs + out_idx] if orig_pos in inplace_positions and not skip_inplace_copy: in_ref = refs[arg_to_tensor_pos[orig_pos]] out_ref[...] = in_ref[...] # type: ignore[index] original_order[orig_pos] = out_ref extra_refs = refs[n_tensor_inputs + len(_output_indices) :] pallas_kernel(*original_order, *extra_refs) # type: ignore[operator] return reordered_kernel def _pallas_build_callable( pallas_kernel: object, grid: tuple[int, ...], jit_fn: object, _output_indices: list[int], arg_to_tensor_pos: dict[int, int], tensor_arg_indices: list[int], cache_attr: str, trace_key_suffix: str = "", ) -> object: """Build a ``JaxCallable``, cache it on the kernel, and return it.""" import jax from torch_tpu._internal.pallas.pallas import ( # pyrefly: ignore[missing-import] JaxCallable, ) kernel_name = getattr(pallas_kernel, "__name__", "pallas_kernel") call_aliases: dict[int, int] = {} for out_idx, orig_pos in enumerate(_output_indices): call_aliases[arg_to_tensor_pos[orig_pos]] = out_idx jax_callable = JaxCallable( name=kernel_name, jit_fn=jax.jit(jit_fn), # pyrefly: ignore[no-matching-overload] trace_key=f"{kernel_name}_{id(pallas_kernel)}_{grid}{trace_key_suffix}", input_output_aliases=call_aliases, ) setattr(pallas_kernel, cache_attr, (grid, jax_callable, tensor_arg_indices)) return jax_callable def default_pallas_launcher( pallas_kernel: object, grid: tuple[int, ...], *args: object, _output_indices: list[int] | None = None, _block_spec_info: _BlockSpecInfo | None = None, **kwargs: object, ) -> None: """Default launcher for Pallas kernels on TPU. Uses ``JaxCallable`` from ``torch_tpu`` to compile and run the Pallas kernel on TPU. Output tensors are donated via ``input_output_aliases`` so the kernel writes directly into their buffers (zero-copy). """ if _output_indices is None: _output_indices = [] cache = getattr(pallas_kernel, "_pallas_cache", None) if cache is not None and cache[0] == grid: _, jax_callable, tensor_arg_indices = cache else: from jax.experimental import pallas as pl import jax.numpy as jnp ( output_set, tensor_arg_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, outputs, inplace_positions, out_shapes, ) = _pallas_prepare_args(args, _output_indices) in_specs, out_specs = _pallas_build_block_specs( pl, jnp, grid, args, tensor_arg_indices, _output_indices, _block_spec_info, ) reordered_kernel = _pallas_make_reordered_kernel( pallas_kernel, args, tensor_arg_indices, non_tensor_args, n_tensor_inputs, _output_indices, inplace_positions, arg_to_tensor_pos, ) out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] pallas_aliases = { arg_to_tensor_pos[orig_pos]: out_idx for out_idx, orig_pos in enumerate(_output_indices) } pallas_call_kwargs: dict[str, object] = { "out_shape": out_shape_arg, "input_output_aliases": pallas_aliases, "grid": grid, } if in_specs is not None: pallas_call_kwargs["in_specs"] = in_specs pallas_call_kwargs["out_specs"] = out_specs jit_fn = pl.pallas_call( reordered_kernel, # pyrefly: ignore[bad-argument-type] **pallas_call_kwargs, # type: ignore[arg-type] ) jax_callable = _pallas_build_callable( pallas_kernel, grid, jit_fn, _output_indices, arg_to_tensor_pos, tensor_arg_indices, cache_attr="_pallas_cache", ) input_tensors = [ cast("torch.Tensor", args[i]).contiguous() for i in tensor_arg_indices ] jax_callable(*input_tensors) # type: ignore[operator] def default_pallas_pipeline_launcher( pallas_kernel: object, grid: tuple[int, ...], *args: object, _output_indices: list[int] | None = None, _block_spec_info: _BlockSpecInfo | None = None, _scratch_shapes: list[tuple[tuple[int, ...], str]] | None = None, **kwargs: object, ) -> None: """Launcher for Pallas kernels using PrefetchScalarGridSpec with scratch memory. Used when ``pallas_loop_type='emit_pipeline'``. Passes all tensors as ``memory_space=pl.ANY`` (HBM refs) and adds scratch buffers as ``pltpu.VMEM`` shapes. The kernel uses ``pltpu.emit_pipeline`` internally for DMA pipelining. """ if _output_indices is None: _output_indices = [] if _scratch_shapes is None: _scratch_shapes = [] cache = getattr(pallas_kernel, "_pallas_pipeline_cache", None) if cache is not None and cache[0] == grid: _, jax_callable, tensor_arg_indices = cache else: from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp ( output_set, tensor_arg_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, outputs, inplace_positions, out_shapes, ) = _pallas_prepare_args(args, _output_indices) # Build scratch shapes for VMEM _jnp_dtype_map: dict[str, object] = { "jnp.float32": jnp.float32, "jnp.float16": jnp.float16, "jnp.bfloat16": jnp.bfloat16, "jnp.int32": jnp.int32, "jnp.int16": jnp.int16, "jnp.int8": jnp.int8, "jnp.uint8": jnp.uint8, "jnp.bool_": jnp.bool_, } scratch_shapes = [] for scratch_entry in _scratch_shapes: if len(scratch_entry) == 3: shape, dtype_str, scratch_type = scratch_entry else: shape, dtype_str = scratch_entry # type: ignore[misc] scratch_type = "vmem" if scratch_type == "dma_semaphore": scratch_shapes.append(pltpu.SemaphoreType.DMA(())) else: jnp_dtype = _jnp_dtype_map.get(dtype_str, jnp.float32) scratch_shapes.append( pltpu.VMEM(shape, jnp_dtype) # pyrefly: ignore[bad-argument-type] ) # Build in_specs/out_specs with memory_space=pl.ANY (HBM refs) in_specs_list = [pl.BlockSpec(memory_space=pl.ANY) for _ in tensor_arg_indices] out_specs_list = [pl.BlockSpec(memory_space=pl.ANY) for _ in _output_indices] out_specs = out_specs_list if len(out_specs_list) > 1 else out_specs_list[0] reordered_kernel = _pallas_make_reordered_kernel( pallas_kernel, args, tensor_arg_indices, non_tensor_args, n_tensor_inputs, _output_indices, inplace_positions, arg_to_tensor_pos, n_extra_refs=len(scratch_shapes), skip_inplace_copy=True, ) out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] pallas_aliases = { arg_to_tensor_pos[orig_pos]: out_idx for out_idx, orig_pos in enumerate(_output_indices) } grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=in_specs_list, out_specs=out_specs, scratch_shapes=scratch_shapes, grid=grid, ) jit_fn = pl.pallas_call( reordered_kernel, # pyrefly: ignore[bad-argument-type] out_shape=out_shape_arg, input_output_aliases=pallas_aliases, grid_spec=grid_spec, compiler_params=pltpu.CompilerParams( # pyrefly: ignore[bad-instantiation] dimension_semantics=tuple("parallel" for _ in grid), ), ) jax_callable = _pallas_build_callable( pallas_kernel, grid, jit_fn, _output_indices, arg_to_tensor_pos, tensor_arg_indices, cache_attr="_pallas_pipeline_cache", trace_key_suffix="_pipeline", ) input_tensors = [ cast("torch.Tensor", args[i]).contiguous() for i in tensor_arg_indices ] jax_callable(*input_tensors) # type: ignore[operator] def default_pallas_fori_launcher( pallas_kernel: object, grid: tuple[int, ...], *args: object, _output_indices: list[int] | None = None, _block_spec_info: _BlockSpecInfo | None = None, _scratch_shapes: list[tuple[tuple[int, ...], str | None, str]] | None = None, **kwargs: object, ) -> None: """Launcher for Pallas kernels using fori_loop with manual DMA. Used when ``pallas_loop_type="fori_loop"``. Passes all tensors as ``memory_space=pl.ANY`` (HBM refs) and adds scratch buffers as ``pltpu.VMEM`` shapes plus ``pltpu.SemaphoreType.DMA`` for async copies. The kernel uses ``jax.lax.fori_loop`` with ``pltpu.make_async_copy`` internally for DMA control. """ if _output_indices is None: _output_indices = [] if _scratch_shapes is None: _scratch_shapes = [] cache = getattr(pallas_kernel, "_pallas_fori_cache", None) if cache is not None and cache[0] == grid: _, jax_callable, tensor_arg_indices = cache else: from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp ( output_set, tensor_arg_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, outputs, inplace_positions, out_shapes, ) = _pallas_prepare_args(args, _output_indices) # Build scratch shapes: VMEM buffers + DMA semaphores _jnp_dtype_map: dict[str, object] = { "jnp.float32": jnp.float32, "jnp.float16": jnp.float16, "jnp.bfloat16": jnp.bfloat16, "jnp.int32": jnp.int32, "jnp.int16": jnp.int16, "jnp.int8": jnp.int8, "jnp.uint8": jnp.uint8, "jnp.bool_": jnp.bool_, } scratch_shapes = [] for shape, dtype_str, scratch_type in _scratch_shapes: if scratch_type == "dma_semaphore": scratch_shapes.append(pltpu.SemaphoreType.DMA(())) else: # "vmem" assert dtype_str is not None jnp_dtype = _jnp_dtype_map.get(dtype_str, jnp.float32) scratch_shapes.append( pltpu.VMEM(shape, jnp_dtype) # pyrefly: ignore[bad-argument-type] ) # Build in_specs/out_specs with memory_space=pl.ANY (HBM refs) in_specs_list = [pl.BlockSpec(memory_space=pl.ANY) for _ in tensor_arg_indices] out_specs_list = [pl.BlockSpec(memory_space=pl.ANY) for _ in _output_indices] out_specs = out_specs_list if len(out_specs_list) > 1 else out_specs_list[0] reordered_kernel = _pallas_make_reordered_kernel( pallas_kernel, args, tensor_arg_indices, non_tensor_args, n_tensor_inputs, _output_indices, inplace_positions, arg_to_tensor_pos, n_extra_refs=len(scratch_shapes), skip_inplace_copy=True, ) out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] pallas_aliases = { arg_to_tensor_pos[orig_pos]: out_idx for out_idx, orig_pos in enumerate(_output_indices) } grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=in_specs_list, out_specs=out_specs, scratch_shapes=scratch_shapes, grid=grid, ) jit_fn = pl.pallas_call( reordered_kernel, # pyrefly: ignore[bad-argument-type] out_shape=out_shape_arg, input_output_aliases=pallas_aliases, grid_spec=grid_spec, compiler_params=pltpu.CompilerParams( # pyrefly: ignore[bad-instantiation] dimension_semantics=tuple("parallel" for _ in grid), ), ) jax_callable = _pallas_build_callable( pallas_kernel, grid, jit_fn, _output_indices, arg_to_tensor_pos, tensor_arg_indices, cache_attr="_pallas_fori_cache", trace_key_suffix="_fori", ) input_tensors = [ cast("torch.Tensor", args[i]).contiguous() for i in tensor_arg_indices ] jax_callable(*input_tensors) # type: ignore[operator] def _torch_dtype_to_cutlass(dtype: torch.dtype) -> object: _patch_cutlass_jit_shutdown_unload() import cutlass mapping: dict[torch.dtype, object] = { torch.float16: cutlass.Float16, torch.float32: cutlass.Float32, torch.float64: cutlass.Float64, torch.bfloat16: cutlass.BFloat16, torch.int8: cutlass.Int8, torch.int16: cutlass.Int16, torch.int32: cutlass.Int32, torch.int64: cutlass.Int64, torch.uint8: cutlass.Uint8, } if dtype not in mapping: raise exc.BackendUnsupported("cute", f"dtype: {dtype}") return mapping[dtype] def _normalize_cute_scalar(arg: object) -> tuple[str, object]: if isinstance(arg, (bool, torch.SymBool)): return ("bool", bool(arg)) if isinstance(arg, (int, torch.SymInt)): return ("int", int(arg)) if isinstance(arg, (float, torch.SymFloat)): return ("float", float(arg)) raise exc.BackendUnsupported("cute", f"launcher scalar argument type: {type(arg)}") def _cute_scalar_annotation(kind: str) -> str: mapping = { "bool": "cutlass.Boolean", "int": "cutlass.Int64", "float": "cutlass.Float64", } return mapping[kind] def _create_cute_wrapper( cute_kernel: object, schema_key: tuple[tuple[object, ...], ...], ) -> object: _patch_cutlass_jit_shutdown_unload() import cutlass import cutlass.cute as cute kernel_name = getattr(cast("Any", cute_kernel), "__name__", "cute_kernel") kernel_tag = f"{kernel_name}_{id(cute_kernel):x}" func_name = f"_helion_cute_launch_{kernel_tag}" params: list[str] = [] body: list[str] = [] call_args: list[str] = [] for i, entry in enumerate(schema_key): kind = entry[0] if kind == "tensor": (_, _dtype, rank) = entry assert isinstance(rank, int) ptr_name = f"arg{i}_ptr" params.append(f"{ptr_name}: cute.Pointer") shape_names = [f"arg{i}_shape{d}" for d in range(rank)] stride_names = [f"arg{i}_stride{d}" for d in range(rank)] params.extend(f"{name}: cutlass.Int64" for name in shape_names) params.extend(f"{name}: cutlass.Int64" for name in stride_names) shape_tuple = ( f"({shape_names[0]},)" if rank == 1 else f"({', '.join(shape_names)})" ) stride_tuple = ( f"({stride_names[0]},)" if rank == 1 else f"({', '.join(stride_names)})" ) body.append( f" arg{i} = cute.make_tensor({ptr_name}, layout=cute.make_layout({shape_tuple}, stride={stride_tuple}))" ) call_args.append(f"arg{i}") continue assert kind == "scalar" (_, scalar_kind) = entry assert isinstance(scalar_kind, str) scalar_name = f"arg{i}" params.append(f"{scalar_name}: {_cute_scalar_annotation(scalar_kind)}") call_args.append(scalar_name) params.extend( ( "grid_x: cutlass.Int32", "grid_y: cutlass.Int32", "grid_z: cutlass.Int32", "block_x: cutlass.Int32", "block_y: cutlass.Int32", "block_z: cutlass.Int32", ) ) body.extend( ( f" _helion_cute_kernel_tag = {kernel_tag!r}", " _kernel(" + ", ".join(call_args) + ").launch(grid=(grid_x, grid_y, grid_z), block=(block_x, block_y, block_z))", ) ) source = "\n".join( [ "@cute.jit", f"def {func_name}({', '.join(params)}) -> None:", *body, ] ) namespace: dict[str, Any] = { "cutlass": cutlass, "cute": cute, "_kernel": cute_kernel, } filename = f"<helion_cute_launcher:{kernel_tag}:{schema_key!r}>" linecache.cache[filename] = ( len(source), None, [line + "\n" for line in source.splitlines()], filename, ) exec(compile(source, filename, "exec"), namespace) return namespace[func_name] def _get_compiled_cute_launcher( cute_kernel: object, schema_key: tuple[tuple[object, ...], ...], launch_args: tuple[object, ...], ) -> object: try: # pyrefly: ignore [missing-attribute] cache = cute_kernel._helion_cute_compiled_launchers except AttributeError: cache = {} # pyrefly: ignore [missing-attribute] cute_kernel._helion_cute_compiled_launchers = cache cached = cache.get(schema_key) if cached is not None: return cached wrapper = _create_cute_wrapper(cute_kernel, schema_key) cache[schema_key] = wrapper return wrapper def _build_cute_schema_and_args( args: tuple[object, ...], grid: tuple[int, int, int], block: tuple[int, int, int], ) -> tuple[tuple[tuple[object, ...], ...], tuple[object, ...]]: _patch_cutlass_jit_shutdown_unload() import cutlass.cute as cute from cutlass.cute.runtime import make_ptr schema: list[tuple[object, ...]] = [] launch_args: list[object] = [] for arg in args: if isinstance(arg, torch.Tensor): if arg.device.type != "cuda": raise exc.BackendUnsupported("cute", "launcher requires CUDA tensors") if arg.ndim <= 0: raise exc.BackendUnsupported( "cute", "launcher requires tensor rank >= 1" ) schema.append(("tensor", str(arg.dtype), arg.ndim)) launch_args.append( make_ptr( cast("Any", _torch_dtype_to_cutlass(arg.dtype)), arg.data_ptr(), cute.AddressSpace.gmem, assumed_align=16, ) ) launch_args.extend(int(arg.size(d)) for d in range(arg.ndim)) launch_args.extend(int(arg.stride(d)) for d in range(arg.ndim)) continue scalar_kind, scalar_value = _normalize_cute_scalar(arg) schema.append(("scalar", scalar_kind)) launch_args.append(scalar_value) launch_args.extend((*grid, *block)) return tuple(schema), tuple(launch_args) def default_cute_launcher( cute_kernel: object, grid: tuple[int, ...], *args: object, **kwargs: object, ) -> object: block = kwargs.pop("block", (256, 1, 1)) if not isinstance(block, tuple) or len(block) < 1: raise ValueError(f"Invalid block specification: {block}") if not isinstance(grid, tuple) or len(grid) < 1: raise ValueError(f"Invalid grid specification: {grid}") if kwargs: raise exc.BackendUnsupported("cute", f"launcher kwargs: {sorted(kwargs)}") grid_xyz = ( int(grid[0]), int(grid[1]) if len(grid) > 1 else 1, int(grid[2]) if len(grid) > 2 else 1, ) block_xyz = ( int(block[0]), int(block[1]) if len(block) > 1 else 1, int(block[2]) if len(block) > 2 else 1, ) if any(dim <= 0 for dim in grid_xyz): return None schema_key, launch_args = _build_cute_schema_and_args( tuple(args), grid_xyz, block_xyz ) compiled = _get_compiled_cute_launcher(cute_kernel, schema_key, launch_args) return cast("Any", compiled)(*launch_args) def default_metal_launcher( metal_kernel: object, grid: tuple[int, ...], *args: object, _block_size: int = 256, **kwargs: object, ) -> None: """Default launcher for Metal kernels on Apple MPS devices. Compiles MSL source via ``torch.mps.compile_shader()`` and dispatches using the compiled library. Caches the compiled library on the kernel object to avoid recompilation on subsequent calls. Only 1D grids are currently supported. """ kwargs.pop("num_warps", None) kwargs.pop("num_stages", None) if kwargs: raise exc.BackendUnsupported( "metal", f"unexpected launcher kwargs: {sorted(kwargs)}" ) assert len(grid) == 1, ( f"Metal launcher only supports 1D grids, got {len(grid)}D: {grid}" ) msl_source, kernel_name = metal_kernel() # type: ignore[operator] source_hash = hashlib.sha256(msl_source.encode()).digest() cache = getattr(metal_kernel, "_metal_cache", None) if cache is not None and cache[0] == source_hash: lib = cache[1] else: lib = torch.mps.compile_shader(msl_source) # type: ignore[attr-defined] metal_kernel._metal_cache = (source_hash, lib) # type: ignore[attr-defined] tensor_args = [a for a in args if isinstance(a, torch.Tensor)] dispatch_fn = getattr(lib, kernel_name) total_threads = grid[0] * _block_size dispatch_fn(*tensor_args, threads=total_threads, group_size=_block_size)