.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/jagged_mean.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_jagged_mean.py: Jagged Mean Example =============== This example demonstrates how to compute the mean of each row in a jagged tensor with variable features per row using Helion. .. GENERATED FROM PYTHON SOURCE LINES 10-12 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 12-21 .. code-block:: Python from __future__ import annotations import torch import helion from helion._testing import run_example import helion.language as hl .. GENERATED FROM PYTHON SOURCE LINES 22-24 Jagged Mean Kernel --------------- .. GENERATED FROM PYTHON SOURCE LINES 24-104 .. code-block:: Python @helion.kernel() def jagged_mean_kernel( x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_counts: torch.Tensor, # [num_rows] - number of features per row max_M_tensor: torch.Tensor, # Dummy tensor whose size indicates max features ) -> torch.Tensor: """ Compute the mean of each row in a jagged tensor with variable features per row. Args: x_data: 2-D tensor of shape (total_elements, max_M) holding all elements x_offsets: (num_rows + 1) tensor. Row i is the slice x_data[x_offsets[i] : x_offsets[i+1], :] x_feature_counts: (num_rows) tensor. Number of valid features for each row max_M_tensor: Dummy tensor whose numel() gives max number of features Returns: 2-D tensor of shape (num_rows, max_M) containing the mean of each row. Invalid features (beyond x_feature_counts[i]) are set to 0. """ num_rows = x_offsets.size(0) - 1 max_M = max_M_tensor.numel() # Extract max features from dummy tensor out = torch.zeros([num_rows, max_M], dtype=x_data.dtype, device=x_data.device) # Flatten x_data for easier indexing x_flat = x_data.view(-1) # Process rows in tiles for tile_b in hl.tile(num_rows): starts = x_offsets[tile_b] ends = x_offsets[tile_b.index + 1] nnz = ends - starts max_nnz = nnz.amax() # Get feature counts for this tile of rows feature_counts = x_feature_counts[tile_b] # Process features in tiles for tile_m in hl.tile(max_M): # Create mask for valid features feature_valid = tile_m.index < feature_counts[:, None] # Initialize accumulator row_sums = hl.zeros([tile_b, tile_m], dtype=x_data.dtype) # Process elements within each row for tile_k in hl.tile(0, max_nnz): # Compute flattened indices base_indices = starts[:, None] + tile_k.index[None, :] flat_indices = ( base_indices[:, :, None] * max_M + tile_m.index[None, None, :] ) # Combined mask: valid row element AND valid feature row_mask = tile_k.index[None, :] < nnz[:, None] combined_mask = row_mask[:, :, None] & feature_valid[:, None, :] x_slice = hl.load( x_flat, [flat_indices], extra_mask=combined_mask, ) # Accumulate - sum across the k dimension (dim=1) row_sums = row_sums + x_slice.sum(dim=1) # Compute mean nnz_float = nnz.to(x_data.dtype) nnz_expanded = nnz_float[:, None] # Compute result with feature masking result = torch.where(nnz_expanded > 0, row_sums / nnz_expanded, 0.0) # Apply feature mask to output out[tile_b, tile_m] = torch.where(feature_valid, result, 0.0) return out .. GENERATED FROM PYTHON SOURCE LINES 105-107 Reference Implementation -------------------- .. GENERATED FROM PYTHON SOURCE LINES 107-136 .. code-block:: Python def reference_jagged_mean_kernel_pytorch( x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_counts: torch.Tensor, max_M: int, ) -> torch.Tensor: """ PyTorch reference implementation for jagged mean with variable features. Args: x_data: 2-D tensor holding all elements x_offsets: Offsets tensor for row indexing x_feature_counts: Number of valid features per row max_M: Maximum number of features Returns: Tensor containing the mean of each row """ num_rows = x_offsets.numel() - 1 out = torch.zeros((num_rows, max_M), dtype=x_data.dtype, device=x_data.device) for i in range(num_rows): start = int(x_offsets[i]) end = int(x_offsets[i + 1]) num_features = int(x_feature_counts[i]) if end > start and num_features > 0: out[i, :num_features] = x_data[start:end, :num_features].mean(dim=0) return out .. GENERATED FROM PYTHON SOURCE LINES 137-139 Benchmark Wrapper -------------- .. GENERATED FROM PYTHON SOURCE LINES 139-169 .. code-block:: Python def jagged_mean_tritonbench( x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float ) -> torch.Tensor: """ Wrapper for tritonbench that matches the expected interface. Args: x: Nested tensor in jagged format with shape (B, *, M) B: Batch size M: Number of features seqlen: Maximum sequence length sparsity: Sparsity factor (not used) Returns: Tensor of shape (B, M) with mean values per row and feature """ x_values = x._values x_offsets = x._offsets # pyright: ignore[reportAttributeAccessIssue] feature_counts = torch.full( (B,), M, dtype=torch.int32, device=x_values.device, # pyright: ignore[reportAttributeAccessIssue] ) max_M_tensor = torch.empty(M, device=x_values.device) # pyright: ignore[reportAttributeAccessIssue] return jagged_mean_kernel(x_values, x_offsets, feature_counts, max_M_tensor) .. GENERATED FROM PYTHON SOURCE LINES 170-172 Main Function ----------- .. GENERATED FROM PYTHON SOURCE LINES 172-203 .. code-block:: Python def main() -> None: """ Main entry point that runs the jagged mean kernel verification. Creates test data with random jagged tensors and feature counts, then compares the kernel implementation against the PyTorch reference implementation. """ num_rows, max_cols = 32, 64 device = "cuda" lengths = torch.randint(1, max_cols + 1, (num_rows,), device=device) x_offsets = torch.cat( [torch.zeros(1, dtype=torch.long, device=device), torch.cumsum(lengths, dim=0)] ) nnz = int(x_offsets[-1]) M = 8 # number of features x_data = torch.randn(nnz, M, dtype=torch.float32, device=device) feature_counts = torch.randint( 1, M + 1, (num_rows,), dtype=torch.int32, device=device ) max_M_tensor = torch.empty(M, device=device) run_example( lambda x, o, fc, mt: jagged_mean_kernel(x, o, fc, mt), lambda x, o, fc, mt: reference_jagged_mean_kernel_pytorch(x, o, fc, mt.numel()), (x_data, x_offsets, feature_counts, max_M_tensor), ) if __name__ == "__main__": main() .. _sphx_glr_download_examples_jagged_mean.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: jagged_mean.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: jagged_mean.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: jagged_mean.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_