from __future__ import annotations
from typing import TYPE_CHECKING
from typing import cast
import torch
from torch.fx import Node
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.host_function import HostFunction
from .._compiler.rng_utils import BOX_MULLER_MIN
from .._compiler.rng_utils import HALF_MASK16
from .._compiler.rng_utils import PHILOX_KEY_A
from .._compiler.rng_utils import PHILOX_KEY_B
from .._compiler.rng_utils import PHILOX_ROUND_A
from .._compiler.rng_utils import PHILOX_ROUND_B
from .._compiler.rng_utils import PHILOX_ROUNDS
from .._compiler.rng_utils import TWO_PI
from .._compiler.rng_utils import UINT32_TO_UNIFORM_SCALE
from .._compiler.rng_utils import codegen_rng_seed_expr
from .._compiler.rng_utils import philox_rand_ref
from .._compiler.rng_utils import philox_randint_ref
from ..exc import NotInsideKernel
from . import _decorators
from .ref_tile import RefTile
if TYPE_CHECKING:
import ast
from collections.abc import Callable
from torch._prims_common import DeviceLikeType
from .._compiler.inductor_lowering import CodegenState
from .tile_interface import TileInterface
__all__ = ["rand", "randint"]
_ShapeDim = int | torch.SymInt
_ArgDesc = tuple[bool, object]
MASK32 = (1 << 32) - 1
SIGN_BIT32 = 1 << 31
INT32_MAX = (1 << 31) - 1
def _pallas_safe_i32_scalar_like(
ref: torch.Tensor,
value: int,
) -> int | torch.Tensor:
if CompileEnvironment.current().backend.name != "pallas":
return value
if value > INT32_MAX:
value -= 1 << 32
return torch.scalar_tensor(value, dtype=ref.dtype, device=ref.device)
def _mask_u32(x: torch.Tensor) -> torch.Tensor:
return x & _pallas_safe_i32_scalar_like(x, MASK32)
def _shape_dim_extent(dim: _ShapeDim) -> _ShapeDim:
env = CompileEnvironment.current()
if (block_id := env.get_block_id(dim)) is not None:
full_extent = env.block_sizes[env.canonical_block_id(block_id)].size
assert isinstance(full_extent, (int, torch.SymInt))
return full_extent
return dim
def _shape_dim_index(
dim: _ShapeDim,
*,
device: torch.device,
) -> torch.Tensor:
from .tile_ops import tile_index
env = CompileEnvironment.current()
if env.get_block_id(dim) is not None:
assert isinstance(dim, torch.SymInt)
return _convert_element_type(
tile_index(cast("TileInterface", dim)), torch.int64
) # pyrefly: ignore[bad-argument-type]
if isinstance(dim, int):
return torch.arange(dim, device=device, dtype=torch.int64)
# pyrefly: ignore[no-matching-overload]
return torch.arange(dim, device=device, dtype=torch.int64)
def _explicit_offset_from_shape(
shape: list[_ShapeDim],
*,
device: torch.device,
) -> torch.Tensor:
if not shape:
return torch.arange(1, device=device, dtype=torch.int64).reshape([]) * 0
extents: list[_ShapeDim] = [_shape_dim_extent(dim) for dim in shape]
indices = [_shape_dim_index(dim, device=device) for dim in shape]
strides: list[_ShapeDim] = [1] * len(shape)
for i in range(len(shape) - 2, -1, -1):
strides[i] = strides[i + 1] * extents[i + 1]
ndim = len(shape)
init_shape: list[_ShapeDim] = [1] * ndim
init_shape[0] = shape[0]
offset = indices[0].reshape(init_shape) * 0
for dim, (index, stride) in enumerate(zip(indices, strides, strict=True)):
if ndim > 1:
view_shape: list[_ShapeDim] = [1] * ndim
view_shape[dim] = shape[dim]
index = index.reshape(view_shape)
offset = cast("torch.Tensor", offset + index * stride)
return offset
def _ref_rng_shape_and_offset(
shape: list[int | RefTile],
*,
device: torch.device,
) -> tuple[list[int], torch.Tensor]:
processed_shape: list[int] = []
full_extents: list[int] = []
indices: list[torch.Tensor] = []
for dim in shape:
if isinstance(dim, RefTile):
processed_shape.append(dim.end - dim.begin)
full_extents.append(dim._extent_end - dim._extent_begin)
indices.append(
torch.arange(dim.begin, dim.end, dtype=torch.int64, device=device)
)
else:
size = int(dim)
processed_shape.append(size)
full_extents.append(size)
indices.append(torch.arange(size, dtype=torch.int64, device=device))
if not processed_shape:
return [], torch.zeros([], dtype=torch.int64, device=device)
strides = [1] * len(processed_shape)
for i in range(len(processed_shape) - 2, -1, -1):
strides[i] = strides[i + 1] * full_extents[i + 1]
offset = torch.zeros(processed_shape, dtype=torch.int64, device=device)
ndim = len(processed_shape)
for dim, (index, stride) in enumerate(zip(indices, strides, strict=True)):
if ndim > 1:
view_shape = [1] * ndim
view_shape[dim] = processed_shape[dim]
index = index.reshape(view_shape)
offset = offset + index * stride
return processed_shape, offset
def _as_int64_scalar(
value: int | torch.SymInt | torch.Tensor, *, device: torch.device
) -> torch.Tensor:
if isinstance(value, torch.Tensor):
return _convert_element_type(value, torch.int64)
if isinstance(value, torch.SymInt):
return torch.scalar_tensor(cast("int", value), dtype=torch.int64, device=device)
return torch.scalar_tensor(value, dtype=torch.int64, device=device)
def _uint32_to_signed_int64(x: torch.Tensor) -> torch.Tensor:
x64 = _convert_element_type(x, torch.int64)
sign_bit32 = _pallas_safe_i32_scalar_like(x64, SIGN_BIT32)
return _mask_u32(x64 + sign_bit32) - sign_bit32
def _uint32_to_uniform_float(x: torch.Tensor) -> torch.Tensor:
signed = _uint32_to_signed_int64(x)
magnitude = torch.where(signed < 0, -signed - 1, signed)
return _convert_element_type(magnitude, torch.float32) * UINT32_TO_UNIFORM_SCALE
def _convert_element_type(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return torch.ops.prims.convert_element_type.default(x, dtype)
def _mulhi_lo_u32(
a: int | torch.Tensor,
b: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
a64 = _convert_element_type(a, torch.int64) if isinstance(a, torch.Tensor) else a
b64 = _convert_element_type(b, torch.int64)
a0 = a64 & HALF_MASK16
a1 = (a64 >> 16) & HALF_MASK16
b0 = b64 & HALF_MASK16
b1 = (b64 >> 16) & HALF_MASK16
t = a0 * b0
w0 = t & HALF_MASK16
k = t >> 16
t = a1 * b0 + k
w1 = t & HALF_MASK16
w2 = t >> 16
t = a0 * b1 + w1
lo = _mask_u32(((t & HALF_MASK16) << 16) | w0)
hi = _mask_u32(a1 * b1 + w2 + (t >> 16))
return hi, lo
def _philox_uint32x4(
seed: int | torch.SymInt | torch.Tensor,
offset: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
device = offset.device
offset64 = _convert_element_type(offset, torch.int64)
seed64 = _as_int64_scalar(seed, device=device)
c0 = _mask_u32(offset64)
c1 = _mask_u32(offset64 >> 32)
c2 = c0 * 0
c3 = c0 * 0
k0 = _mask_u32(seed64)
k1 = _mask_u32(seed64 >> 32)
for _ in range(PHILOX_ROUNDS):
hi0, lo0 = _mulhi_lo_u32(PHILOX_ROUND_B, c2)
hi1, lo1 = _mulhi_lo_u32(PHILOX_ROUND_A, c0)
c0 = _mask_u32(hi0 ^ c1 ^ k0)
c1 = lo0
c2 = _mask_u32(hi1 ^ c3 ^ k1)
c3 = lo1
k0 = _mask_u32(k0 + _pallas_safe_i32_scalar_like(k0, PHILOX_KEY_A))
k1 = _mask_u32(k1 + _pallas_safe_i32_scalar_like(k1, PHILOX_KEY_B))
return c0, c1, c2, c3
def _philox_rand_from_seed_and_offset(
seed: int | torch.SymInt | torch.Tensor,
offset: torch.Tensor,
) -> torch.Tensor:
c0, _, _, _ = _philox_uint32x4(seed, offset)
return _uint32_to_uniform_float(c0)
def _philox_randn_from_seed_and_offset(
seed: int | torch.SymInt | torch.Tensor,
offset: torch.Tensor,
) -> torch.Tensor:
c0, c1, _, _ = _philox_uint32x4(seed, offset)
u1 = torch.clamp_min(_uint32_to_uniform_float(c0), BOX_MULLER_MIN)
u2 = _uint32_to_uniform_float(c1)
minus_two = torch.full([], -2.0, dtype=torch.float32, device=u1.device)
tau = torch.full([], TWO_PI, dtype=torch.float32, device=u1.device)
radius = torch.sqrt(torch.log(u1) * minus_two)
return radius * torch.cos(u2 * tau)
def _philox_randint_from_seed_and_offset(
seed: int | torch.SymInt | torch.Tensor,
offset: torch.Tensor,
*,
low: int,
high: int,
) -> torch.Tensor:
if low >= high:
raise ValueError(f"low ({low}) must be less than high ({high})")
c0, _, _, _ = _philox_uint32x4(seed, offset)
signed = _uint32_to_signed_int64(c0)
magnitude = torch.where(signed < 0, -signed, signed)
return _convert_element_type(low + (magnitude % (high - low)), torch.int32)
@_decorators.api()
def _rng_seed(index: int) -> torch.Tensor:
raise AssertionError("this should never be called directly")
@_decorators.register_fake(_rng_seed)
def _(index: int) -> torch.Tensor:
env = CompileEnvironment.current()
return torch.empty([], dtype=torch.int64, device=env.device)
@_decorators.codegen(_rng_seed, "common")
def _(state: CodegenState) -> ast.AST:
seed_index = state.proxy_arg(0)
assert isinstance(seed_index, int)
return codegen_rng_seed_expr(state.codegen, seed_index)
def _next_rng_seed_slot() -> int:
return HostFunction.current().allocate_rng_seed_slot()
def _next_ref_rng_seed_slot() -> int:
from ..runtime.ref_mode import RefModeContext
return RefModeContext.current().allocate_rng_seed_slot()
def _ref_rng_seed(index: int) -> torch.Tensor:
from ..runtime.ref_mode import RefModeContext
return RefModeContext.current().lookup_rng_seed(index)
def _add_rewrite_desc(
descriptors: list[_ArgDesc],
value: object,
) -> _ArgDesc:
if isinstance(value, Node):
desc: _ArgDesc = (True, len(descriptors))
descriptors.append((True, value))
return desc
return (False, value)
def _shape_rewrite_desc(
shape_arg: object,
) -> tuple[list[_ArgDesc], list[_ArgDesc]]:
assert isinstance(shape_arg, (list, tuple, torch.Size))
descriptors: list[_ArgDesc] = []
shape_desc = [_add_rewrite_desc(descriptors, dim) for dim in shape_arg]
return descriptors, shape_desc
def decompose_rand(
shape: list[int | torch.SymInt],
*,
seed: int | torch.SymInt | torch.Tensor,
) -> torch.Tensor:
env = CompileEnvironment.current()
offset = _explicit_offset_from_shape(shape, device=env.device)
return _philox_rand_from_seed_and_offset(seed, offset)
def decompose_randint(
shape: list[int | torch.SymInt],
*,
low: int,
high: int,
seed: int | torch.SymInt | torch.Tensor,
) -> torch.Tensor:
env = CompileEnvironment.current()
offset = _explicit_offset_from_shape(shape, device=env.device)
return _philox_randint_from_seed_and_offset(
seed,
offset,
low=low,
high=high,
)
def _canonicalize_rng_device(device: DeviceLikeType) -> torch.device:
requested = torch.device(device)
env_device = torch.device(CompileEnvironment.current().device)
if requested.index is None and requested.type == env_device.type:
requested = torch.device(requested.type, env_device.index)
return requested
def _assert_rng_device_matches_env(device: DeviceLikeType | None) -> None:
if device is None:
return
env_device = torch.device(CompileEnvironment.current().device)
requested = _canonicalize_rng_device(device)
assert requested == env_device, f"expected {env_device}, got {requested}"
def _normalize_implicit_rng_request(
shape: list[_ShapeDim],
*,
dtype: torch.dtype | None,
default_dtype: torch.dtype,
device: DeviceLikeType | None,
requires_grad: object = False,
) -> tuple[list[_ShapeDim], torch.dtype]:
_assert_rng_device_matches_env(device)
assert not requires_grad
resolved_dtype = default_dtype if dtype is None else dtype
if not resolved_dtype.is_floating_point:
raise NotImplementedError(
f"implicit RNG only supports floating-point dtypes, got {resolved_dtype}"
)
return shape, resolved_dtype
def _runtime_seeded_random(
shape: list[_ShapeDim],
*,
dtype: torch.dtype,
sampler: Callable[
[int | torch.SymInt | torch.Tensor, torch.Tensor],
torch.Tensor,
],
) -> torch.Tensor:
seed_slot = _next_rng_seed_slot()
seed = _rng_seed(seed_slot)
env = CompileEnvironment.current()
offset = _explicit_offset_from_shape(shape, device=env.device)
values = sampler(seed, offset)
if dtype != torch.float32:
values = _convert_element_type(values, dtype)
return values
def _implicit_random(
shape: list[_ShapeDim],
*,
dtype: torch.dtype | None,
default_dtype: torch.dtype,
device: DeviceLikeType | None,
requires_grad: object = False,
sampler: Callable[
[int | torch.SymInt | torch.Tensor, torch.Tensor],
torch.Tensor,
],
) -> torch.Tensor:
shape, dtype = _normalize_implicit_rng_request(
shape,
dtype=dtype,
default_dtype=default_dtype,
device=device,
requires_grad=requires_grad,
)
return _runtime_seeded_random(shape, dtype=dtype, sampler=sampler)
def _ref_runtime_seeded_random(
shape: list[int | RefTile],
*,
dtype: torch.dtype,
rng_device: torch.device,
sampler: Callable[
[int | torch.SymInt | torch.Tensor, torch.Tensor],
torch.Tensor,
],
) -> torch.Tensor:
seed_slot = _next_ref_rng_seed_slot()
seed = _ref_rng_seed(seed_slot)
processed_shape, offset = _ref_rng_shape_and_offset(shape, device=rng_device)
values = sampler(seed, offset).reshape(processed_shape)
if dtype != torch.float32:
values = values.to(dtype)
return values.to(device=rng_device)
def ref_implicit_random(
shape: list[int | RefTile],
*,
dtype: torch.dtype | None,
default_dtype: torch.dtype,
device: DeviceLikeType | None,
requires_grad: object = False,
normal: bool,
) -> torch.Tensor:
rng_shape, resolved_dtype = _normalize_implicit_rng_request(
cast("list[_ShapeDim]", shape),
dtype=dtype,
default_dtype=default_dtype,
device=device,
requires_grad=requires_grad,
)
env = CompileEnvironment.current()
rng_device = (
torch.device(env.device) if device is None else _canonicalize_rng_device(device)
)
sampler = (
_philox_randn_from_seed_and_offset
if normal
else _philox_rand_from_seed_and_offset
)
return _ref_runtime_seeded_random(
cast("list[int | RefTile]", rng_shape),
dtype=resolved_dtype,
rng_device=rng_device,
sampler=sampler,
)
def _rewrite_runtime_args(
descriptors: list[_ArgDesc],
) -> tuple[list[Node], list[object]]:
runtime_args: list[Node] = []
example_args: list[object] = []
for is_dynamic, value in descriptors:
if is_dynamic:
assert isinstance(value, Node)
runtime_args.append(value)
example_args.append(value.meta["val"])
return runtime_args, example_args
def _copy_rewrite_subgraph(
graph: torch.fx.Graph,
helper_graph: torch.fx.Graph,
*,
before: Node,
runtime_args: list[Node],
) -> Node:
helper_placeholders = list(helper_graph.find_nodes(op="placeholder"))
helper_getattrs = list(helper_graph.find_nodes(op="get_attr"))
if helper_getattrs:
raise NotImplementedError(
f"unexpected helper constants: {[node.target for node in helper_getattrs]!r}"
)
with graph.inserting_before(before):
copied = graph.graph_copy(
helper_graph,
dict(zip(helper_placeholders, runtime_args, strict=True)),
)
assert isinstance(copied, Node)
return copied
def _trace_rewrite_subgraph(
graph: torch.fx.Graph,
node: Node,
helper: Callable[..., torch.Tensor],
descriptors: list[_ArgDesc],
) -> Node:
from .._compiler.device_ir import _make_fx
runtime_args, example_args = _rewrite_runtime_args(descriptors)
helper_graph = _make_fx(helper, *example_args)
location = node.meta.get("location")
if location is not None:
for helper_node in helper_graph.nodes:
if helper_node.op == "call_function" and "location" not in helper_node.meta:
helper_node.meta["location"] = location
return _copy_rewrite_subgraph(
graph,
helper_graph,
before=node,
runtime_args=runtime_args,
)
def _resolve_rewrite_arg(
flat_args: tuple[object, ...],
desc: _ArgDesc,
) -> object:
return flat_args[cast("int", desc[1])] if desc[0] else desc[1]
def _resolve_shape_desc(
flat_args: tuple[object, ...],
desc: _ArgDesc,
) -> int | torch.SymInt:
value = _resolve_rewrite_arg(flat_args, desc)
assert isinstance(value, (int, torch.SymInt))
return value
def _resolve_int_desc(
flat_args: tuple[object, ...],
desc: _ArgDesc,
) -> int:
value = _resolve_rewrite_arg(flat_args, desc)
assert isinstance(value, int)
return value
def _resolve_seed_desc(
flat_args: tuple[object, ...],
desc: _ArgDesc,
) -> int | torch.SymInt | torch.Tensor:
return cast(
"int | torch.SymInt | torch.Tensor", _resolve_rewrite_arg(flat_args, desc)
)
def _random_rewrite_nodes(graph: torch.fx.Graph) -> list[Node]:
targets = (
rand,
randint,
torch.ops.aten.rand.default,
torch.ops.aten.randn.default,
torch.ops.aten.rand_like.default,
torch.ops.aten.randn_like.default,
)
return sorted(
node
for target in targets
for node in graph.find_nodes(
op="call_function",
target=target,
sort=False,
)
)
def rewrite_implicit_random_ops(graph: torch.fx.Graph) -> None:
for node in _random_rewrite_nodes(graph):
if node.target is rand:
shape_arg = node.args[0]
descriptors, shape_desc = _shape_rewrite_desc(shape_arg)
seed_desc = _add_rewrite_desc(descriptors, node.args[1])
def helper(
*flat_args: object,
shape_desc: tuple[_ArgDesc, ...] = tuple(shape_desc),
seed_desc: _ArgDesc = seed_desc,
) -> torch.Tensor:
shape = [_resolve_shape_desc(flat_args, desc) for desc in shape_desc]
seed = _resolve_seed_desc(flat_args, seed_desc)
return decompose_rand(shape, seed=seed)
replacement = _trace_rewrite_subgraph(graph, node, helper, descriptors)
elif node.target is randint:
shape_arg = node.args[0]
descriptors, shape_desc = _shape_rewrite_desc(shape_arg)
low_desc = _add_rewrite_desc(descriptors, node.args[1])
high_desc = _add_rewrite_desc(descriptors, node.args[2])
seed_desc = _add_rewrite_desc(descriptors, node.args[3])
def helper(
*flat_args: object,
shape_desc: tuple[_ArgDesc, ...] = tuple(shape_desc),
low_desc: _ArgDesc = low_desc,
high_desc: _ArgDesc = high_desc,
seed_desc: _ArgDesc = seed_desc,
) -> torch.Tensor:
shape = [_resolve_shape_desc(flat_args, desc) for desc in shape_desc]
low_val = _resolve_int_desc(flat_args, low_desc)
high_val = _resolve_int_desc(flat_args, high_desc)
seed = _resolve_seed_desc(flat_args, seed_desc)
return decompose_randint(
shape,
low=low_val,
high=high_val,
seed=seed,
)
replacement = _trace_rewrite_subgraph(graph, node, helper, descriptors)
elif node.target in {torch.ops.aten.rand.default, torch.ops.aten.randn.default}:
descriptors, shape_desc = _shape_rewrite_desc(node.args[0])
dtype = node.kwargs.get("dtype", torch.float32)
assert dtype is None or isinstance(dtype, torch.dtype)
device = node.kwargs.get("device")
requires_grad = node.kwargs.get("requires_grad", False)
if node.target is torch.ops.aten.rand.default:
sampler = _philox_rand_from_seed_and_offset
else:
sampler = _philox_randn_from_seed_and_offset
def helper(
*flat_args: object,
shape_desc: tuple[_ArgDesc, ...] = tuple(shape_desc),
dtype: torch.dtype | None = dtype,
device_arg: object | None = device,
requires_grad: object = requires_grad,
sampler: Callable[
[int | torch.SymInt | torch.Tensor, torch.Tensor],
torch.Tensor,
] = sampler,
) -> torch.Tensor:
shape = [_resolve_shape_desc(flat_args, desc) for desc in shape_desc]
return _implicit_random(
shape,
dtype=dtype,
default_dtype=torch.float32,
device=cast("DeviceLikeType | None", device_arg),
requires_grad=requires_grad,
sampler=sampler,
)
replacement = _trace_rewrite_subgraph(graph, node, helper, descriptors)
elif node.target in {
torch.ops.aten.rand_like.default,
torch.ops.aten.randn_like.default,
}:
tensor = node.args[0]
assert isinstance(tensor, Node)
descriptors: list[_ArgDesc] = [(True, tensor)]
dtype = node.kwargs.get("dtype")
assert dtype is None or isinstance(dtype, torch.dtype)
device = node.kwargs.get("device")
requires_grad = node.kwargs.get("requires_grad", False)
if node.target is torch.ops.aten.rand_like.default:
sampler = _philox_rand_from_seed_and_offset
else:
sampler = _philox_randn_from_seed_and_offset
def helper(
*flat_args: object,
dtype: torch.dtype | None = dtype,
device_arg: object | None = device,
requires_grad: object = requires_grad,
sampler: Callable[
[int | torch.SymInt | torch.Tensor, torch.Tensor],
torch.Tensor,
] = sampler,
) -> torch.Tensor:
(input_tensor,) = flat_args
assert isinstance(input_tensor, torch.Tensor)
return _implicit_random(
[*input_tensor.shape],
dtype=dtype,
default_dtype=input_tensor.dtype,
device=cast("DeviceLikeType | None", device_arg),
requires_grad=requires_grad,
sampler=sampler,
)
replacement = _trace_rewrite_subgraph(graph, node, helper, descriptors)
else:
continue
node.replace_all_uses_with(replacement)
graph.erase_node(node)
[docs]
@_decorators.api(tiles_as_sizes=True)
def rand(
shape: list[object],
seed: int | torch.Tensor,
device: torch.device | None = None,
) -> torch.Tensor:
"""
hl.rand provides a Philox-based pseudorandom number generator (PRNG) that
operates independently of PyTorch's global random seed. Instead, it
requires an explicit seed argument. Offsets are derived from the full
logical sizes of the tiles specified in the shape argument.
Args:
shape: A list of sizes for the output tensor
seed: A single element int64 tensor or int literal
device: Device must match the current compile environment device
Returns:
torch.Tensor: A device tensor of float32 dtype filled with uniform
random values in [0, 1)
Examples:
.. code-block:: python
@helion.kernel
def process_kernel(x: torch.Tensor) -> torch.Tensor:
output = torch.zeros_like(x)
(m,) = x.shape
for tile_m in hl.tile(m):
output[tile_m] = hl.rand([tile_m], seed=42)
return output
"""
raise NotInsideKernel
@_decorators.register_fake(rand)
def _rand_fake(
shape: list[int | torch.SymInt],
seed: int | torch.Tensor,
device: torch.device | None = None,
) -> torch.Tensor:
if not isinstance(shape, (list, tuple)):
raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}")
env = CompileEnvironment.current()
env.add_kernel_tensor_size(shape)
return torch.empty(
[*shape],
dtype=torch.float32,
device=env.device if device is None else device,
)
@_decorators.get_masked_value(rand)
def _(node: torch.fx.Node) -> float:
return 0
@_decorators.ref(rand)
def _(
shape: list[int | RefTile],
seed: int | torch.Tensor,
device: torch.device | None = None,
) -> torch.Tensor:
env = CompileEnvironment.current()
rng_device = env.device if device is None else device
processed_shape, offset = _ref_rng_shape_and_offset(shape, device=rng_device)
return philox_rand_ref(seed, offset).reshape(processed_shape).to(device=rng_device)
[docs]
@_decorators.api(tiles_as_sizes=True)
def randint(
shape: list[object],
low: int,
high: int,
seed: int | torch.Tensor,
device: torch.device | None = None,
) -> torch.Tensor:
"""
hl.randint provides a Philox-based pseudorandom integer generator (PRNG)
that operates independently of PyTorch's global random seed. Instead, it
requires an explicit seed argument. Offsets are derived from the full
logical sizes of the tiles specified in the shape argument.
Args:
shape: A list of sizes for the output tensor
low: Lowest integer to be drawn from the distribution (inclusive)
high: One above the highest integer to be drawn from the distribution
(exclusive)
seed: A single element int64 tensor or int literal
device: Device must match the current compile environment device
Returns:
torch.Tensor: A device tensor of int32 dtype filled with random
integers in [low, high)
Examples:
.. code-block:: python
@helion.kernel
def process_kernel(x: torch.Tensor) -> torch.Tensor:
output = torch.zeros(x.shape, dtype=torch.int32, device=x.device)
(m,) = x.shape
for tile_m in hl.tile(m):
output[tile_m] = hl.randint([tile_m], low=0, high=10, seed=42)
return output
"""
raise NotInsideKernel
@_decorators.register_fake(randint)
def _randint_fake(
shape: list[int | torch.SymInt],
low: int,
high: int,
seed: int | torch.Tensor,
device: torch.device | None = None,
) -> torch.Tensor:
if not isinstance(shape, (list, tuple)):
raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}")
if low >= high:
raise ValueError(f"low ({low}) must be less than high ({high})")
env = CompileEnvironment.current()
env.add_kernel_tensor_size(shape)
return torch.empty(
[*shape],
dtype=torch.int32,
device=env.device if device is None else device,
)
@_decorators.get_masked_value(randint)
def _(node: torch.fx.Node) -> int:
return 0
@_decorators.ref(randint)
def _(
shape: list[int | RefTile],
low: int,
high: int,
seed: int | torch.Tensor,
device: torch.device | None = None,
) -> torch.Tensor:
env = CompileEnvironment.current()
rng_device = env.device if device is None else device
processed_shape, offset = _ref_rng_shape_and_offset(shape, device=rng_device)
return (
philox_randint_ref(seed, offset, low, high)
.reshape(processed_shape)
.to(device=rng_device)
)