from __future__ import annotations
import dataclasses
import hashlib
import inspect
import itertools
import json
import logging
import os
from pathlib import Path
import platform
import textwrap
from typing import TYPE_CHECKING
import uuid
import torch
from torch._inductor.runtime.cache_dir_utils import cache_dir
from .._compat import extract_device
from .._compat import get_device_name
from ..runtime.config import Config
from .base_cache import AutotuneCacheBase
from .base_cache import LooseAutotuneCacheKey
from .base_cache import StrictAutotuneCacheKey
if TYPE_CHECKING:
from collections.abc import Iterator
from collections.abc import Sequence
from .base_search import BaseSearch
log: logging.Logger = logging.getLogger(__name__)
[docs]
def get_helion_cache_dir() -> Path:
"""Return the root directory for all Helion caches."""
if (user_path := os.environ.get("HELION_CACHE_DIR")) is not None:
return Path(user_path)
return Path(cache_dir()) / "helion"
[docs]
def helion_triton_cache_dir(device_index: int) -> str:
"""Return per-device Triton cache directory under Helion's cache root."""
return str(get_helion_cache_dir() / "triton" / str(device_index))
[docs]
@dataclasses.dataclass(frozen=True)
class SavedBestConfig:
"""A parsed cache entry from a .best_config file."""
hardware: str
specialization_key: str
config: Config
config_spec_hash: str
flat_config: tuple[object, ...] | None
[docs]
def to_mutable_flat_config(self) -> list[object]:
"""Return the stored flat_config as a mutable list."""
assert self.flat_config is not None
return list(self.flat_config)
[docs]
def iter_cache_entries(
cache_path: Path, *, max_scan: int | None = None
) -> Iterator[SavedBestConfig]:
"""Yield parsed cache entries from *cache_path*, newest first.
Corrupt or unparsable files are skipped with a warning.
"""
if not cache_path.exists():
return
files = list(cache_path.glob("*.best_config"))
files.sort(key=lambda p: p.stat().st_mtime, reverse=True)
for p in itertools.islice(files, max_scan):
try:
data = json.loads(p.read_text())
fields = data["key"]["fields"]
raw_flat = data.get("flat_config")
if isinstance(raw_flat, str):
flat_config: tuple[object, ...] | None = tuple(json.loads(raw_flat))
elif raw_flat is not None:
flat_config = tuple(raw_flat)
else:
flat_config = None
yield SavedBestConfig(
hardware=fields.get("hardware", ""),
specialization_key=fields.get("specialization_key", ""),
config=Config.from_json(data["config"]),
config_spec_hash=fields.get("config_spec_hash", ""),
flat_config=flat_config,
)
except (OSError, KeyError, ValueError, TypeError) as e:
log.warning("Skipping corrupt cache file %s: %s", p.name, e)
continue
[docs]
class LocalAutotuneCache(AutotuneCacheBase):
"""
This class implements the local autotune cache, storing the
best config artifact on the local file system either by default
on torch's cache directory, or at a user specified HELION_CACHE_DIR
directory.
It uses the LooseAutotuneCacheKey implementation for the cache key
which takes into account device and source code properties, but does
not account for library level code changes such as Triton, Helion or
PyTorch. Use StrictLocalAutotuneCache to consider these properties.
"""
[docs]
def __init__(self, autotuner: BaseSearch) -> None:
super().__init__(autotuner)
self.key = self._generate_key()
def _generate_key(self) -> LooseAutotuneCacheKey:
in_memory_cache_key = self.kernel.kernel._create_bound_kernel_cache_key(
self.kernel,
tuple(self.args),
self.kernel.kernel.specialization_key(self.args),
)
kernel_source = textwrap.dedent(inspect.getsource(self.kernel.kernel.fn))
kernel_source_hash = hashlib.sha256(kernel_source.encode("utf-8")).hexdigest()
dev = extract_device(self.args)
assert dev is not None
hardware = get_device_name(dev)
runtime_name = None
if (
dev.type == "xpu"
and getattr(torch, "xpu", None) is not None
and torch.xpu.is_available()
):
runtime_name = torch.xpu.get_device_properties(dev).driver_version
elif dev.type == "cuda" and torch.cuda.is_available():
if torch.version.cuda is not None:
runtime_name = str(torch.version.cuda)
elif torch.version.hip is not None:
runtime_name = torch.version.hip
elif dev.type == "mps":
# Include OS version as Metal runtime is part of OS
runtime_name = platform.mac_ver()[0] or "mps"
elif dev.type == "tpu":
hardware = "tpu"
try:
import torch_tpu # type: ignore[import-not-found]
runtime_name = getattr(torch_tpu, "__version__", "unknown")
except ImportError:
runtime_name = "unknown"
assert hardware is not None and runtime_name is not None
config_spec_hash = self.kernel.config_spec.structural_fingerprint_hash()
return LooseAutotuneCacheKey(
specialization_key=in_memory_cache_key.specialization_key,
extra_results=in_memory_cache_key.extra_results,
kernel_source_hash=kernel_source_hash,
hardware=hardware,
runtime_name=runtime_name,
backend=self.kernel.env.backend.name,
config_spec_hash=config_spec_hash,
)
def _get_local_cache_path(self) -> Path:
return get_helion_cache_dir() / f"{self.key.stable_hash()}.best_config"
[docs]
def get(self) -> Config | None:
path = self._get_local_cache_path()
try:
data = json.loads(path.read_text())
return Config.from_json(data["config"])
except Exception:
return None
[docs]
def put(self, config: Config) -> None:
path = self._get_local_cache_path()
path.parent.mkdir(parents=True, exist_ok=True)
# Save both config and key for better debugging
# Store key as dict for safer reconstruction (avoids eval)
key_dict = {
"type": type(self.key).__name__,
"fields": {k: str(v) for k, v in vars(self.key).items()},
}
data: dict[str, object] = {
"config": config.to_json(),
"key": key_dict,
}
config_gen = self.kernel.config_spec.create_config_generation()
data["flat_config"] = json.dumps(config_gen.flatten(config))
backend_cache_key = self.kernel.backend_cache_key(config)
if backend_cache_key is None:
# Config may have been minimized (default values stripped),
# so it won't match the full config in _compile_cache.
# Expand it back by merging with defaults.
default = self.kernel.config_spec.default_config()
# pyrefly: ignore [bad-argument-type]
full_config = Config(**(default.config | config.config))
backend_cache_key = self.kernel.backend_cache_key(full_config)
if backend_cache_key is not None:
data["backend_cache_key"] = backend_cache_key
# Atomic write
tmp = path.parent / f"tmp.{uuid.uuid4()!s}"
tmp.write_text(json.dumps(data, indent=2))
os.rename(str(tmp), str(path))
def _get_cache_info_message(self) -> str:
cache_dir = self._get_local_cache_path().parent
return f"Cache directory: {cache_dir}. To run autotuning again, delete the cache directory or set HELION_SKIP_CACHE=1."
def _get_cache_key(self) -> LooseAutotuneCacheKey:
return self.key
def _list_cache_entries(self) -> Sequence[tuple[str, LooseAutotuneCacheKey]]:
"""List all cache entries in the cache directory."""
cache_dir = self._get_local_cache_path().parent
if not cache_dir.exists():
return []
current_key_hash = self.key.stable_hash()
entries: list[tuple[str, LooseAutotuneCacheKey]] = []
for cache_file in cache_dir.glob("*.best_config"):
try:
data = json.loads(cache_file.read_text())
file_hash = cache_file.stem
if file_hash == current_key_hash:
continue
key_data = data["key"]
# Create a simple namespace object that has the same attributes
# for comparison purposes (we don't need the full key object)
class CachedKey:
def __init__(self, fields: dict[str, str]) -> None:
for name, value in fields.items():
setattr(self, name, value)
cached_key = CachedKey(key_data["fields"])
entries.append((cache_file.name, cached_key)) # type: ignore[arg-type]
except Exception:
pass
return entries
[docs]
class StrictLocalAutotuneCache(LocalAutotuneCache):
"""
Stricter implementation of the local autotune cache, which takes into
account library level code changes such as Triton, Helion or PyTorch.
"""
def _generate_key(self) -> StrictAutotuneCacheKey:
loose_key = super()._generate_key()
return StrictAutotuneCacheKey(**vars(loose_key))