.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/embedding.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_embedding.py: Embedding Lookup Example ==================== This example demonstrates how to implement an embedding lookup operation using Helion. .. GENERATED FROM PYTHON SOURCE LINES 9-11 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 11-20 .. code-block:: Python from __future__ import annotations import torch import helion from helion._testing import run_example import helion.language as hl .. GENERATED FROM PYTHON SOURCE LINES 21-23 Embedding Kernel ------------- .. GENERATED FROM PYTHON SOURCE LINES 23-48 .. code-block:: Python @helion.kernel() def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: """ Performs embedding lookup for input indices. Maps indices in the input tensor to vectors from the embedding weight matrix. Args: x: Input tensor of indices of any shape weight: Embedding weight matrix of shape [num_embeddings, embedding_dim] Returns: Output tensor of shape [*x.shape, embedding_dim] containing the embedding vectors """ x_flat = x.reshape(-1) # collapse x into a single dimension _, embedding_dim = weight.size() out = torch.empty( [x_flat.size(0), embedding_dim], dtype=weight.dtype, device=weight.device ) for tile_b, tile_e in hl.tile([x_flat.size(0), embedding_dim]): out[tile_b, tile_e] = weight[x_flat[tile_b], tile_e] # restore the original shape return out.view(*x.size(), embedding_dim) .. GENERATED FROM PYTHON SOURCE LINES 49-51 Benchmark Wrapper -------------- .. GENERATED FROM PYTHON SOURCE LINES 51-69 .. code-block:: Python def embedding_tritonbench( V: int, D: int, inp: torch.Tensor, shared_weight: torch.Tensor ) -> torch.Tensor: """ Wrapper for tritonbench that matches its interface. Args: V: Vocabulary size (unused, provided for compatibility) D: Embedding dimension (unused, provided for compatibility) inp: Input tensor of indices shared_weight: Embedding weight matrix Returns: Output tensor containing the embedding vectors """ return embedding(inp, shared_weight) .. GENERATED FROM PYTHON SOURCE LINES 70-72 Main Function ----------- .. GENERATED FROM PYTHON SOURCE LINES 72-87 .. code-block:: Python def main() -> None: """ Main entry point that runs the embedding kernel verification. Tests with a batch of indices and an embedding table of size 16x64. """ num_embeddings, embedding_dim = 16, 64 x = torch.randint(0, num_embeddings, [256, 32], device="cuda", dtype=torch.int32) weight = torch.randn([num_embeddings, embedding_dim], device="cuda") run_example( embedding, torch.nn.functional.embedding, (x, weight), atol=0.0, rtol=0.0 ) if __name__ == "__main__": main() .. _sphx_glr_download_examples_embedding.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: embedding.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: embedding.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: embedding.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_