Rate this Page

Source code for helion.runtime

from __future__ import annotations

from contextlib import suppress
import contextvars
import inspect
import linecache
import os
import sys
from typing import TYPE_CHECKING
from typing import Any
from typing import cast

import torch

from .. import _compat as _compat  # ensure Triton compatibility patches run
from .. import exc
from .._utils import triton_is_available
from .config import Config as Config
from .kernel import Kernel as Kernel
from .kernel import kernel as kernel

if TYPE_CHECKING:
    from collections.abc import Callable

_CUTLASS_SHUTDOWN_PATCHED = False


def _patch_cutlass_jit_shutdown_unload() -> None:
    """Avoid CUDA library unload hangs during interpreter shutdown.

    On current CUTLASS DSL builds, ``CudaDialectJitModule.__del__`` unconditionally
    calls ``cudaLibraryUnload``. On B200 this can hang during Python finalization
    after a CuTe kernel has already finished executing. Skipping that unload during
    interpreter teardown lets the process exit cleanly while preserving the normal
    unload path during regular runtime GC.
    """

    global _CUTLASS_SHUTDOWN_PATCHED
    if _CUTLASS_SHUTDOWN_PATCHED:
        return

    try:
        import cutlass.cutlass_dsl.cuda_jit_executor as cuda_jit_executor
    except ImportError:
        return

    module_type = cuda_jit_executor.CudaDialectJitModule
    if getattr(module_type, "_helion_shutdown_patch", False):
        _CUTLASS_SHUTDOWN_PATCHED = True
        return

    original_del = cast("Any", module_type.__del__)

    def _helion_del(self: object) -> None:
        module = cast("Any", self)
        if sys.is_finalizing():
            with suppress(Exception):
                module._unloaded = True
            return
        original_del(module)

    module_type.__del__ = _helion_del
    module_type._helion_shutdown_patch = True
    _CUTLASS_SHUTDOWN_PATCHED = True


if triton_is_available():
    import triton

    def _alloc_fn(size: int, alignment: int, stream: int | None) -> torch.Tensor:
        # Dynamically get device from Triton backend
        current_target = triton.runtime.driver.active.get_current_target()
        if current_target is None:
            raise RuntimeError("No active Triton target available")
        backend = current_target.backend
        return torch.empty(size, device=backend, dtype=torch.int8)

    def set_triton_allocator() -> None:
        try:
            from triton import set_allocator
            from triton.runtime._allocation import NullAllocator
            from triton.runtime._allocation import _allocator
        except ImportError:
            return
        if isinstance(_allocator, contextvars.ContextVar):
            existing = _allocator.get()
        else:  # older versions of Triton
            existing = _allocator
        # if allocator isn't NullAllocator, we assume it is set by the user
        if isinstance(existing, NullAllocator):
            set_allocator(_alloc_fn)
else:

[docs] def set_triton_allocator() -> None: # type: ignore[misc] pass
[docs] def get_num_sm(device: torch.device, *, reserved_sms: int = 0) -> int: """ Get the number of streaming multiprocessors (SMs) for the specified device. Args: device: Device to query. reserved_sms: Number of SMs to keep free for other work (e.g., communication kernels). Defaults to 0 meaning all device SMs are available to Helion. Returns: Grid size to use for a persistent kernel on the device after accounting for any reserved SMs. Always at least 1. """ available_sms: int assert device.type in [ "cuda", "xpu", "mtia", "mps", ], "TODO: implement for other devices" if device.type == "cuda": available_sms = torch.cuda.get_device_properties( device.index ).multi_processor_count # TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number. elif device.type == "xpu": available_sms = torch.xpu.get_device_properties(device.index).gpu_subslice_count elif device.type == "mps": available_sms = torch.backends.mps.get_core_count() elif device.type == "mtia": device_props = torch.mtia.get_device_properties(device.index) if "max_grid_height" in device_props and "max_grid_width" in device_props: available_sms = ( device_props["max_grid_height"] * device_props["max_grid_width"] ) else: raise RuntimeError( f"Unable to determine SM count for MTIA device. " f"Available properties: {list(device_props.keys())}" ) else: raise NotImplementedError( f"get_num_sm not implemented for device type: {device.type}" ) if reserved_sms <= 0: return available_sms return max(available_sms - reserved_sms, 1)
def default_launcher( triton_kernel: object, grid: tuple[int, ...], *args: object, num_warps: int, num_stages: int, ptx_options: str | None = None, launch_cooperative_grid: bool = False, **kwargs: dict, ) -> object: """Default launcher function that executes the kernel immediately.""" # For both CUDA and MTIA, use the same kernel execution run_kwargs: dict = { "grid": grid, "warmup": False, "num_warps": num_warps, "num_stages": num_stages, "launch_cooperative_grid": launch_cooperative_grid, **kwargs, } if ptx_options is not None: run_kwargs["ptx_options"] = ptx_options try: return triton_kernel.run( # type: ignore[union-attr] *args, **run_kwargs, ) except Exception as error: message = str(error) if "Cannot make_shape_compatible: incompatible dimensions" in message: raise exc.ShapeMismatch("kernel operands", message) from error raise def _pallas_make_block_spec( pl: object, jnp: object, pltpu: object, tensor: torch.Tensor, entry: tuple[tuple[int | None, ...], tuple[int | tuple[int, int, int] | None, ...]] | None, should_use_smem: bool = False, ) -> object: """Build one ``pl.BlockSpec`` from compile-time ``(block_shape, grid_dims)``.""" memory_space = None # default value (pallas will default to VMEM) if should_use_smem: # pyrefly: ignore[missing-attribute] memory_space = pltpu.SMEM if entry is None: ndim = tensor.ndim full_shape = tuple(tensor.shape) def index_map_full(*grid_args: object, _nd: int = ndim) -> tuple[object, ...]: # pyrefly: ignore[missing-attribute] return tuple(jnp.int32(0) for _ in range(_nd)) return pl.BlockSpec(full_shape, index_map_full, memory_space=memory_space) # type: ignore[union-attr] block_shape_template, grid_dims = entry block_shape = tuple( min(bs, tensor.shape[d]) if bs is not None else tensor.shape[d] for d, bs in enumerate(block_shape_template) ) def _index_for_dim( grid_args: tuple[object, ...], g: int | tuple[int, int, int] | None, jnp: object = jnp, ) -> object: if g is None: return jnp.int32(0) # pyrefly: ignore[missing-attribute] if isinstance(g, tuple): # Flat grid decomposition: (grid_dim, stride, num_blocks) grid_dim, stride, num_blocks = g val = grid_args[grid_dim] if stride > 1: val = val // stride # type: ignore[operator] val = val % num_blocks # type: ignore[operator] return jnp.int32(val) # pyrefly: ignore[missing-attribute] return jnp.int32(grid_args[g]) # pyrefly: ignore[missing-attribute] def index_map( *grid_args: object, _grid_dims: tuple[int | tuple[int, int, int] | None, ...] = grid_dims, ) -> tuple[object, ...]: return tuple(_index_for_dim(grid_args, g) for g in _grid_dims) return pl.BlockSpec(block_shape, index_map, memory_space=memory_space) # type: ignore[union-attr] _CACHED_VMEM_LIMIT_BYTES: int | None = None def _get_vmem_limit_bytes(pltpu: object) -> int: """Safely retrieves the TPU VMEM capacity without crashing on hardware locks.""" global _CACHED_VMEM_LIMIT_BYTES if _CACHED_VMEM_LIMIT_BYTES is not None: return _CACHED_VMEM_LIMIT_BYTES try: get_tpu_info = pltpu.get_tpu_info # pyrefly: ignore[missing-attribute] _CACHED_VMEM_LIMIT_BYTES = get_tpu_info().vmem_capacity_bytes except Exception: # Fallback if JAX fails to acquire the TPU backend lock (e.g., in a precompile fork). # Default to 16MB (safe baseline for v4 and v5e per-core VMEM). _CACHED_VMEM_LIMIT_BYTES = 16 * 1024 * 1024 return _CACHED_VMEM_LIMIT_BYTES def _estimate_pallas_vmem_bytes( pl: object, pltpu: object, in_specs: list[object] | None, out_specs: list[object] | object | None, scratch_shapes: list[object] | list[Any] | None, args: tuple[object, ...], tensor_arg_indices: list[int], output_indices: list[int], pallas_aliases: dict[int, int] | None, ) -> int: """Estimates the VMEM required by the Pallas kernel.""" total_bytes = 0 in_spec_bytes = [0] * len(tensor_arg_indices) out_spec_bytes = [0] * len(output_indices) def _bytes_per_element(t: object) -> int: import torch if isinstance(t, torch.Tensor): return t.element_size() dtype = getattr(t, "dtype", None) if dtype is not None: # Works for torch.dtype and np.dtype/jnp.dtype itemsize = getattr(dtype, "itemsize", None) if itemsize is not None: return itemsize return 4 if in_specs: for i, idx in enumerate(tensor_arg_indices): spec = in_specs[i] # pl.BlockSpec will have block_shape and memory_space. # HBM is pl.ANY. We only count VMEM (which is not pl.ANY). if spec is not None and getattr(spec, "memory_space", None) is not getattr( pl, "ANY", None ): block_shape = getattr(spec, "block_shape", None) if block_shape is not None: numel = 1 for d in block_shape: numel *= int(d) in_spec_bytes[i] = numel * _bytes_per_element(args[idx]) if out_specs: out_specs_list = ( out_specs if isinstance(out_specs, (list, tuple)) else [out_specs] ) for i, idx in enumerate(output_indices): if i < len(out_specs_list): spec = out_specs_list[i] if spec is not None and getattr( spec, "memory_space", None ) is not getattr(pl, "ANY", None): block_shape = getattr(spec, "block_shape", None) if block_shape is not None: numel = 1 for d in block_shape: numel *= int(d) out_spec_bytes[i] = numel * _bytes_per_element(args[idx]) pallas_aliases = pallas_aliases or {} aliased_out_positions = set() for in_pos, out_pos in pallas_aliases.items(): aliased_out_positions.add(out_pos) if in_pos < len(in_spec_bytes) and out_pos < len(out_spec_bytes): in_spec_bytes[in_pos] = max(in_spec_bytes[in_pos], out_spec_bytes[out_pos]) for out_pos in aliased_out_positions: if out_pos < len(out_spec_bytes): out_spec_bytes[out_pos] = 0 # Pallas pipelines and default launchers natively double buffer their BlockSpecs. multiplier = 2 total_bytes += sum(in_spec_bytes) * multiplier total_bytes += sum(out_spec_bytes) * multiplier if scratch_shapes: for scratch in scratch_shapes: if type(scratch).__name__ == "VMEM": numel = 1 shape = getattr(scratch, "shape", ()) for d in shape: numel *= int(d) dtype_size = getattr(getattr(scratch, "dtype", None), "itemsize", 4) total_bytes += numel * dtype_size return total_bytes # Per-tensor block spec info: see ``_pallas_make_block_spec``. # grid_dims entries are int (direct grid dim), tuple (flat decomposition), # or None (untiled dim). _BlockSpecInfo = list[ tuple[tuple[int | None, ...], tuple[int | tuple[int, int, int] | None, ...]] | None ] def _pallas_build_block_specs( pl: object, jnp: object, pltpu: object, grid: tuple[int, ...], args: tuple[object, ...], tensor_arg_indices: list[int], output_indices: list[int], block_spec_info: _BlockSpecInfo | None = None, _smem_arg_indices: list[int] | None = None, output_only_indices: list[int] | None = None, ) -> tuple[list[object] | None, object | None]: """Build ``in_specs`` and ``out_specs`` for ``pl.pallas_call``. ``block_spec_info`` is indexed by position among *all* tensor args. ``output_only_indices`` lists tensor positions excluded from ``tensor_arg_indices``; they are merged back to compute the mapping. """ if block_spec_info is None or len(grid) == 0: return None, None all_positions = sorted(set(tensor_arg_indices) | set(output_only_indices or [])) all_arg_to_tensor_pos = {orig: tpos for tpos, orig in enumerate(all_positions)} in_specs = [] for idx in tensor_arg_indices: t = args[idx] assert isinstance(t, torch.Tensor) tensor_pos = all_arg_to_tensor_pos[idx] should_use_smem = tensor_pos in (_smem_arg_indices or []) in_specs.append( _pallas_make_block_spec( pl, jnp, pltpu, t, block_spec_info[tensor_pos], should_use_smem ) ) out_specs_list = [] for idx in output_indices: t = args[idx] assert isinstance(t, torch.Tensor) tensor_pos = all_arg_to_tensor_pos[idx] should_use_smem = tensor_pos in (_smem_arg_indices or []) out_specs_list.append( _pallas_make_block_spec( pl, jnp, pltpu, t, block_spec_info[tensor_pos], should_use_smem, ) ) out_specs = out_specs_list if len(out_specs_list) > 1 else out_specs_list[0] return in_specs, out_specs def _pallas_build_pipeline_specs( pl: object, jnp: object, pltpu: object, grid: tuple[int, ...], args: tuple[object, ...], tensor_arg_indices: list[int], output_indices: list[int], block_spec_info: _BlockSpecInfo, pipeline_arg_indices: list[int] | None, output_only_indices: list[int] | None = None, smem_arg_indices: list[int] | None = None, ) -> tuple[list[object], object]: """Build in/out specs for pipeline launchers. Pipeline-body tensors (listed in *pipeline_arg_indices*) get HBM refs. All other tensors get proper BlockSpecs for automatic VMEM prefetch. Tensors in *smem_arg_indices* (only ever accessed by scalar index, e.g. group offset tables) are placed in SMEM so dynamic scalar reads don't require 128-lane alignment proofs against a small VMEM ref. """ pipeline_set = set(pipeline_arg_indices or []) smem_set = set(smem_arg_indices or []) all_positions = sorted(set(tensor_arg_indices) | set(output_only_indices or [])) arg_to_tpos = {orig: tpos for tpos, orig in enumerate(all_positions)} def _spec_for(idx: int) -> object: if idx in pipeline_set: return pl.BlockSpec(memory_space=pltpu.HBM) # type: ignore[union-attr] tpos = arg_to_tpos[idx] t = args[idx] assert isinstance(t, torch.Tensor) return _pallas_make_block_spec( pl, jnp, pltpu, t, block_spec_info[tpos], tpos in smem_set ) in_specs = [_spec_for(idx) for idx in tensor_arg_indices] out_specs_list = [_spec_for(idx) for idx in output_indices] out_specs = out_specs_list if len(out_specs_list) > 1 else out_specs_list[0] return in_specs, out_specs def _jax_placeholder_for_tensor(t: torch.Tensor) -> object: """Create a JAX ShapeDtypeStruct placeholder for a torch.Tensor. Used as a fallback when ``torch_tpu`` is not available (e.g. interpret mode on CPU). """ import jax from torch._inductor.runtime.runtime_utils import torch_dtype_to_jax_runtime jax_dtype = torch_dtype_to_jax_runtime(t.dtype) return jax.ShapeDtypeStruct(tuple(t.shape), jax_dtype) def _pallas_jnp_dtype_map() -> dict[str, object]: import jax.numpy as jnp return { "jnp.float32": jnp.float32, "jnp.float16": jnp.float16, "jnp.bfloat16": jnp.bfloat16, "jnp.int32": jnp.int32, "jnp.int16": jnp.int16, "jnp.int8": jnp.int8, "jnp.uint8": jnp.uint8, "jnp.bool_": jnp.bool_, } def _pallas_check_dtypes(args: tuple[object, ...]) -> None: """Raise if any tensor arg uses a dtype unsupported on TPU.""" from .._compiler.backend import _PALLAS_UNSUPPORTED_DTYPES for a in args: if isinstance(a, torch.Tensor) and a.dtype in _PALLAS_UNSUPPORTED_DTYPES: raise TypeError( f"Pallas/TPU does not support {a.dtype} tensors. " f"Cast to a 32-bit type before calling the kernel." ) def _pallas_prepare_args( args: tuple[object, ...], _output_indices: list[int], _inplace_indices: list[int] | None = None, ) -> tuple[ list[int], list[int], dict[int, object], int, dict[int, int], set[int], tuple[object, ...], dict[int, int], ]: """Extract and organize tensor/non-tensor args for Pallas launchers. Returns a tuple of: - tensor_arg_indices: positions of tensor args passed as pallas_call inputs - output_only_indices: positions of output-only tensors (excluded from inputs) - non_tensor_args: mapping of non-tensor arg positions to values - n_tensor_inputs: count of tensor inputs (excl. output-only) - arg_to_tensor_pos: mapping from original position to tensor-only position - inplace_positions: positions that are both input and output - out_shapes: JAX placeholders for output shapes """ from .settings import is_pallas_interpret if is_pallas_interpret(): placeholder_fn = _jax_placeholder_for_tensor else: from torch_tpu._internal.pallas.pallas import ( # pyrefly: ignore[missing-import] jax_placeholder, ) placeholder_fn = jax_placeholder output_set = set(_output_indices) inplace_set = set(_inplace_indices) if _inplace_indices is not None else output_set output_only = output_set - inplace_set all_tensor_positions = [ i for i in range(len(args)) if isinstance(args[i], torch.Tensor) ] output_only_indices = [i for i in all_tensor_positions if i in output_only] tensor_arg_indices = [i for i in all_tensor_positions if i not in output_only] non_tensor_args: dict[int, object] = { i: args[i] for i in range(len(args)) if not isinstance(args[i], torch.Tensor) } n_tensor_inputs = len(tensor_arg_indices) arg_to_tensor_pos = {orig: tpos for tpos, orig in enumerate(tensor_arg_indices)} inplace_positions = output_set & set(tensor_arg_indices) out_shapes = tuple(placeholder_fn(args[i]) for i in _output_indices) # type: ignore[arg-type] pallas_aliases = { arg_to_tensor_pos[orig_pos]: out_idx for out_idx, orig_pos in enumerate(_output_indices) if orig_pos in arg_to_tensor_pos } return ( tensor_arg_indices, output_only_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, inplace_positions, out_shapes, pallas_aliases, ) def _pallas_make_reordered_kernel( pallas_kernel: object, args: tuple[object, ...], tensor_arg_indices: list[int], non_tensor_args: dict[int, object], n_tensor_inputs: int, _output_indices: list[int], inplace_positions: set[int], arg_to_tensor_pos: dict[int, int], n_extra_refs: int = 0, skip_inplace_copy: set[int] | None = None, _smem_arg_indices: list[int] | None = None, ) -> object: """Create a wrapper kernel that reorders pallas_call refs to the original arg order. ``pallas_call`` provides refs as ``[inputs..., outputs...]``, but Helion kernels expect the original parameter order. When *n_extra_refs* > 0 (e.g. scratch buffers), those trailing refs are appended after the reordered args. *skip_inplace_copy* is a set of original-arg positions for which the initial ``out_ref[...] = in_ref[...]`` copy should be skipped. Used by pipeline/fori launchers for pipeline-body tensors backed by HBM refs where direct load/store is not allowed. """ _skip_copy = skip_inplace_copy or set() def reordered_kernel(*refs: object) -> None: n_kernel_params = len(args) original_order: list[object] = [None] * n_kernel_params for tensor_pos, orig_pos in enumerate(tensor_arg_indices): original_order[orig_pos] = refs[tensor_pos] for orig_pos, value in non_tensor_args.items(): original_order[orig_pos] = value for out_idx, orig_pos in enumerate(_output_indices): out_ref = refs[n_tensor_inputs + out_idx] if orig_pos in inplace_positions and orig_pos not in _skip_copy: in_ref = refs[arg_to_tensor_pos[orig_pos]] if _smem_arg_indices is not None and orig_pos in _smem_arg_indices: # [...] cannot be used for SMEMs, # TODO(dunfanlu): handle in-place copy for SMEM refs pass else: out_ref[...] = in_ref[...] # type: ignore[index] original_order[orig_pos] = out_ref extra_refs = refs[n_tensor_inputs + len(_output_indices) :] pallas_kernel(*original_order, *extra_refs) # type: ignore[operator] return reordered_kernel def _pallas_build_callable( pallas_kernel: object, grid: tuple[int, ...], jit_fn: Callable[..., object], _output_indices: list[int], arg_to_tensor_pos: dict[int, int], tensor_arg_indices: list[int], cache_attr: str, call_aliases: dict[int, int], trace_key_suffix: str = "", ) -> object: """Build a ``JaxCallable``, cache it on the kernel, and return it. When ``torch_tpu`` is available, wraps the function in a ``JaxCallable`` for efficient torch<->JAX interop. Otherwise (interpret mode on CPU), returns a thin wrapper that converts tensors manually. """ def _make_interpret_callable() -> _PallasInterpretCallable: # Map (out_idx in _output_indices) -> tensor_pos for inplace outputs. # out_idx must match jax_results ordering (all outputs), not filtered. inplace_output_mapping = [ (out_idx, arg_to_tensor_pos[orig_pos]) for out_idx, orig_pos in enumerate(_output_indices) if orig_pos in arg_to_tensor_pos ] callable_obj = _PallasInterpretCallable(jit_fn, inplace_output_mapping) setattr( pallas_kernel, cache_attr, (grid, callable_obj, tensor_arg_indices, arg_to_tensor_pos), ) return callable_obj if _pallas_interpret_flag(): return _make_interpret_callable() import jax from torch_tpu._internal.pallas.pallas import ( # pyrefly: ignore[missing-import] JaxCallable, ) kernel_name = getattr(pallas_kernel, "__name__", "pallas_kernel") jax.config.update("jax_export_ignore_forward_compatibility", True) jax_callable = JaxCallable( name=kernel_name, jit_fn=jax.jit(jit_fn), trace_key=f"{kernel_name}_{id(pallas_kernel)}_{grid}{trace_key_suffix}", input_output_aliases=call_aliases, ) setattr( pallas_kernel, cache_attr, (grid, jax_callable, tensor_arg_indices, arg_to_tensor_pos), ) return jax_callable class _PallasInterpretCallable: """Thin wrapper that converts torch tensors <-> JAX arrays for interpret mode. In interpret mode, ``pallas_call`` runs on CPU and returns JAX arrays. This wrapper: 1. Converts input torch tensors to JAX arrays 2. Runs the pallas_call function 3. For inplace outputs (donated tensors): copies JAX results back into the original torch tensors via ``copy_()`` 4. Returns raw JAX results so ``_pallas_invoke_and_return`` can handle output-only tensors (which are not in the input list) ``inplace_output_mapping`` maps each inplace output to its JAX result: a list of ``(out_idx, tensor_pos)`` where ``out_idx`` indexes into ``jax_results`` and ``tensor_pos`` indexes into ``input_tensors``. """ def __init__( self, jit_fn: Callable[..., object], inplace_output_mapping: list[tuple[int, int]], ) -> None: self._jit_fn = jit_fn self._inplace_output_mapping = inplace_output_mapping def __call__(self, *input_tensors: torch.Tensor) -> tuple[object, ...]: jax_inputs = [_torch_to_jax(t) for t in input_tensors] jax_results = self._jit_fn(*jax_inputs) # type: ignore[operator] if not isinstance(jax_results, (tuple, list)): jax_results = (jax_results,) # Write inplace results back into the original output tensors. for out_idx, tensor_pos in self._inplace_output_mapping: out_tensor = input_tensors[tensor_pos] result_data = _jax_to_torch( jax_results[out_idx], device=out_tensor.device, dtype=out_tensor.dtype ) out_tensor.copy_(result_data) # Return JAX results so output-only tensors can be handled # by _pallas_invoke_and_return. return tuple(jax_results) def _pallas_interpret_flag() -> bool: """Return True if ``HELION_PALLAS_INTERPRET=1`` is set. As a side effect, registers a synthetic CPU TpuInfo entry so that ``emit_pipeline`` / ``fori_loop`` interpret paths don't fail. """ from .settings import is_pallas_interpret result = is_pallas_interpret() if result: _ensure_cpu_tpu_info() return result def _ensure_cpu_tpu_info() -> None: """Register a synthetic TpuInfo for ``"cpu"`` so that ``emit_pipeline`` / ``fori_loop`` interpret paths don't fail. """ try: from jax._src.pallas.mosaic.tpu_info import ChipVersion from jax._src.pallas.mosaic.tpu_info import _get_tpu_info_impl from jax._src.pallas.mosaic.tpu_info import registry except ImportError: return if "cpu" not in registry: registry["cpu"] = lambda: _get_tpu_info_impl(ChipVersion.TPU_7X, 1) def _pallas_invoke_and_return( jax_callable: object, args: tuple[object, ...], tensor_arg_indices: list[int], arg_to_tensor_pos: dict[int, int], _output_indices: list[int], _ds_pad_dims: list[tuple[int, int, int, int]] | None = None, _orig_output_tensors: dict[int, torch.Tensor] | None = None, ) -> object: """Run the JaxCallable and return output-only results. Output-only tensors (those not in ``arg_to_tensor_pos``) are not passed as pallas_call inputs, so the JaxCallable returns new buffers for them. Returns a single tensor, a tuple of tensors, or None. When ``_ds_pad_dims`` is provided, also handles: - Copying sliced results back into original (unpadded) in-place output tensors - Slicing padded output-only result tensors back to original shapes """ input_tensors = [ cast("torch.Tensor", args[i]).contiguous() for i in tensor_arg_indices ] results = jax_callable(*input_tensors) # type: ignore[operator] if results is None: return None if not isinstance(results, (tuple, list)): results = (results,) output_only_results = [] for out_idx, orig_pos in enumerate(_output_indices): if orig_pos not in arg_to_tensor_pos: result = results[out_idx] if not isinstance(result, torch.Tensor): # Interpret mode: pallas_call returns JAX arrays, convert to torch. # On TPU, JaxCallable returns torch tensors directly. out_tensor = cast("torch.Tensor", args[orig_pos]) # Output-only tensors are allocated with ``device='meta'`` to # avoid HBM; fall back to the first real input's device in # interpret mode so the converted tensor lands somewhere real. device = out_tensor.device if device.type == "meta" and tensor_arg_indices: device = cast("torch.Tensor", args[tensor_arg_indices[0]]).device result = _jax_to_torch( result, device=device, dtype=out_tensor.dtype, ) output_only_results.append(result) # Handle padding copy-back and result slicing if _ds_pad_dims and _orig_output_tensors: # _ds_pad_dims contains (arg_idx, dim, block_size, extra_pad). # Build a map from arg_idx → [(dim, ...)] for padded output args. padded_dims_by_arg: dict[int, list[int]] = {} for arg_idx, dim, _bs, _extra in _ds_pad_dims: if arg_idx in _orig_output_tensors: padded_dims_by_arg.setdefault(arg_idx, []).append(dim) # Copy sliced results back into original in-place output tensors. # Skip output-only tensors (not in arg_to_tensor_pos) — their # results come from output_only_results, not from args. for arg_idx, orig_tensor in _orig_output_tensors.items(): if arg_idx not in arg_to_tensor_pos: continue dims = padded_dims_by_arg.get(arg_idx) if not dims: continue padded = cast("torch.Tensor", args[arg_idx]) slices = [slice(None)] * padded.ndim for dim in dims: slices[dim] = slice(None, orig_tensor.shape[dim]) orig_tensor.copy_(padded[tuple(slices)]) # Slice padded output-only results back to original shapes if output_only_results: compacted_idx = 0 for orig_pos in _output_indices: if orig_pos not in arg_to_tensor_pos: orig = _orig_output_tensors.get(orig_pos) dims = padded_dims_by_arg.get(orig_pos) if ( orig is not None and dims and compacted_idx < len(output_only_results) ): t = output_only_results[compacted_idx] if isinstance(t, torch.Tensor): slices = [slice(None)] * t.ndim for dim in dims: slices[dim] = slice(None, orig.shape[dim]) output_only_results[compacted_idx] = t[tuple(slices)] compacted_idx += 1 if len(output_only_results) == 1: return output_only_results[0] return tuple(output_only_results) if output_only_results else None def _pallas_apply_ds_padding( args: tuple[object, ...], _output_indices: list[int], _ds_pad_dims: list[tuple[int, int, int, int]], ) -> tuple[tuple[object, ...], dict[int, torch.Tensor]]: """Pad tensor args so ``pl.ds(offset, block_size)`` never reads OOB. ``_ds_pad_dims`` contains ``(arg_index, dim, block_size, extra_pad)`` tuples. The pad amount is ``(-tensor.shape[dim]) % block_size + extra_pad``, where *extra_pad* accounts for non-zero loop begins. Returns the padded args tuple and a dict mapping output arg indices to their original (unpadded) tensors for post-call copy-back. """ args_list = list(args) orig_output_tensors: dict[int, torch.Tensor] = {} output_set = set(_output_indices) for arg_idx, dim, block_size, extra_pad in _ds_pad_dims: a = args_list[arg_idx] if not isinstance(a, torch.Tensor): continue pad_amount = (-a.shape[dim]) % block_size + extra_pad if pad_amount == 0: continue if arg_idx in output_set and arg_idx not in orig_output_tensors: orig_output_tensors[arg_idx] = a pad_widths = [0] * (2 * a.ndim) pad_widths[2 * (a.ndim - 1 - dim) + 1] = pad_amount args_list[arg_idx] = torch.nn.functional.pad(a, pad_widths) return tuple(args_list), orig_output_tensors def default_pallas_launcher( pallas_kernel: object, grid: tuple[int, ...], *args: object, _output_indices: list[int] | None = None, _inplace_indices: list[int] | None = None, _block_spec_info: _BlockSpecInfo | None = None, _smem_arg_indices: list[int] | None = None, _ds_pad_dims: list[tuple[int, int, int, int]] | None = None, **kwargs: object, ) -> object: """Default launcher for Pallas kernels on TPU (or CPU with interpret=True). Uses ``JaxCallable`` from ``torch_tpu`` to compile and run the Pallas kernel on TPU. When ``torch_tpu`` is not available (interpret mode), falls back to direct torch<->JAX conversion. Output tensors are donated via ``input_output_aliases`` so the kernel writes directly into their buffers (zero-copy on TPU). Output-only tensors (in ``_output_indices`` but not in ``_inplace_indices``) are excluded from pallas_call inputs to save VMEM. Their results are returned as torch tensors. """ if _output_indices is None: _output_indices = [] _orig_output_tensors: dict[int, torch.Tensor] | None = None if _ds_pad_dims: args, _orig_output_tensors = _pallas_apply_ds_padding( args, _output_indices, _ds_pad_dims ) _pallas_check_dtypes(args) cache = getattr(pallas_kernel, "_pallas_cache", None) if cache is not None and cache[0] == grid: _, jax_callable, tensor_arg_indices, arg_to_tensor_pos = cache else: from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp ( tensor_arg_indices, output_only_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, inplace_positions, out_shapes, pallas_aliases, ) = _pallas_prepare_args(args, _output_indices, _inplace_indices) in_specs, out_specs = _pallas_build_block_specs( pl, jnp, pltpu, grid, args, tensor_arg_indices, _output_indices, _block_spec_info, _smem_arg_indices, output_only_indices, ) reordered_kernel = _pallas_make_reordered_kernel( pallas_kernel, args, tensor_arg_indices, non_tensor_args, n_tensor_inputs, _output_indices, inplace_positions, arg_to_tensor_pos, _smem_arg_indices=_smem_arg_indices, ) out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] estimated_vmem = _estimate_pallas_vmem_bytes( pl, pltpu, in_specs, out_specs, None, args, tensor_arg_indices, _output_indices, pallas_aliases, ) vmem_limit_bytes = _get_vmem_limit_bytes(pltpu) if estimated_vmem > vmem_limit_bytes: raise RuntimeError( f"XLA:TPU compile permanent error. Ran out of memory in memory space vmem. " f"Estimated {estimated_vmem / 1e6:.2f}MB exceeds {vmem_limit_bytes / 1e6:.2f}MB vmem capacity." ) pallas_call_kwargs: dict[str, object] = { "out_shape": out_shape_arg, "grid": grid, } if _pallas_interpret_flag(): pallas_call_kwargs["interpret"] = True if in_specs is not None: pallas_call_kwargs["in_specs"] = in_specs pallas_call_kwargs["out_specs"] = out_specs jit_fn = pl.pallas_call( reordered_kernel, # pyrefly: ignore[bad-argument-type] **pallas_call_kwargs, # type: ignore[arg-type] ) jax_callable = _pallas_build_callable( pallas_kernel, grid, jit_fn, _output_indices, arg_to_tensor_pos, tensor_arg_indices, cache_attr="_pallas_cache", call_aliases=pallas_aliases, ) return _pallas_invoke_and_return( jax_callable, args, tensor_arg_indices, arg_to_tensor_pos, _output_indices, _ds_pad_dims, _orig_output_tensors, ) def default_pallas_pipeline_launcher( pallas_kernel: object, grid: tuple[int, ...], *args: object, _output_indices: list[int] | None = None, _inplace_indices: list[int] | None = None, _block_spec_info: _BlockSpecInfo | None = None, _scratch_shapes: list[tuple[tuple[int, ...], str]] | None = None, _pipeline_arg_indices: list[int] | None = None, _ds_pad_dims: list[tuple[int, int, int, int]] | None = None, _smem_arg_indices: list[int] | None = None, **kwargs: object, ) -> object: """Launcher for Pallas kernels using PrefetchScalarGridSpec with scratch memory. Used when ``pallas_loop_type='emit_pipeline'``. Pipeline-body tensors (listed in ``_pipeline_arg_indices``) use HBM refs; all other tensors get proper BlockSpecs for automatic VMEM prefetch. """ if _output_indices is None: _output_indices = [] if _scratch_shapes is None: _scratch_shapes = [] _orig_output_tensors: dict[int, torch.Tensor] | None = None if _ds_pad_dims: args, _orig_output_tensors = _pallas_apply_ds_padding( args, _output_indices, _ds_pad_dims ) _pallas_check_dtypes(args) cache = getattr(pallas_kernel, "_pallas_pipeline_cache", None) if cache is not None and cache[0] == grid: _, jax_callable, tensor_arg_indices, arg_to_tensor_pos = cache else: from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp ( tensor_arg_indices, output_only_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, inplace_positions, out_shapes, pallas_aliases, ) = _pallas_prepare_args(args, _output_indices, _inplace_indices) # Build scratch shapes for VMEM _jnp_dtype_map = _pallas_jnp_dtype_map() scratch_shapes = [] for scratch_entry in _scratch_shapes: if len(scratch_entry) == 3: shape, dtype_str, scratch_type = scratch_entry else: shape, dtype_str = scratch_entry # type: ignore[misc] scratch_type = "vmem" if scratch_type == "dma_semaphore": scratch_shapes.append(pltpu.SemaphoreType.DMA(())) else: jnp_dtype = _jnp_dtype_map.get(dtype_str, jnp.float32) scratch_shapes.append( pltpu.VMEM(shape, jnp_dtype) # pyrefly: ignore[bad-argument-type] ) assert _block_spec_info is not None, ( "emit_pipeline launcher requires _block_spec_info from codegen" ) in_specs_list, out_specs = _pallas_build_pipeline_specs( pl, jnp, pltpu, grid, args, tensor_arg_indices, _output_indices, _block_spec_info, _pipeline_arg_indices, output_only_indices, smem_arg_indices=_smem_arg_indices, ) _pipeline_set = set(_pipeline_arg_indices or []) reordered_kernel = _pallas_make_reordered_kernel( pallas_kernel, args, tensor_arg_indices, non_tensor_args, n_tensor_inputs, _output_indices, inplace_positions, arg_to_tensor_pos, n_extra_refs=len(scratch_shapes), skip_inplace_copy=_pipeline_set, _smem_arg_indices=_smem_arg_indices, ) out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=in_specs_list, out_specs=out_specs, scratch_shapes=scratch_shapes, grid=grid, ) estimated_vmem = _estimate_pallas_vmem_bytes( pl, pltpu, in_specs_list, out_specs, scratch_shapes, args, tensor_arg_indices, _output_indices, pallas_aliases, ) vmem_limit_bytes = _get_vmem_limit_bytes(pltpu) if estimated_vmem > vmem_limit_bytes: raise RuntimeError( f"XLA:TPU compile permanent error. Ran out of memory in memory space vmem. " f"Estimated {estimated_vmem / 1e6:.2f}MB exceeds {vmem_limit_bytes / 1e6:.2f}MB vmem capacity." ) pallas_call_kwargs: dict[str, object] = { "out_shape": out_shape_arg, "grid_spec": grid_spec, "compiler_params": pltpu.CompilerParams( # pyrefly: ignore[bad-instantiation] dimension_semantics=tuple("parallel" for _ in grid), ), } if _pallas_interpret_flag(): pallas_call_kwargs["interpret"] = True jit_fn = pl.pallas_call( reordered_kernel, # pyrefly: ignore[bad-argument-type] **pallas_call_kwargs, # type: ignore[arg-type] ) jax_callable = _pallas_build_callable( pallas_kernel, grid, jit_fn, _output_indices, arg_to_tensor_pos, tensor_arg_indices, cache_attr="_pallas_pipeline_cache", call_aliases=pallas_aliases, trace_key_suffix="_pipeline", ) return _pallas_invoke_and_return( jax_callable, args, tensor_arg_indices, arg_to_tensor_pos, _output_indices, _ds_pad_dims, _orig_output_tensors, ) def default_pallas_fori_launcher( pallas_kernel: object, grid: tuple[int, ...], *args: object, _output_indices: list[int] | None = None, _inplace_indices: list[int] | None = None, _block_spec_info: _BlockSpecInfo | None = None, _scratch_shapes: list[tuple[tuple[int, ...], str | None, str]] | None = None, _ds_pad_dims: list[tuple[int, int, int, int]] | None = None, _smem_arg_indices: list[int] | None = None, **kwargs: object, ) -> object: """Launcher for Pallas kernels using fori_loop with manual DMA. Used when ``pallas_loop_type="fori_loop"``. Passes all tensors as ``memory_space=pl.ANY`` (HBM refs) and adds scratch buffers as ``pltpu.VMEM`` shapes plus ``pltpu.SemaphoreType.DMA`` for async copies. The kernel uses ``jax.lax.fori_loop`` with ``pltpu.make_async_copy`` internally for DMA control. """ if _output_indices is None: _output_indices = [] if _scratch_shapes is None: _scratch_shapes = [] _orig_output_tensors: dict[int, torch.Tensor] | None = None if _ds_pad_dims: args, _orig_output_tensors = _pallas_apply_ds_padding( args, _output_indices, _ds_pad_dims ) _pallas_check_dtypes(args) cache = getattr(pallas_kernel, "_pallas_fori_cache", None) if cache is not None and cache[0] == grid: _, jax_callable, tensor_arg_indices, arg_to_tensor_pos = cache else: from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp ( tensor_arg_indices, output_only_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, inplace_positions, out_shapes, pallas_aliases, ) = _pallas_prepare_args(args, _output_indices, _inplace_indices) # Build scratch shapes: VMEM buffers + DMA semaphores _jnp_dtype_map = _pallas_jnp_dtype_map() scratch_shapes = [] for shape, dtype_str, scratch_type in _scratch_shapes: if scratch_type == "dma_semaphore": scratch_shapes.append(pltpu.SemaphoreType.DMA(())) else: # "vmem" assert dtype_str is not None jnp_dtype = _jnp_dtype_map.get(dtype_str, jnp.float32) scratch_shapes.append( pltpu.VMEM(shape, jnp_dtype) # pyrefly: ignore[bad-argument-type] ) # Build in_specs/out_specs: proper BlockSpecs for outer grid dims, # HBM refs for tensors used in the fori_loop body (DMA handles tiling). _fori_pipeline_indices = kwargs.get("_pipeline_arg_indices") assert _block_spec_info is not None, ( "fori_loop launcher requires _block_spec_info from codegen" ) in_specs_list, out_specs = _pallas_build_pipeline_specs( pl, jnp, pltpu, grid, args, tensor_arg_indices, _output_indices, _block_spec_info, _fori_pipeline_indices, # type: ignore[arg-type] output_only_indices, smem_arg_indices=_smem_arg_indices, ) _fori_pipeline_set = set(_fori_pipeline_indices or []) # type: ignore[arg-type] reordered_kernel = _pallas_make_reordered_kernel( pallas_kernel, args, tensor_arg_indices, non_tensor_args, n_tensor_inputs, _output_indices, inplace_positions, arg_to_tensor_pos, n_extra_refs=len(scratch_shapes), skip_inplace_copy=_fori_pipeline_set, _smem_arg_indices=_smem_arg_indices, ) out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=in_specs_list, out_specs=out_specs, scratch_shapes=scratch_shapes, grid=grid, ) estimated_vmem = _estimate_pallas_vmem_bytes( pl, pltpu, in_specs_list, out_specs, scratch_shapes, args, tensor_arg_indices, _output_indices, pallas_aliases, ) vmem_limit_bytes = _get_vmem_limit_bytes(pltpu) if estimated_vmem > vmem_limit_bytes: raise RuntimeError( f"XLA:TPU compile permanent error. Ran out of memory in memory space vmem. " f"Estimated {estimated_vmem / 1e6:.2f}MB exceeds {vmem_limit_bytes / 1e6:.2f}MB vmem capacity." ) pallas_call_kwargs: dict[str, object] = { "out_shape": out_shape_arg, "grid_spec": grid_spec, "compiler_params": pltpu.CompilerParams( # pyrefly: ignore[bad-instantiation] dimension_semantics=tuple("parallel" for _ in grid), ), } if _pallas_interpret_flag(): pallas_call_kwargs["interpret"] = True jit_fn = pl.pallas_call( reordered_kernel, # pyrefly: ignore[bad-argument-type] **pallas_call_kwargs, # type: ignore[arg-type] ) jax_callable = _pallas_build_callable( pallas_kernel, grid, jit_fn, _output_indices, arg_to_tensor_pos, tensor_arg_indices, cache_attr="_pallas_fori_cache", call_aliases=pallas_aliases, trace_key_suffix="_fori", ) return _pallas_invoke_and_return( jax_callable, args, tensor_arg_indices, arg_to_tensor_pos, _output_indices, _ds_pad_dims, _orig_output_tensors, ) def _torch_to_jax(t: torch.Tensor) -> object: """Convert a torch.Tensor to a JAX array via numpy (for interpret mode on CPU).""" import jax.numpy as jnp import numpy as np return jnp.array(np.asarray(t.detach().cpu())) def _jax_to_torch( arr: object, *, device: torch.device, dtype: torch.dtype ) -> torch.Tensor: """Convert a JAX array back to a torch.Tensor via numpy (for interpret mode on CPU).""" import numpy as np return torch.from_numpy(np.asarray(arr)).to(dtype=dtype, device=device) def _torch_dtype_to_cutlass(dtype: torch.dtype) -> object: _patch_cutlass_jit_shutdown_unload() import cutlass mapping: dict[torch.dtype, object] = { torch.float16: cutlass.Float16, torch.float32: cutlass.Float32, torch.float64: cutlass.Float64, torch.bfloat16: cutlass.BFloat16, # CuTe does not support i1 global-memory tensors; torch.bool is stored # as one byte, so pass bool tensor pointers as uint8 and let load # lowering convert nonzero bytes back to cutlass.Boolean registers. torch.bool: cutlass.Uint8, torch.int8: cutlass.Int8, torch.int16: cutlass.Int16, torch.int32: cutlass.Int32, torch.int64: cutlass.Int64, torch.uint8: cutlass.Uint8, torch.uint64: cutlass.Int64, } if dtype not in mapping: raise exc.BackendUnsupported("cute", f"dtype: {dtype}") return mapping[dtype] def _normalize_cute_scalar(arg: object) -> tuple[str, object]: if isinstance(arg, (bool, torch.SymBool)): return ("bool", bool(arg)) if isinstance(arg, (int, torch.SymInt)): return ("int", int(arg)) if isinstance(arg, (float, torch.SymFloat)): return ("float", float(arg)) raise exc.BackendUnsupported("cute", f"launcher scalar argument type: {type(arg)}") def _cute_scalar_annotation(kind: str) -> str: mapping = { "bool": "cutlass.Boolean", "int": "cutlass.Int64", "float": "cutlass.Float32", } return mapping[kind] def _cute_kernel_param_is_constexpr(cute_kernel: object) -> tuple[bool, ...]: """Return per-parameter Constexpr flags for a ``@cute.kernel``. Cached on the kernel object to avoid repeated signature inspection. The newer cutlass DSL (>=4.5) enforces region isolation: a runtime scalar passed through the wrapper cannot satisfy a kernel parameter declared as ``cutlass.Constexpr``. When the wrapper sees a Constexpr-typed kernel parameter, it must propagate the value as a Constexpr (i.e., baked into the compiled wrapper) rather than as a runtime ``cutlass.Int64``. """ cached = getattr(cast("Any", cute_kernel), "_helion_cute_param_constexpr", None) if cached is not None: return cast("tuple[bool, ...]", cached) import cutlass try: sig = inspect.signature(cute_kernel) # type: ignore[arg-type] except (TypeError, ValueError): flags: tuple[bool, ...] = () else: from typing import get_origin from typing import get_type_hints # Helion-emitted kernels use ``from __future__ import annotations`` so # ``param.annotation`` is the source string. ``get_type_hints`` resolves # those strings against the function's globals (which include # ``cutlass``). try: hints = get_type_hints(cute_kernel) # type: ignore[arg-type] except Exception: hints = {} flags_list: list[bool] = [] for name, param in sig.parameters.items(): ann = hints.get(name, param.annotation) is_constexpr = ann is cutlass.Constexpr or get_origin(ann) is ( cutlass.Constexpr ) flags_list.append(is_constexpr) flags = tuple(flags_list) with suppress(AttributeError, TypeError): cast("Any", cute_kernel)._helion_cute_param_constexpr = flags return flags def _append_cute_wrapper_plan( body: list[str], call_args: list[str], plan: dict[str, object], ) -> None: def plan_int(key: str, default: int | None = None) -> int: value = plan.get(key, default) if default is not None else plan[key] assert isinstance(value, int) return value kind = plan["kind"] if kind == "tcgen05_d_tma": d_idx = plan_int("d_idx") bm = plan_int("bm") bn = plan_int("bn") c_stage_count = plan_int("c_stage_count") output_dtype = str(plan["output_dtype"]) kernel_args = [str(arg) for arg in cast("list[object]", plan["kernel_args"])] assert len(kernel_args) == 2 tma_atom_d, tma_tensor_d = kernel_args epi_tile = f"{tma_atom_d}_epi_tile" smem_layout = f"{tma_atom_d}_smem_layout" cta_v_layout = f"{tma_atom_d}_cta_v_layout" # Keep these layout arguments in sync with the device-side # `make_smem_layout_epi` call in `_codegen_cute_store_tcgen05_tile`; # the TMA atom slices the same SMEM stage that the kernel allocates. body.extend( ( ( f" {epi_tile} = " "cutlass.utils.blackwell_helpers.compute_epilogue_tile_shape(" f"({bm}, {bn}), False, " "cutlass.utils.layout.LayoutEnum.ROW_MAJOR, " f"{output_dtype})" ), ( f" {smem_layout} = cutlass.utils.blackwell_helpers." "make_smem_layout_epi(" f"{output_dtype}, cutlass.utils.layout.LayoutEnum.ROW_MAJOR, " f"{epi_tile}, {c_stage_count})" ), ( f" {cta_v_layout} = cute.composition(" f"cute.make_identity_layout(arg{d_idx}.shape), {epi_tile})" ), ( f" {tma_atom_d}, {tma_tensor_d} = " "cute.nvgpu.cpasync.make_tiled_tma_atom(" "cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(), " f"arg{d_idx}, cute.slice_({smem_layout}, (None, None, 0)), " f"{cta_v_layout})" ), ) ) call_args.extend(kernel_args) return if kind != "tcgen05_ab_tma": raise exc.BackendUnsupported("cute", f"wrapper plan kind: {kind}") lhs_idx_key = "lhs_idx" if "lhs_idx" in plan else "lhsidx" rhs_idx_key = "rhs_idx" if "rhs_idx" in plan else "rhsidx" lhs_idx = plan_int(lhs_idx_key) rhs_idx = plan_int(rhs_idx_key) bm = plan_int("bm") bn = plan_int("bn") bk = plan_int("bk") cluster_m = plan_int("cluster_m", 1) cluster_n = plan_int("cluster_n", 1) input_dtype = str(plan["input_dtype"]) acc_dtype = str(plan["acc_dtype"]) ab_stage_count = plan_int("ab_stage_count", 2) kernel_args = [str(arg) for arg in cast("list[object]", plan["kernel_args"])] assert len(kernel_args) == 4 tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b = kernel_args cta_group = ( "cute.nvgpu.tcgen05.CtaGroup.TWO" if cluster_m * cluster_n == 2 and bm == 256 else "cute.nvgpu.tcgen05.CtaGroup.ONE" ) cluster_shape = f"({cluster_m}, {cluster_n}, 1)" tiled_mma = f"{tma_atom_a}_tiled_mma" cluster_layout_vmnk = f"{tma_atom_a}_cluster_layout_vmnk" smem_a_layout = f"{tma_atom_a}_smem_layout" smem_b_layout = f"{tma_atom_b}_smem_layout" rhs_tma = f"{tma_atom_b}_rhs_tma" body.extend( ( ( f" {tiled_mma} = cutlass.utils.blackwell_helpers.make_trivial_tiled_mma(" f"{input_dtype}, " "cute.nvgpu.tcgen05.OperandMajorMode.K, " "cute.nvgpu.tcgen05.OperandMajorMode.MN, " f"{acc_dtype}, " f"{cta_group}, " f"({bm}, {bn}), " "cute.nvgpu.tcgen05.OperandSource.SMEM)" ), ( f" {cluster_layout_vmnk} = cute.tiled_divide(" f"cute.make_layout({cluster_shape}), ({tiled_mma}.thr_id.shape,))" ), ( f" {smem_a_layout} = cutlass.utils.blackwell_helpers.make_smem_layout_a(" f"{tiled_mma}, ({bm}, {bn}, {bk}), {input_dtype}, {ab_stage_count})" ), ( f" {smem_b_layout} = cutlass.utils.blackwell_helpers.make_smem_layout_b(" f"{tiled_mma}, ({bm}, {bn}, {bk}), {input_dtype}, {ab_stage_count})" ), ( f" {rhs_tma} = cute.make_tensor(" f"arg{rhs_idx}.iterator, " "layout=cute.make_layout(" f"(arg{rhs_idx}_shape1, arg{rhs_idx}_shape0), " f"stride=(arg{rhs_idx}_stride1, arg{rhs_idx}_stride0)))" ), f" {rhs_tma}.mark_layout_dynamic(leading_dim=0)", ( f" {tma_atom_a}, {tma_tensor_a} = cute.nvgpu.make_tiled_tma_atom_A(" "cutlass.utils.blackwell_helpers.cluster_shape_to_tma_atom_A(" f"{cluster_shape}, {tiled_mma}.thr_id), " f"arg{lhs_idx}, " f"cute.slice_({smem_a_layout}, (None, None, None, 0)), " f"({bm}, {bn}, {bk}), {tiled_mma})" ), ( f" {tma_atom_b}, {tma_tensor_b} = cute.nvgpu.make_tiled_tma_atom_B(" "cutlass.utils.blackwell_helpers.cluster_shape_to_tma_atom_B(" f"{cluster_shape}, {tiled_mma}.thr_id), " f"{rhs_tma}, " f"cute.slice_({smem_b_layout}, (None, None, None, 0)), " f"({bm}, {bn}, {bk}), {tiled_mma}, {cluster_layout_vmnk}.shape)" ), ) ) call_args.extend(kernel_args) def _cute_cluster_shape_from_wrapper_plans( wrapper_plans: list[dict[str, object]], ) -> tuple[int, int, int] | None: cluster_m = 1 cluster_n = 1 for plan in wrapper_plans: if plan.get("kind") != "tcgen05_ab_tma": continue plan_cluster_m = plan.get("cluster_m", 1) plan_cluster_n = plan.get("cluster_n", 1) assert isinstance(plan_cluster_m, int) assert isinstance(plan_cluster_n, int) cluster_m = max(cluster_m, plan_cluster_m) cluster_n = max(cluster_n, plan_cluster_n) if cluster_m * cluster_n <= 1: return None return (cluster_m, cluster_n, 1) def _cute_cluster_shape( cute_kernel: object, wrapper_plans: list[dict[str, object]] ) -> tuple[int, int, int] | None: explicit_cluster_shape = getattr( cast("Any", cute_kernel), "_helion_cute_cluster_shape", None ) if explicit_cluster_shape is not None: if ( isinstance(explicit_cluster_shape, tuple) and len(explicit_cluster_shape) == 3 and all(isinstance(dim, int) for dim in explicit_cluster_shape) ): return cast("tuple[int, int, int]", explicit_cluster_shape) raise exc.BackendUnsupported( "cute", f"invalid _helion_cute_cluster_shape: {explicit_cluster_shape!r}", ) return _cute_cluster_shape_from_wrapper_plans(wrapper_plans) def _create_cute_wrapper( cute_kernel: object, schema_key: tuple[tuple[object, ...], ...], block: tuple[int, int, int], ) -> object: _patch_cutlass_jit_shutdown_unload() import cutlass import cutlass.cute as cute kernel_name = getattr(cast("Any", cute_kernel), "__name__", "cute_kernel") kernel_tag = f"{kernel_name}_{id(cute_kernel):x}" func_name = f"_helion_cute_launch_{kernel_tag}" params: list[str] = [] body: list[str] = [] call_args: list[str] = [] for i, entry in enumerate(schema_key): kind = entry[0] if kind == "tensor": (_, _dtype, rank) = entry assert isinstance(rank, int) ptr_name = f"arg{i}_ptr" params.append(f"{ptr_name}: cute.Pointer") shape_names = [f"arg{i}_shape{d}" for d in range(rank)] stride_names = [f"arg{i}_stride{d}" for d in range(rank)] params.extend(f"{name}: cutlass.Int64" for name in shape_names) params.extend(f"{name}: cutlass.Int64" for name in stride_names) shape_tuple = ( f"({shape_names[0]},)" if rank == 1 else f"({', '.join(shape_names)})" ) stride_tuple = ( f"({stride_names[0]},)" if rank == 1 else f"({', '.join(stride_names)})" ) body.append( f" arg{i} = cute.make_tensor({ptr_name}, layout=cute.make_layout({shape_tuple}, stride={stride_tuple}))" ) call_args.append(f"arg{i}") continue if kind == "scalar_constexpr": (_, scalar_kind, scalar_value) = entry assert isinstance(scalar_kind, str) literal = repr(scalar_value) body.append(f" arg{i} = {literal}") call_args.append(f"arg{i}") continue assert kind == "scalar" (_, scalar_kind) = entry assert isinstance(scalar_kind, str) scalar_name = f"arg{i}" params.append(f"{scalar_name}: {_cute_scalar_annotation(scalar_kind)}") call_args.append(scalar_name) params.extend( ( "grid_x: cutlass.Int32", "grid_y: cutlass.Int32", "grid_z: cutlass.Int32", ) ) wrapper_plans = [ cast("dict[str, object]", plan) for plan in getattr(cast("Any", cute_kernel), "_helion_cute_wrapper_plans", []) ] for plan in wrapper_plans: _append_cute_wrapper_plan(body, call_args, plan) launch_suffix = f", block={block!r}" cluster_shape = _cute_cluster_shape(cute_kernel, wrapper_plans) if cluster_shape is not None: launch_suffix += f", cluster={list(cluster_shape)!r}" body.extend( ( f" _helion_cute_kernel_tag = {kernel_tag!r}", " _kernel(" + ", ".join(call_args) + f").launch(grid=(grid_x, grid_y, grid_z){launch_suffix})", ) ) source = "\n".join( [ "@cute.jit", f"def {func_name}({', '.join(params)}) -> None:", *body, ] ) namespace: dict[str, Any] = { "cutlass": cutlass, "cute": cute, "_kernel": cute_kernel, } filename = f"<helion_cute_launcher:{kernel_tag}:{schema_key!r}:{block!r}>" linecache.cache[filename] = ( len(source), None, [line + "\n" for line in source.splitlines()], filename, ) exec(compile(source, filename, "exec"), namespace) return namespace[func_name] class _CompiledCuteLauncher: """Lazily compile a Helion ``@cute.jit`` wrapper via ``cute.compile``. The first call uses ``cute.compile(jit_func, *args)`` to produce a compiled callable; subsequent calls invoke the compiled callable directly. This bypasses the per-launch ``@cute.jit`` argument-handling/dispatch path, matching Quack's pattern (see ``gemm_tvm_ffi_utils.py``). On B200 this collapses ~200ms of per-launch host overhead into ~0.1ms. """ __slots__ = ("_compile_options", "_compiled", "_jit_func") def __init__(self, jit_func: object, compile_options: str | None) -> None: self._jit_func = jit_func self._compile_options = compile_options self._compiled: object = None def __call__(self, *args: object) -> object: compiled = self._compiled if compiled is None: import cutlass.cute as cute if self._compile_options is None: compiled = cute.compile(self._jit_func, *args) else: compiled = cute.compile( self._jit_func, *args, options=self._compile_options, ) self._compiled = compiled return cast("Any", compiled)(*args) def _get_compiled_cute_launcher( cute_kernel: object, schema_key: tuple[tuple[object, ...], ...], block: tuple[int, int, int], compile_options: str | None = None, ) -> object: try: # pyrefly: ignore [missing-attribute] cache = cute_kernel._helion_cute_compiled_launchers except AttributeError: cache = {} # pyrefly: ignore [missing-attribute] cute_kernel._helion_cute_compiled_launchers = cache wrapper_plans = tuple( repr(plan) for plan in getattr(cast("Any", cute_kernel), "_helion_cute_wrapper_plans", []) ) cluster_shape = getattr( cast("Any", cute_kernel), "_helion_cute_cluster_shape", None ) cache_key = ( schema_key, block, wrapper_plans, repr(cluster_shape), compile_options, ) cached = cache.get(cache_key) if cached is not None: return cached jit_func = _create_cute_wrapper(cute_kernel, schema_key, block) launcher = _CompiledCuteLauncher(jit_func, compile_options) cache[cache_key] = launcher return launcher def _build_cute_schema_and_args( cute_kernel: object, args: tuple[object, ...], grid: tuple[int, int, int], ) -> tuple[tuple[tuple[object, ...], ...], tuple[object, ...]]: _patch_cutlass_jit_shutdown_unload() import cutlass.cute as cute from cutlass.cute.runtime import make_ptr _ensure_cute_dsl_arch_env(args) constexpr_flags = _cute_kernel_param_is_constexpr(cute_kernel) schema: list[tuple[object, ...]] = [] launch_args: list[object] = [] for i, arg in enumerate(args): if isinstance(arg, torch.Tensor): if arg.device.type != "cuda": raise exc.BackendUnsupported("cute", "launcher requires CUDA tensors") if arg.ndim <= 0: raise exc.BackendUnsupported( "cute", "launcher requires tensor rank >= 1" ) schema.append(("tensor", str(arg.dtype), arg.ndim)) launch_args.append( make_ptr( cast("Any", _torch_dtype_to_cutlass(arg.dtype)), arg.data_ptr(), cute.AddressSpace.gmem, assumed_align=16, ) ) launch_args.extend(int(arg.size(d)) for d in range(arg.ndim)) launch_args.extend(int(arg.stride(d)) for d in range(arg.ndim)) continue scalar_kind, scalar_value = _normalize_cute_scalar(arg) is_constexpr = i < len(constexpr_flags) and constexpr_flags[i] if is_constexpr: # Bake Constexpr values into the wrapper / cache key. cutlass DSL # >=4.5 fails IR verification ("value defined outside the region") # if a runtime scalar is fed to a kernel parameter declared as # ``cutlass.Constexpr``. schema.append(("scalar_constexpr", scalar_kind, scalar_value)) else: schema.append(("scalar", scalar_kind)) launch_args.append(scalar_value) launch_args.extend(grid) return tuple(schema), tuple(launch_args) def _ensure_cute_dsl_arch_env(args: tuple[object, ...]) -> None: tensor_args = [arg for arg in args if isinstance(arg, torch.Tensor)] if tensor_args: device = tensor_args[0].device if device.type != "cuda": return with torch.cuda.device(device): major, minor = torch.cuda.get_device_capability(device) elif not torch.cuda.is_available(): return else: major, minor = torch.cuda.get_device_capability() # CUTLASS DSL distinguishes post-Hopper arch variants such as sm_90a/sm_100a, # while torch.cuda.get_device_capability() only returns major/minor. suffix = "a" if major >= 9 else "" desired = f"sm_{major}{minor}{suffix}" if os.environ.get("CUTE_DSL_ARCH") != desired: os.environ["CUTE_DSL_ARCH"] = desired def default_cute_launcher( cute_kernel: object, grid: tuple[int, ...], *args: object, **kwargs: object, ) -> object: block = kwargs.pop("block", (256, 1, 1)) cute_compile_options = kwargs.pop("cute_compile_options", None) if cute_compile_options is not None and not isinstance(cute_compile_options, str): raise ValueError(f"Invalid CuTe compile options: {cute_compile_options!r}") if not isinstance(block, tuple) or len(block) < 1: raise ValueError(f"Invalid block specification: {block}") if not isinstance(grid, tuple) or len(grid) < 1: raise ValueError(f"Invalid grid specification: {grid}") if kwargs: raise exc.BackendUnsupported("cute", f"launcher kwargs: {sorted(kwargs)}") grid_xyz = ( int(grid[0]), int(grid[1]) if len(grid) > 1 else 1, int(grid[2]) if len(grid) > 2 else 1, ) block_xyz = ( int(block[0]), int(block[1]) if len(block) > 1 else 1, int(block[2]) if len(block) > 2 else 1, ) if any(dim <= 0 for dim in grid_xyz): return None schema_key, launch_args = _build_cute_schema_and_args( cute_kernel, tuple(args), grid_xyz ) compiled = _get_compiled_cute_launcher( cute_kernel, schema_key, block_xyz, compile_options=cute_compile_options, ) return cast("Any", compiled)(*launch_args) def default_metal_launcher( metal_kernel: object, grid: tuple[int, ...], *args: object, _block_dims: tuple[int, int, int] = (256, 1, 1), **kwargs: object, ) -> None: """Default launcher for Metal kernels on Apple MPS devices. The ``metal_kernel`` is a ``@metal_jit`` decorated function that translates its Python AST body to MSL and compiles it via ``torch.mps.compile_shader`` on each call. This launcher dispatches the compiled kernel with the given grid and threadgroup dimensions. Uses a 3D threadgroup dispatch model: ``_block_dims`` specifies the threadgroup size as ``(x, y, z)``. The grid specifies the number of threadgroups per dimension. """ kwargs.pop("num_warps", None) kwargs.pop("num_stages", None) if kwargs: raise exc.BackendUnsupported( "metal", f"unexpected launcher kwargs: {sorted(kwargs)}" ) lib, kernel_name = metal_kernel(*args) # type: ignore[operator] tensor_args = [a for a in args if isinstance(a, torch.Tensor)] dispatch_fn = getattr(lib, kernel_name) bx, by, bz = _block_dims # Pad grid to 3D gx = grid[0] if len(grid) > 0 else 1 gy = grid[1] if len(grid) > 1 else 1 gz = grid[2] if len(grid) > 2 else 1 total_threads = (gx * bx, gy * by, gz * bz) group_size = (bx, by, bz) dispatch_fn(*tensor_args, threads=total_threads, group_size=group_size)