Source code for helion.language.tile_proxy

from __future__ import annotations

import threading
from typing import TYPE_CHECKING
from typing import Protocol
from typing import Sequence
from typing import TypeVar
from typing import cast
from typing_extensions import Self

import torch
from torch.utils._pytree import tree_map_only

from .. import exc
from .._compiler.compile_environment import CompileEnvironment
from .tile_interface import TileInterface

if TYPE_CHECKING:
    from collections.abc import Callable

    _T = TypeVar("_T")

    class _TLS(Protocol):
        index_calls: _CheckForIndexCalls | None


_tls: _TLS = cast("_TLS", threading.local())


[docs] class Tile(TileInterface, torch.Tensor): """ This class should not be instantiated directly, it is the result of hl.tile(...) and represents a single tile of the iteration space. Tile's can be used as indices to tensors, e.g. `tensor[tile]`. Tile's can also be use as sizes for allocations, e.g. `torch.empty([tile])`. There are also properties such as :meth:`tile.index <index>`, :meth:`tile.begin <begin>`, :meth:`tile.end <end>`, :meth:`tile.id <id>` and :meth:`tile.block_size <block_size>` that can be used to retrieve various information about the tile. Masking is implicit for tiles, so if the final tile is smaller than the block size loading that tile will only load the valid elements and reduction operations know to ignore the invalid elements. """
[docs] def __init__(self, block_id: int) -> None: super().__init__() self.block_id = block_id
@classmethod def __torch_function__( cls, func: Callable[..., object], types: object, args: tuple[object, ...] = (), kwargs: dict[str, object] | None = None, ) -> object: from ..language.memory_ops import load from ..language.memory_ops import store if func is torch.Tensor.__getitem__: if len(args) != 2 or kwargs: raise exc.IncorrectTileUsage(func) tensor, index = args assert isinstance(tensor, torch.Tensor) return load(tensor, cls._prepare_index(index)) if func is torch.Tensor.__setitem__: if len(args) != 3 or kwargs: raise exc.IncorrectTileUsage(func) tensor, index, value = args assert isinstance(tensor, torch.Tensor) assert isinstance(value, torch.Tensor) return store(tensor, cls._prepare_index(index), value) if ( func is torch.Tensor.__index__ and (index_calls := getattr(_tls, "index_calls", None)) is not None ): index_calls.count += 1 if func is torch.Tensor.__format__: return repr(args[0]) raise exc.IncorrectTileUsage(func) @staticmethod def _prepare_index(index: object) -> list[object]: if isinstance(index, (list, tuple)): return [*index] assert isinstance(index, Tile) return [index] def __repr__(self, tensor_contents: None = None) -> str: # pyright: ignore[reportIncompatibleMethodOverride] return f"Tile({self.block_id!r})" @classmethod def _tiles_to_sizes(cls, it: _T) -> _T: return tree_map_only(Tile, cls._tile_to_size, it) @staticmethod def _tile_to_size(x: Tile) -> torch.SymInt: return CompileEnvironment.current().block_sizes[x.block_id].var
class _CheckForIndexCalls: """ Unfortunately, the `__torch_function__` method of `TileIndexProxy` does not work properly when operations like view() are called on a `TileIndexProxy` object. It calls `__torch_function__(Tensor.__index__, ...)` but then discards the result because it is not an integer (if a SymInt is returned). This class is a workaround to detect this case and turn tiles to sizes in the caller. """ @classmethod def retry_call( cls, fn: Callable[..., object], proxy_args: Sequence[object], proxy_kwargs: dict[str, object], ) -> object: index_calls = cls() try: with index_calls: return fn(*proxy_args, **proxy_kwargs) except TypeError: if index_calls.count == 0: raise # This is likely a view op, try again with tiles_to_sizes proxy_args = Tile._tiles_to_sizes(proxy_args) proxy_kwargs = Tile._tiles_to_sizes(proxy_kwargs) return fn(*proxy_args, **proxy_kwargs) def __init__(self) -> None: self.count = 0 def __enter__(self) -> Self: assert getattr(_tls, "index_calls", None) is None _tls.index_calls = self return self def __exit__(self, *args: object) -> None: _tls.index_calls = None