Rate this Page

Source code for helion.language.view_ops

from __future__ import annotations

import collections
from typing import TYPE_CHECKING
from typing import cast

import torch

from .. import exc
from .._compiler.ast_extension import expr_from_string
from .._compiler.compile_environment import CompileEnvironment
from ..exc import NotInsideKernel
from . import _decorators

if TYPE_CHECKING:
    import ast

    from .._compiler.inductor_lowering import CodegenState

__all__ = ["join", "split", "subscript"]


[docs] @_decorators.api(tiles_as_sizes=True) def subscript(tensor: torch.Tensor, index: list[object]) -> torch.Tensor: """ Equivalent to tensor[index] where tensor is a kernel-tensor (not a host-tensor). Can be used to add dimensions to the tensor, e.g. tensor[None, :] or tensor[:, None]. Args: tensor: The kernel tensor to index index: List of indices, including None for new dimensions and : for existing dimensions Returns: torch.Tensor: The indexed tensor with potentially modified dimensions Examples: .. code-block:: python @helion.kernel def broadcast_multiply(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # x has shape (N,), y has shape (M,) result = torch.empty( [x.size(0), y.size(0)], dtype=x.dtype, device=x.device ) for tile_i, tile_j in hl.tile([x.size(0), y.size(0)]): # Get tile data x_tile = x[tile_i] y_tile = y[tile_j] # Make x broadcastable: (tile_size, 1) # same as hl.subscript(x_tile, [slice(None), None]) x_expanded = x_tile[:, None] # Make y broadcastable: (1, tile_size) # same as hl.subscript(y_tile, [None, slice(None)]) y_expanded = y_tile[None, :] result[tile_i, tile_j] = x_expanded * y_expanded return result See Also: - :func:`~helion.language.load`: For loading tensor values - :func:`~helion.language.store`: For storing tensor values Note: - Only supports None and : (slice(None)) indexing - Used for reshaping kernel tensors by adding dimensions - Prefer direct indexing syntax when possible: ``tensor[None, :]`` - Does not support integer indexing or slicing with start/stop """ raise NotInsideKernel
@_decorators.register_fake(subscript) def _(tensor: torch.Tensor, index: list[object]) -> torch.Tensor: input_size = collections.deque(tensor.size()) output_size = [] for val in index: if val is None: output_size.append(1) elif isinstance(val, slice) and repr(val) == "slice(None, None, None)": output_size.append(input_size.popleft()) else: raise exc.InvalidIndexingType(repr(val)) assert len(input_size) == 0 env = CompileEnvironment.current() return env.new_index_result(tensor, output_size) @_decorators.codegen(subscript, "common") def _(state: CodegenState) -> ast.AST: output_keys = [] # pyrefly: ignore [not-iterable] for val in state.proxy_arg(1): if val is None: output_keys.append("None") elif isinstance(val, slice) and repr(val) == "slice(None, None, None)": output_keys.append(":") else: raise exc.InvalidIndexingType(repr(val)) return expr_from_string( f"{{base}}[{', '.join(output_keys)}]", base=state.ast_arg(0), ) @_decorators.codegen(subscript, "cute") def _(state: CodegenState) -> ast.AST: # CuTe kernels currently execute scalarized pointwise code, so shape-only # indexing used for broadcast setup is a no-op. return state.ast_arg(0) @_decorators.ref(subscript) def _(tensor: torch.Tensor, indices: list[object]) -> torch.Tensor: # pyrefly: ignore [bad-index] return tensor[indices] @_decorators.get_masked_value(subscript) def _(node: torch.fx.Node) -> float | bool | None: from .._compiler.node_masking import cached_masked_value other = node.args[0] assert isinstance(other, torch.fx.Node) return cached_masked_value(other)
[docs] @_decorators.api(is_device_only=True) def split(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Split the last dimension of a tensor with size two into two separate tensors. Args: tensor: The input tensor whose last dimension has length two. Returns: A tuple ``(lo, hi)`` where each tensor has the same shape as ``tensor`` without its last dimension. See Also: - :func:`~helion.language.join` """ raise NotInsideKernel
@_decorators.register_fake(split) def _(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: out_shape = tensor.shape[:-1] return ( tensor.new_empty(out_shape), tensor.new_empty(out_shape), ) @_decorators.codegen(split, "triton") def _(state: CodegenState) -> list[ast.AST]: split_call = expr_from_string("tl.split({tensor})", tensor=state.ast_arg(0)) return [ expr_from_string("{value}[0]", value=split_call), expr_from_string("{value}[1]", value=split_call), ] @_decorators.codegen(split, "cute") def _(state: CodegenState) -> list[ast.AST]: from .._compiler.ast_extension import statement_from_string from .._compiler.cute.cute_reshape import _flat_index_from_coords from .._compiler.cute.cute_reshape import _get_node_dim_local_coord from .._compiler.cute.cute_reshape import _get_tile_shape from .._compiler.generate_ast import GenerateAST fx_node = state.fx_node assert fx_node is not None input_node = fx_node.args[0] assert isinstance(input_node, torch.fx.Node) input_val = input_node.meta["val"] assert isinstance(input_val, torch.Tensor) output_val = input_val.new_empty(input_val.shape[:-1]) cg = state.codegen assert isinstance(cg, GenerateAST) df = cg.device_function env = CompileEnvironment.current() config = df.config input_shape = _get_tile_shape(input_val, env, config) output_shape = _get_tile_shape(output_val, env, config) input_numel = 1 for s in input_shape: input_numel *= s dtype_str = env.backend.dtype_str(input_val.dtype) smem_ptr = df.new_var("split_smem_ptr") smem = df.new_var("split_smem") src_coords = [ _get_node_dim_local_coord(cg, input_node, input_val, i) for i in range(len(input_shape)) ] src_flat = _flat_index_from_coords(src_coords, input_shape) if output_shape: output_coords = [ _get_node_dim_local_coord(cg, input_node, output_val, i) for i in range(len(output_shape)) ] out_flat_base = _flat_index_from_coords(output_coords, output_shape) else: out_flat_base = "cutlass.Int32(0)" lo_flat = f"({out_flat_base}) * cutlass.Int32(2)" hi_flat = f"({out_flat_base}) * cutlass.Int32(2) + cutlass.Int32(1)" cg.add_statement( statement_from_string( f"{smem_ptr} = cute.arch.alloc_smem({dtype_str}, {input_numel})" ) ) cg.add_statement( statement_from_string( f"{smem} = cute.make_tensor({smem_ptr}, ({input_numel},))" ) ) cg.add_statement( statement_from_string(f"{smem}[{src_flat}] = {{_inp}}", _inp=state.ast_arg(0)) ) cg.add_statement(statement_from_string("cute.arch.sync_threads()")) lo_var = df.new_var("split_lo") hi_var = df.new_var("split_hi") cg.add_statement(statement_from_string(f"{lo_var} = {smem}[{lo_flat}]")) cg.add_statement(statement_from_string(f"{hi_var} = {smem}[{hi_flat}]")) return [ expr_from_string(lo_var), expr_from_string(hi_var), ] @_decorators.ref(split) def _(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return cast("tuple[torch.Tensor, torch.Tensor]", torch.unbind(tensor, dim=-1))
[docs] @_decorators.api(is_device_only=True) def join( tensor0: torch.Tensor, tensor1: torch.Tensor, ) -> torch.Tensor: """ Join two tensors along a new minor dimension. Args: tensor0: First tensor to join. tensor1: Second tensor to join. Must be broadcast-compatible with ``tensor0``. Returns: torch.Tensor: A tensor with shape ``broadcast_shape + (2,)`` where ``broadcast_shape`` is the broadcast of the input shapes. See Also: - :func:`~helion.language.split` """ raise NotInsideKernel
@_decorators.register_fake(join) def _(tensor0: torch.Tensor, tensor1: torch.Tensor) -> torch.Tensor: if tensor0.dtype != tensor1.dtype: raise TypeError("join() requires both tensors to have the same dtype") if tensor0.device != tensor1.device: raise ValueError("join() requires both tensors to be on the same device") broadcast_shape = torch.broadcast_shapes(tensor0.shape, tensor1.shape) return tensor0.new_empty([*broadcast_shape, 2]) @_decorators.codegen(join, "triton") def _(state: CodegenState) -> ast.AST: return expr_from_string( "tl.join({tensor0}, {tensor1})", tensor0=state.ast_arg(0), tensor1=state.ast_arg(1), ) @_decorators.codegen(join, "cute") def _(state: CodegenState) -> ast.AST: from .._compiler.cute.cute_reshape import _get_dim_local_coord from .._compiler.generate_ast import GenerateAST fx_node = state.fx_node assert fx_node is not None output_val = fx_node.meta["val"] assert isinstance(output_val, torch.Tensor) assert isinstance(state.codegen, GenerateAST) new_dim = output_val.ndim - 1 selector = _get_dim_local_coord(state.codegen, output_val, new_dim) return expr_from_string( f"(({{a}}) if ({selector}) == cutlass.Int32(0) else ({{b}}))", a=state.ast_arg(0), b=state.ast_arg(1), ) @_decorators.ref(join) def _(tensor0: torch.Tensor, tensor1: torch.Tensor) -> torch.Tensor: left, right = torch.broadcast_tensors(tensor0, tensor1) return torch.stack((left, right), dim=-1)