Rate this Page

Source code for helion.autotuner.local_cache

from __future__ import annotations

import hashlib
import inspect
import json
import logging
import os
from pathlib import Path
import platform
import textwrap
from typing import TYPE_CHECKING
import uuid

import torch
from torch._inductor.runtime.cache_dir_utils import (
    cache_dir,  # pyright: ignore[reportPrivateImportUsage]
)

from ..runtime.config import Config
from .base_cache import AutotuneCacheBase
from .base_cache import LooseAutotuneCacheKey
from .base_cache import StrictAutotuneCacheKey

if TYPE_CHECKING:
    from collections.abc import Sequence

    from .base_search import BaseSearch

log: logging.Logger = logging.getLogger(__name__)


[docs] class LocalAutotuneCache(AutotuneCacheBase): """ This class implements the local autotune cache, storing the best config artifact on the local file system either by default on torch's cache directory, or at a user specified HELION_CACHE_DIR directory. It uses the LooseAutotuneCacheKey implementation for the cache key which takes into account device and source code properties, but does not account for library level code changes such as Triton, Helion or PyTorch. Use StrictLocalAutotuneCache to consider these properties. """
[docs] def __init__(self, autotuner: BaseSearch) -> None: super().__init__(autotuner) self.key = self._generate_key()
def _generate_key(self) -> LooseAutotuneCacheKey: in_memory_cache_key = self.kernel.kernel._create_bound_kernel_cache_key( self.kernel, tuple(self.args), self.kernel.kernel.specialization_key(self.args), ) kernel_source = textwrap.dedent(inspect.getsource(self.kernel.kernel.fn)) kernel_source_hash = hashlib.sha256(kernel_source.encode("utf-8")).hexdigest() hardware = None runtime_name = None for arg in self.args: if isinstance(arg, torch.Tensor): dev = arg.device # CPU support if dev.type == "cpu": hardware = "cpu" runtime_name = platform.machine().lower() break # XPU (Intel) path if ( dev.type == "xpu" and getattr(torch, "xpu", None) is not None and torch.xpu.is_available() ): # pyright: ignore[reportAttributeAccessIssue] device_properties = torch.xpu.get_device_properties(dev) hardware = device_properties.name runtime_name = device_properties.driver_version # pyright: ignore[reportAttributeAccessIssue] break # CUDA/ROCm path if dev.type == "cuda" and torch.cuda.is_available(): device_properties = torch.cuda.get_device_properties(dev) if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue] hardware = device_properties.name runtime_name = str(torch.version.cuda) elif torch.version.hip is not None: # pyright: ignore[reportAttributeAccessIssue] hardware = device_properties.gcnArchName runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue] break assert hardware is not None and runtime_name is not None return LooseAutotuneCacheKey( specialization_key=in_memory_cache_key.specialization_key, extra_results=in_memory_cache_key.extra_results, kernel_source_hash=kernel_source_hash, hardware=hardware, runtime_name=runtime_name, ) def _get_local_cache_path(self) -> Path: if (user_path := os.environ.get("HELION_CACHE_DIR", None)) is not None: cache_path = Path(user_path) else: cache_path = Path(cache_dir()) / "helion" return cache_path / f"{self.key.stable_hash()}.best_config"
[docs] def get(self) -> Config | None: path = self._get_local_cache_path() try: data = json.loads(path.read_text()) return Config.from_json(data["config"]) except Exception: return None
[docs] def put(self, config: Config) -> None: path = self._get_local_cache_path() path.parent.mkdir(parents=True, exist_ok=True) # Save both config and key for better debugging # Store key as dict for safer reconstruction (avoids eval) key_dict = { "type": type(self.key).__name__, "fields": {k: str(v) for k, v in vars(self.key).items()}, } data = { "config": config.to_json(), "key": key_dict, } # Atomic write tmp = path.parent / f"tmp.{uuid.uuid4()!s}" tmp.write_text(json.dumps(data, indent=2)) os.rename(str(tmp), str(path))
def _get_cache_info_message(self) -> str: cache_dir = self._get_local_cache_path().parent return f"Cache directory: {cache_dir}. To run autotuning again, delete the cache directory or set HELION_SKIP_CACHE=1." def _get_cache_key(self) -> LooseAutotuneCacheKey: return self.key def _list_cache_entries(self) -> Sequence[tuple[str, LooseAutotuneCacheKey]]: """List all cache entries in the cache directory.""" cache_dir = self._get_local_cache_path().parent if not cache_dir.exists(): return [] current_key_hash = self.key.stable_hash() entries: list[tuple[str, LooseAutotuneCacheKey]] = [] for cache_file in cache_dir.glob("*.best_config"): try: data = json.loads(cache_file.read_text()) file_hash = cache_file.stem if file_hash == current_key_hash: continue key_data = data["key"] # Create a simple namespace object that has the same attributes # for comparison purposes (we don't need the full key object) class CachedKey: def __init__(self, fields: dict[str, str]) -> None: for name, value in fields.items(): setattr(self, name, value) cached_key = CachedKey(key_data["fields"]) entries.append((cache_file.name, cached_key)) # type: ignore[arg-type] except Exception: pass return entries
[docs] class StrictLocalAutotuneCache(LocalAutotuneCache): """ Stricter implementation of the local autotune cache, which takes into account library level code changes such as Triton, Helion or PyTorch. """ def _generate_key(self) -> StrictAutotuneCacheKey: loose_key = super()._generate_key() return StrictAutotuneCacheKey(**vars(loose_key))