Rate this Page

helion.language.jagged_tile#

helion.language.jagged_tile(parent)[source]#

Iterate over a jagged inner dimension using an N-D parent tensor of per-lane ends.

jagged_tile is the jagged counterpart to tile(). Instead of taking a scalar upper bound, it takes a tensor whose every axis comes from an enclosing parent tile context. Each element of parent gives the true end of the jagged child loop for the corresponding parent lane.

Conceptually, Helion lowers:

for tile_k in hl.jagged_tile(parent):
    ...

to:

end = parent.amax()
for tile_k in hl.tile(end):
    mask = tile_k.index[None, :] < parent[:, None]
    ...

while automatically masking out indices where tile_k.index >= parent for each parent lane. This lets you write ragged loops directly instead of writing a dense loop and manually constructing masks.

Parameters:

parent (object) – N-D tensor whose every axis is an enclosing tile axis. parent[i, ...] is the true end of the jagged child loop for that combination of parent lanes. The 1-D case is the common scalar-of-rows pattern.

Returns:

Iterator over tile objects for the jagged child dimension

Return type:

Iterator[Tile]

Examples

Before jagged_tile: dense loop plus manual mask:

@helion.kernel
def jagged_row_sum_masked(
    x: torch.Tensor, row_lengths: torch.Tensor
) -> torch.Tensor:
    b = row_lengths.size(0)
    out = torch.zeros([b], dtype=x.dtype, device=x.device)

    for tile_b in hl.tile(b):
        lengths = row_lengths[tile_b]
        max_len = lengths.amax()
        acc = hl.zeros([tile_b], dtype=x.dtype)

        for tile_k in hl.tile(max_len):
            mask = tile_k.index[None, :] < lengths[:, None]
            vals = hl.load(x, [tile_b, tile_k], extra_mask=mask)
            acc = acc + vals.sum(dim=1)

        out[tile_b] = acc
    return out

With jagged_tile: the mask becomes implicit:

@helion.kernel
def jagged_row_sum(
    x: torch.Tensor, row_lengths: torch.Tensor
) -> torch.Tensor:
    b = row_lengths.size(0)
    out = torch.zeros([b], dtype=x.dtype, device=x.device)

    for tile_b in hl.tile(b):
        lengths = row_lengths[tile_b]
        acc = hl.zeros([tile_b], dtype=x.dtype)

        for tile_k in hl.jagged_tile(lengths):
            acc = acc + x[tile_b, tile_k].sum(dim=1)

        out[tile_b] = acc
    return out

Packed jagged data with offsets:

@helion.kernel
def jagged_sum(
    x_data: torch.Tensor,
    x_offsets: torch.Tensor,
) -> torch.Tensor:
    b = x_offsets.size(0) - 1
    out = torch.zeros([b], dtype=x_data.dtype, device=x_data.device)

    for tile_b in hl.tile(b):
        starts = x_offsets[tile_b]
        ends = x_offsets[tile_b.index + 1]
        lengths = ends - starts

        acc = hl.zeros([tile_b], dtype=x_data.dtype)
        for tile_k in hl.jagged_tile(lengths):
            idx = starts[:, None] + tile_k.index[None, :]
            acc = acc + x_data[idx].sum(dim=1)

        out[tile_b] = acc

    return out

See also

  • tile(): For dense or uniform iteration spaces

Note

jagged_tile currently has a few important restrictions:

  • The input must be a tensor of rank >= 1. Scalars are not allowed, and every axis of the parent tensor must come from an enclosing tile context.

  • jagged_tile cannot be used as the outermost loop of a kernel.

  • A jagged child tile must be indexed together with its parent axes. For example, x[tile_k] is invalid if tile_k comes from hl.jagged_tile(lengths) under tile_b. Use x[tile_b, tile_k] or another indexing expression that preserves the parent context.

  • Use tile() when the loop bound is uniform across lanes.

  • Check more jagged kernels using hl.jagged_tile in the examples/ directory.