Note
Go to the end to download the full example code.
Embedding Lookup Example
This example demonstrates how to implement an embedding lookup operation using Helion.
Imports
from __future__ import annotations
import torch
import helion
from helion._testing import run_example
import helion.language as hl
Embedding Kernel
@helion.kernel()
def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Performs embedding lookup for input indices.
Maps indices in the input tensor to vectors from the embedding weight matrix.
Args:
x: Input tensor of indices of any shape
weight: Embedding weight matrix of shape [num_embeddings, embedding_dim]
Returns:
Output tensor of shape [*x.shape, embedding_dim] containing the embedding vectors
"""
x_flat = x.reshape(-1) # collapse x into a single dimension
_, embedding_dim = weight.size()
out = torch.empty(
[x_flat.size(0), embedding_dim], dtype=weight.dtype, device=weight.device
)
for tile_b, tile_e in hl.tile([x_flat.size(0), embedding_dim]):
out[tile_b, tile_e] = weight[x_flat[tile_b], tile_e]
# restore the original shape
return out.view(*x.size(), embedding_dim)
Benchmark Wrapper
def embedding_tritonbench(
V: int, D: int, inp: torch.Tensor, shared_weight: torch.Tensor
) -> torch.Tensor:
"""
Wrapper for tritonbench that matches its interface.
Args:
V: Vocabulary size (unused, provided for compatibility)
D: Embedding dimension (unused, provided for compatibility)
inp: Input tensor of indices
shared_weight: Embedding weight matrix
Returns:
Output tensor containing the embedding vectors
"""
return embedding(inp, shared_weight)
Main Function
def main() -> None:
"""
Main entry point that runs the embedding kernel verification.
Tests with a batch of indices and an embedding table of size 16x64.
"""
num_embeddings, embedding_dim = 16, 64
x = torch.randint(0, num_embeddings, [256, 32], device="cuda", dtype=torch.int32)
weight = torch.randn([num_embeddings, embedding_dim], device="cuda")
run_example(
embedding, torch.nn.functional.embedding, (x, weight), atol=0.0, rtol=0.0
)
if __name__ == "__main__":
main()