Rate this Page

Source code for helion.autotuner.local_cache

from __future__ import annotations

import dataclasses
import hashlib
import inspect
import itertools
import json
import logging
import os
from pathlib import Path
import platform
import textwrap
from typing import TYPE_CHECKING
import uuid

import torch
from torch._inductor.runtime.cache_dir_utils import cache_dir

from .._compat import extract_device
from .._compat import get_device_name
from ..runtime.config import Config
from .base_cache import AutotuneCacheBase
from .base_cache import LooseAutotuneCacheKey
from .base_cache import StrictAutotuneCacheKey

if TYPE_CHECKING:
    from collections.abc import Iterator
    from collections.abc import Sequence

    from .base_search import BaseSearch

log: logging.Logger = logging.getLogger(__name__)


[docs] def get_helion_cache_dir() -> Path: """Return the root directory for all Helion caches.""" if (user_path := os.environ.get("HELION_CACHE_DIR")) is not None: return Path(user_path) return Path(cache_dir()) / "helion"
[docs] def helion_triton_cache_dir(device_index: int) -> str: """Return per-device Triton cache directory under Helion's cache root.""" return str(get_helion_cache_dir() / "triton" / str(device_index))
[docs] @dataclasses.dataclass(frozen=True) class SavedBestConfig: """A parsed cache entry from a .best_config file.""" hardware: str specialization_key: str config: Config config_spec_hash: str flat_config: tuple[object, ...] | None
[docs] def to_mutable_flat_config(self) -> list[object]: """Return the stored flat_config as a mutable list.""" assert self.flat_config is not None return list(self.flat_config)
[docs] def iter_cache_entries( cache_path: Path, *, max_scan: int | None = None ) -> Iterator[SavedBestConfig]: """Yield parsed cache entries from *cache_path*, newest first. Corrupt or unparsable files are skipped with a warning. """ if not cache_path.exists(): return files = list(cache_path.glob("*.best_config")) files.sort(key=lambda p: p.stat().st_mtime, reverse=True) for p in itertools.islice(files, max_scan): try: data = json.loads(p.read_text()) fields = data["key"]["fields"] raw_flat = data.get("flat_config") if isinstance(raw_flat, str): flat_config: tuple[object, ...] | None = tuple(json.loads(raw_flat)) elif raw_flat is not None: flat_config = tuple(raw_flat) else: flat_config = None yield SavedBestConfig( hardware=fields.get("hardware", ""), specialization_key=fields.get("specialization_key", ""), config=Config.from_json(data["config"]), config_spec_hash=fields.get("config_spec_hash", ""), flat_config=flat_config, ) except (OSError, KeyError, ValueError, TypeError) as e: log.warning("Skipping corrupt cache file %s: %s", p.name, e) continue
[docs] class LocalAutotuneCache(AutotuneCacheBase): """ This class implements the local autotune cache, storing the best config artifact on the local file system either by default on torch's cache directory, or at a user specified HELION_CACHE_DIR directory. It uses the LooseAutotuneCacheKey implementation for the cache key which takes into account device and source code properties, but does not account for library level code changes such as Triton, Helion or PyTorch. Use StrictLocalAutotuneCache to consider these properties. """
[docs] def __init__(self, autotuner: BaseSearch) -> None: super().__init__(autotuner) self.key = self._generate_key()
def _generate_key(self) -> LooseAutotuneCacheKey: in_memory_cache_key = self.kernel.kernel._create_bound_kernel_cache_key( self.kernel, tuple(self.args), self.kernel.kernel.specialization_key(self.args), ) kernel_source = textwrap.dedent(inspect.getsource(self.kernel.kernel.fn)) kernel_source_hash = hashlib.sha256(kernel_source.encode("utf-8")).hexdigest() dev = extract_device(self.args) assert dev is not None hardware = get_device_name(dev) runtime_name = None if ( dev.type == "xpu" and getattr(torch, "xpu", None) is not None and torch.xpu.is_available() ): runtime_name = torch.xpu.get_device_properties(dev).driver_version elif dev.type == "cuda" and torch.cuda.is_available(): if torch.version.cuda is not None: runtime_name = str(torch.version.cuda) elif torch.version.hip is not None: runtime_name = torch.version.hip elif dev.type == "mps": # Include OS version as Metal runtime is part of OS runtime_name = platform.mac_ver()[0] or "mps" elif dev.type == "tpu": hardware = "tpu" try: import torch_tpu # type: ignore[import-not-found] runtime_name = getattr(torch_tpu, "__version__", "unknown") except ImportError: runtime_name = "unknown" assert hardware is not None and runtime_name is not None config_spec_hash = self.kernel.config_spec.structural_fingerprint_hash() return LooseAutotuneCacheKey( specialization_key=in_memory_cache_key.specialization_key, extra_results=in_memory_cache_key.extra_results, kernel_source_hash=kernel_source_hash, hardware=hardware, runtime_name=runtime_name, backend=self.kernel.env.backend.name, config_spec_hash=config_spec_hash, ) def _get_local_cache_path(self) -> Path: return get_helion_cache_dir() / f"{self.key.stable_hash()}.best_config"
[docs] def get(self) -> Config | None: path = self._get_local_cache_path() try: data = json.loads(path.read_text()) return Config.from_json(data["config"]) except Exception: return None
[docs] def put(self, config: Config) -> None: path = self._get_local_cache_path() path.parent.mkdir(parents=True, exist_ok=True) # Save both config and key for better debugging # Store key as dict for safer reconstruction (avoids eval) key_dict = { "type": type(self.key).__name__, "fields": {k: str(v) for k, v in vars(self.key).items()}, } data: dict[str, object] = { "config": config.to_json(), "key": key_dict, } config_gen = self.kernel.config_spec.create_config_generation() data["flat_config"] = json.dumps(config_gen.flatten(config)) backend_cache_key = self.kernel.backend_cache_key(config) if backend_cache_key is None: # Config may have been minimized (default values stripped), # so it won't match the full config in _compile_cache. # Expand it back by merging with defaults. default = self.kernel.config_spec.default_config() # pyrefly: ignore [bad-argument-type] full_config = Config(**(default.config | config.config)) backend_cache_key = self.kernel.backend_cache_key(full_config) if backend_cache_key is not None: data["backend_cache_key"] = backend_cache_key # Atomic write tmp = path.parent / f"tmp.{uuid.uuid4()!s}" tmp.write_text(json.dumps(data, indent=2)) os.rename(str(tmp), str(path))
def _get_cache_info_message(self) -> str: cache_dir = self._get_local_cache_path().parent return f"Cache directory: {cache_dir}. To run autotuning again, delete the cache directory or set HELION_SKIP_CACHE=1." def _get_cache_key(self) -> LooseAutotuneCacheKey: return self.key def _list_cache_entries(self) -> Sequence[tuple[str, LooseAutotuneCacheKey]]: """List all cache entries in the cache directory.""" cache_dir = self._get_local_cache_path().parent if not cache_dir.exists(): return [] current_key_hash = self.key.stable_hash() entries: list[tuple[str, LooseAutotuneCacheKey]] = [] for cache_file in cache_dir.glob("*.best_config"): try: data = json.loads(cache_file.read_text()) file_hash = cache_file.stem if file_hash == current_key_hash: continue key_data = data["key"] # Create a simple namespace object that has the same attributes # for comparison purposes (we don't need the full key object) class CachedKey: def __init__(self, fields: dict[str, str]) -> None: for name, value in fields.items(): setattr(self, name, value) cached_key = CachedKey(key_data["fields"]) entries.append((cache_file.name, cached_key)) # type: ignore[arg-type] except Exception: pass return entries
[docs] class StrictLocalAutotuneCache(LocalAutotuneCache): """ Stricter implementation of the local autotune cache, which takes into account library level code changes such as Triton, Helion or PyTorch. """ def _generate_key(self) -> StrictAutotuneCacheKey: loose_key = super()._generate_key() return StrictAutotuneCacheKey(**vars(loose_key))