.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/cross_entropy.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_cross_entropy.py: Cross Entropy Loss Example ====================== This example demonstrates how to implement a cross entropy loss function using Helion. .. GENERATED FROM PYTHON SOURCE LINES 9-11 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 11-20 .. code-block:: Python from __future__ import annotations import torch import helion from helion._testing import run_example import helion.language as hl .. GENERATED FROM PYTHON SOURCE LINES 21-23 Cross Entropy Kernel ----------------- .. GENERATED FROM PYTHON SOURCE LINES 23-76 .. code-block:: Python @helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper]) def cross_entropy( logits: torch.Tensor, # [N, V] input logits labels: torch.Tensor, # [N] target labels ) -> torch.Tensor: """ Computes the cross entropy loss between logits and target labels. Implements the cross entropy loss function commonly used in classification tasks. The function computes the log softmax of the logits and then calculates the negative log likelihood of the true labels. Args: logits: Input logits tensor of shape [N, V] where N is batch size and V is vocabulary size labels: Target labels tensor of shape [N] containing class indices Returns: A scalar tensor containing the mean cross entropy loss """ n, v = logits.shape losses = torch.zeros([n], dtype=logits.dtype, device=logits.device) # Flatten logits once at the beginning logits_flat = logits.view(-1) for tile_n in hl.tile(n): # Get data for this tile labels_tile = labels[tile_n] # [tile_size] base_indices_tile = tile_n.index * v # [tile_size] # Compute the actual flat indices by adding the label offset flat_indices = base_indices_tile + labels_tile # Load the logits at the target indices logits_at_target = hl.load(logits_flat, [flat_indices]) # Compute log_softmax for numerical stability # Load the full rows for this tile logits_rows = logits[tile_n, :] # [tile_size, V] # Compute log-sum-exp max_logits = torch.amax(logits_rows, dim=-1, keepdim=True) shifted = logits_rows - max_logits exp_shifted = torch.exp(shifted) sum_exp = torch.sum(exp_shifted, dim=-1, keepdim=True) log_sum_exp = max_logits.squeeze(-1) + torch.log(sum_exp.squeeze(-1)) # Cross entropy loss: log_sum_exp - logit_at_target losses[tile_n] = log_sum_exp - logits_at_target return losses.mean() .. GENERATED FROM PYTHON SOURCE LINES 77-79 Main Function ----------- .. GENERATED FROM PYTHON SOURCE LINES 79-102 .. code-block:: Python def main() -> None: """ Main entry point that runs the cross entropy kernel verification. Tests with a batch size of 128 and vocabulary size of 1000. """ # Test with moderate size n, v = 128, 1000 logits = torch.randn(n, v, device="cuda", dtype=torch.float32) labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long) run_example( cross_entropy, torch.nn.functional.cross_entropy, (logits, labels), kernel_name="helion", baseline_name="torch", rtol=1e-4, atol=1e-4, ) if __name__ == "__main__": main() .. _sphx_glr_download_examples_cross_entropy.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: cross_entropy.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: cross_entropy.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: cross_entropy.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_