helion.language.dot

helion.language.dot(mat1, mat2, acc=None)[source]

Performs a matrix multiplication of tensors with support for multiple dtypes.

This operation performs matrix multiplication with inputs of various dtypes including float16, bfloat16, float32, int8, and FP8 formats (e4m3fn, e5m2). The computation is performed with appropriate precision based on the input dtypes.

Parameters:
  • mat1 (Tensor) – First matrix (2D or 3D tensor of torch.float16, torch.bfloat16, torch.float32, torch.int8, torch.float8_e4m3fn, or torch.float8_e5m2)

  • mat2 (Tensor) – Second matrix (2D or 3D tensor of torch.float16, torch.bfloat16, torch.float32, torch.int8, torch.float8_e4m3fn, or torch.float8_e5m2)

  • acc (Tensor | None) – The accumulator tensor (2D or 3D tensor of torch.float16, torch.float32, or torch.int32). If not None, the result is added to this tensor. If None, a new tensor is created with appropriate dtype based on inputs.

Return type:

Tensor

Returns:

Result of matrix multiplication. If acc is provided, returns acc + (mat1 @ mat2). Otherwise returns (mat1 @ mat2) with promoted dtype.

Example

>>> # FP8 example
>>> a = torch.randn(32, 64, device="cuda").to(torch.float8_e4m3fn)
>>> b = torch.randn(64, 128, device="cuda").to(torch.float8_e4m3fn)
>>> c = torch.zeros(32, 128, device="cuda", dtype=torch.float32)
>>> result = hl.dot(a, b, acc=c)  # result is c + (a @ b)
>>> # Float16 example
>>> a = torch.randn(32, 64, device="cuda", dtype=torch.float16)
>>> b = torch.randn(64, 128, device="cuda", dtype=torch.float16)
>>> result = hl.dot(a, b)  # result dtype will be torch.float16
>>> # Int8 example
>>> a = torch.randint(-128, 127, (32, 64), device="cuda", dtype=torch.int8)
>>> b = torch.randint(-128, 127, (64, 128), device="cuda", dtype=torch.int8)
>>> acc = torch.zeros(32, 128, device="cuda", dtype=torch.int32)
>>> result = hl.dot(a, b, acc=acc)  # int8 x int8 -> int32