Source code for helion.autotuner.finite_search
from __future__ import annotations
from typing import TYPE_CHECKING
from .. import exc
from .base_search import BaseSearch
if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Sequence
from ..runtime.config import Config
from ..runtime.kernel import BoundKernel
from .base_search import _AutotunableKernel
from .config_generation import ConfigGeneration
[docs]
class FiniteSearch(BaseSearch):
"""Search over a given list of configs, returning the best one.
This strategy is similar to triton.Autotune, and is the default if you specify `helion.kernel(configs=[...])`.
"""
[docs]
def __init__(
self,
kernel: _AutotunableKernel,
args: Sequence[object],
configs: Sequence[Config] | None = None,
) -> None:
super().__init__(kernel, args)
self.config_gen: ConfigGeneration = self.config_spec.create_config_generation(
overrides=self.settings.autotune_config_overrides or None,
advanced_controls_files=self.settings.autotune_search_acf or None,
process_group_name=kernel.env.process_group_name,
)
raw: list[Config] = list(configs if configs is not None else kernel.configs)
self.configs: list[Config] = raw
if len(self.configs) < 2:
raise exc.NotEnoughConfigs(len(self.configs))
def _autotune(self) -> Config:
best_config = None
best_time = float("inf")
for result in self.benchmark_batch(self.configs, desc="Benchmarking"):
if result.perf < best_time:
best_time = result.perf
best_config = result.config
assert best_config is not None
return best_config
[docs]
class CachedFiniteSearch(FiniteSearch):
"""FiniteSearch seeded with previously-cached best_configs prepended to the explicit config list."""
[docs]
def __init__(
self,
kernel: _AutotunableKernel,
args: Sequence[object],
*,
configs: Sequence[Config] = (),
max_configs: int | None = None,
) -> None:
BaseSearch.__init__(self, kernel, args)
self.config_gen: ConfigGeneration = self.config_spec.create_config_generation(
overrides=self.settings.autotune_config_overrides or None,
advanced_controls_files=self.settings.autotune_search_acf or None,
process_group_name=kernel.env.process_group_name,
)
cap = (
max_configs
if max_configs is not None
else self.settings.autotune_best_available_max_configs
)
cached: list[Config] = []
for i, entry in enumerate(self._find_similar_cached_configs(cap)):
try:
cached.append(self.config_gen.unflatten(entry.to_mutable_flat_config()))
except (
ValueError,
TypeError,
KeyError,
AssertionError,
exc.InvalidConfig,
) as e:
self.log(f"from_cache: failed to transfer cached config {i + 1}: {e}")
self.log(f"from_cache: resolved {len(cached)} cached config(s) (cap={cap})")
self.configs: list[Config] = [*cached, *list(configs)]
if len(self.configs) < 2:
raise exc.NotEnoughConfigs(len(self.configs))
[docs]
def from_cache(
*, max_configs: int | None = None, configs: Sequence[Config] = ()
) -> Callable[..., CachedFiniteSearch]:
"""Return an autotuner_fn that seeds FiniteSearch with previously-cached best_configs."""
def _fn(
bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
) -> CachedFiniteSearch:
return CachedFiniteSearch(
bound_kernel, args, configs=configs, max_configs=max_configs
)
return _fn