Rate this Page

helion.language.jagged_tile#

helion.language.jagged_tile(parent)[source]#

Iterate over a jagged inner dimension using a 1D parent tensor of per-lane ends.

jagged_tile is the jagged counterpart to tile(). Instead of taking a scalar upper bound, it takes a 1D tensor from the enclosing parent tile context. Each element of parent gives the true end of the jagged child loop for the corresponding parent lane.

Conceptually, Helion lowers:

for tile_k in hl.jagged_tile(parent):
    ...

to:

end = parent.amax()
for tile_k in hl.tile(end):
    mask = tile_k.index[None, :] < parent[:, None]
    ...

while automatically masking out indices where tile_k.index >= parent for each parent lane. This lets you write ragged loops directly instead of writing a dense loop and manually constructing masks.

Parameters:

parent (object) – 1D tensor in the parent tile context. parent[i] is the true end of the jagged child loop for parent lane i.

Returns:

Iterator over tile objects for the jagged child dimension

Return type:

Iterator[Tile]

Examples

Before jagged_tile: dense loop plus manual mask:

@helion.kernel
def jagged_row_sum_masked(
    x: torch.Tensor, row_lengths: torch.Tensor
) -> torch.Tensor:
    b = row_lengths.size(0)
    max_len = row_lengths.amax()
    out = torch.zeros([b], dtype=x.dtype, device=x.device)

    for tile_b in hl.tile(b):
        lengths = row_lengths[tile_b]
        acc = hl.zeros([tile_b], dtype=x.dtype)

        for tile_k in hl.tile(max_len):
            mask = tile_k.index[None, :] < lengths[:, None]
            vals = hl.load(x, [tile_b, tile_k], extra_mask=mask)
            acc = acc + vals.sum(dim=1)

        out[tile_b] = acc
    return out

With jagged_tile: the mask becomes implicit:

@helion.kernel
def jagged_row_sum(
    x: torch.Tensor, row_lengths: torch.Tensor
) -> torch.Tensor:
    b = row_lengths.size(0)
    out = torch.zeros([b], dtype=x.dtype, device=x.device)

    for tile_b in hl.tile(b):
        lengths = row_lengths[tile_b]
        acc = hl.zeros([tile_b], dtype=x.dtype)

        for tile_k in hl.jagged_tile(lengths):
            acc = acc + x[tile_b, tile_k].sum(dim=1)

        out[tile_b] = acc
    return out

Packed jagged data with offsets:

@helion.kernel
def jagged_sum(
    x_data: torch.Tensor,
    x_offsets: torch.Tensor,
) -> torch.Tensor:
    b = x_offsets.size(0) - 1
    out = torch.zeros([b], dtype=x_data.dtype, device=x_data.device)

    for tile_b in hl.tile(b):
        starts = x_offsets[tile_b]
        ends = x_offsets[tile_b.index + 1]
        lengths = ends - starts

        acc = hl.zeros([tile_b], dtype=x_data.dtype)
        for tile_k in hl.jagged_tile(lengths):
            idx = starts[:, None] + tile_k.index[None, :]
            acc = acc + x_data[idx].sum(dim=1)

        out[tile_b] = acc

    return out

See also

  • tile(): For dense or uniform iteration spaces

Note

jagged_tile currently has a few important restrictions:

  • The input must be a 1D tensor. Scalars and higher-rank tensors are not allowed.

  • jagged_tile cannot be used as the outermost loop of a kernel.

  • A jagged child tile must be indexed together with its parent axes. For example, x[tile_k] is invalid if tile_k comes from hl.jagged_tile(lengths) under tile_b. Use x[tile_b, tile_k] or another indexing expression that preserves the parent context.

  • Use tile() when the loop bound is uniform across lanes.

  • Check more jagged kernels using hl.jagged_tile in the examples/ directory.