helion.language.jagged_tile#
- helion.language.jagged_tile(parent)[source]#
Iterate over a jagged inner dimension using an N-D parent tensor of per-lane ends.
jagged_tileis the jagged counterpart totile(). Instead of taking a scalar upper bound, it takes a tensor whose every axis comes from an enclosing parent tile context. Each element ofparentgives the true end of the jagged child loop for the corresponding parent lane.Conceptually, Helion lowers:
for tile_k in hl.jagged_tile(parent): ...
to:
end = parent.amax() for tile_k in hl.tile(end): mask = tile_k.index[None, :] < parent[:, None] ...
while automatically masking out indices where
tile_k.index >= parentfor each parent lane. This lets you write ragged loops directly instead of writing a dense loop and manually constructing masks.- Parameters:
parent (
object) – N-D tensor whose every axis is an enclosing tile axis.parent[i, ...]is the true end of the jagged child loop for that combination of parent lanes. The 1-D case is the common scalar-of-rows pattern.- Returns:
Iterator over tile objects for the jagged child dimension
- Return type:
Iterator[Tile]
Examples
Before
jagged_tile: dense loop plus manual mask:@helion.kernel def jagged_row_sum_masked( x: torch.Tensor, row_lengths: torch.Tensor ) -> torch.Tensor: b = row_lengths.size(0) out = torch.zeros([b], dtype=x.dtype, device=x.device) for tile_b in hl.tile(b): lengths = row_lengths[tile_b] max_len = lengths.amax() acc = hl.zeros([tile_b], dtype=x.dtype) for tile_k in hl.tile(max_len): mask = tile_k.index[None, :] < lengths[:, None] vals = hl.load(x, [tile_b, tile_k], extra_mask=mask) acc = acc + vals.sum(dim=1) out[tile_b] = acc return out
With
jagged_tile: the mask becomes implicit:@helion.kernel def jagged_row_sum( x: torch.Tensor, row_lengths: torch.Tensor ) -> torch.Tensor: b = row_lengths.size(0) out = torch.zeros([b], dtype=x.dtype, device=x.device) for tile_b in hl.tile(b): lengths = row_lengths[tile_b] acc = hl.zeros([tile_b], dtype=x.dtype) for tile_k in hl.jagged_tile(lengths): acc = acc + x[tile_b, tile_k].sum(dim=1) out[tile_b] = acc return out
Packed jagged data with offsets:
@helion.kernel def jagged_sum( x_data: torch.Tensor, x_offsets: torch.Tensor, ) -> torch.Tensor: b = x_offsets.size(0) - 1 out = torch.zeros([b], dtype=x_data.dtype, device=x_data.device) for tile_b in hl.tile(b): starts = x_offsets[tile_b] ends = x_offsets[tile_b.index + 1] lengths = ends - starts acc = hl.zeros([tile_b], dtype=x_data.dtype) for tile_k in hl.jagged_tile(lengths): idx = starts[:, None] + tile_k.index[None, :] acc = acc + x_data[idx].sum(dim=1) out[tile_b] = acc return out
See also
tile(): For dense or uniform iteration spaces
Note
jagged_tilecurrently has a few important restrictions:The input must be a tensor of rank >= 1. Scalars are not allowed, and every axis of the parent tensor must come from an enclosing tile context.
jagged_tilecannot be used as the outermost loop of a kernel.A jagged child tile must be indexed together with its parent axes. For example,
x[tile_k]is invalid iftile_kcomes fromhl.jagged_tile(lengths)undertile_b. Usex[tile_b, tile_k]or another indexing expression that preserves the parent context.Use
tile()when the loop bound is uniform across lanes.Check more jagged kernels using
hl.jagged_tilein theexamples/directory.