.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/concatenate.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_concatenate.py: Tensor Concatenation Example ======================== This example demonstrates how to implement a tensor concatenation operation using Helion. .. GENERATED FROM PYTHON SOURCE LINES 9-11 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 11-20 .. code-block:: Python from __future__ import annotations import torch import helion from helion._testing import run_example import helion.language as hl .. GENERATED FROM PYTHON SOURCE LINES 21-23 Concatenation Kernel ----------------- .. GENERATED FROM PYTHON SOURCE LINES 23-55 .. code-block:: Python @helion.kernel() def concat2d_dim1(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Concatenates two 2D tensors along dimension 1 (columns). Args: x: First input tensor of shape [M, N1] y: Second input tensor of shape [M, N2] with same first dimension as x Returns: Output tensor of shape [M, N1+N2] containing the concatenation of x and y along dimension 1 """ assert x.size(0) == y.size(0) out = torch.empty( [x.size(0), x.size(1) + y.size(1)], dtype=x.dtype, device=x.device ) for tile0, tile1 in hl.tile(out.size()): # Most masking is automatic in helion, but tile1 spans both x and y we need to do some manual masking x_part = hl.load( x, [tile0, tile1], extra_mask=(tile1.index < x.size(1))[None, :] ) y_part = hl.load( y, [tile0, tile1.index - x.size(1)], extra_mask=(tile1.index >= x.size(1))[None, :], ) out[tile0, tile1] = torch.where( (tile1.index < x.size(1))[None, :], x_part, y_part ) return out .. GENERATED FROM PYTHON SOURCE LINES 56-58 Main Function ----------- .. GENERATED FROM PYTHON SOURCE LINES 58-70 .. code-block:: Python def main() -> None: """ Main entry point that runs the concatenation kernel verification. Tests with two tensors of shapes [1500, 400] and [1500, 600]. """ x = torch.randn([1500, 400], device="cuda") y = torch.randn([1500, 600], device="cuda") run_example(concat2d_dim1, lambda x, y: torch.cat([x, y], dim=1), (x, y)) if __name__ == "__main__": main() .. _sphx_glr_download_examples_concatenate.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: concatenate.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: concatenate.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: concatenate.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_