Note
Go to the end to download the full example code
One-Shot All-Reduce Example#
This example demonstrates how to implement a one-shot pulling all-reduce operation using Helion and PyTorch’s distributed capabilities. It includes a Helion kernel demonstrating how to do cross-device synchronization using symmetric memory signal pads and access symmetric memory tensor resident on peer devices.
Imports#
from __future__ import annotations
import os
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch.utils.cpp_extension import load_inline
import helion
from helion._testing import DEVICE
import helion.language as hl
Work around before symm mem natively supports extract dev_ptrs as tensors: from_blob
from_blob_cpp = """
#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
at::Tensor from_blob(uint64_t data_ptr, c10::IntArrayRef sizes, py::object dtype) {
at::Tensor tensor = at::for_blob((void*)data_ptr, sizes)
.deleter([](void *ptr) {
;
})
.options(at::device(at::kCUDA).dtype(((THPDtype*)dtype.ptr())->scalar_type))
.make_tensor();
return tensor;
}
"""
cpp_mod = load_inline(
"cpp_mod", cpp_sources=from_blob_cpp, with_cuda=True, functions=["from_blob"]
)
def dev_array_to_tensor_short(
dev_array_ptr: int, shape: tuple[int], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
"""
Convert a device array pointer to a PyTorch tensor.
This is a workaround function that creates a PyTorch tensor from a raw device pointer
using the C++ extension. It's used to interface with symmetric memory device pointers
before native support is available.
Args:
dev_array_ptr: Raw device pointer as integer
shape: Shape of the tensor to create
dtype: PyTorch data type for the tensor
device: Target device for the tensor
Returns:
PyTorch tensor created from the device pointer
"""
# pyrefly: ignore [missing-attribute]
return cpp_mod.from_blob(dev_array_ptr, shape, dtype)
One Shot All-Reduce Kernel Implementation#
@helion.kernel(
config=helion.Config(
block_sizes=[8192],
num_warps=32,
),
static_shapes=True,
)
def one_shot_all_reduce_kernel(
signal_pad_addrs: torch.Tensor,
local_signal_pad: torch.Tensor,
a_shared: torch.Tensor,
my_rank: hl.constexpr,
group_name: hl.constexpr,
) -> torch.Tensor:
"""
Helion JIT-compiled kernel for one-shot all-reduce operation.
This kernel implements a distributed all-reduce using symmetric memory and signal pads
for cross-device synchronization. It performs element-wise summation across all devices
in the distributed group using tiled computation for memory efficiency.
Args:
signal_pad_addrs: Tensor containing addresses of signal pads for all devices
local_signal_pad: Local signal pad for synchronization
a_shared_tuple: Tuple of shared tensors from all devices in the group
my_rank: Current device's rank in the distributed group
Returns:
Tensor containing the all-reduced result (sum across all devices)
"""
_, world_size = local_signal_pad.size()
world_size = hl.specialize(world_size)
out = torch.empty_like(a_shared)
N = out.size(0)
a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, group_name)
for tile_n in hl.tile(N):
# Sync all devices through signal_pad to make sure
# all previous writes to the shared tensor are visible
ptr_tile = signal_pad_addrs[:]
stack_signalpad = hl.stacktensor_like(local_signal_pad, ptr_tile)
hl.signal(
stack_signalpad,
[tile_n.id, my_rank],
signal=1,
wait_for=0,
scope="sys",
hasPreviousMemAccess=False,
)
for world in hl.tile(world_size, block_size=world_size):
hl.wait(
local_signal_pad,
[tile_n.id, world],
signal=1,
update=0,
scope="sys",
)
acc = hl.zeros([tile_n], dtype=a_shared.dtype, device=local_signal_pad.device)
for a in a_shared_tuple:
acc += a[tile_n]
out[tile_n] = acc
# Sync all devices through signal_pad to make sure our writes to shared
# tensor are visible to subsequent kernels.
hl.signal(
stack_signalpad, [tile_n.id, my_rank], signal=1, wait_for=0, scope="sys"
)
for world in hl.tile(world_size, block_size=world_size):
hl.wait(
local_signal_pad,
[tile_n.id, world],
signal=1,
update=0,
scope="sys",
hasSubsequentMemAccess=False,
)
return out
Attract tensors from symmetric memory handler#
def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
"""
Prepares symmetric memory tensors for Helion one-shot all-reduce kernel.
Tracks shared tensors as tuple of tensors, and/or dev_ptrs tensors.
Args:
a_shared: Input tensor to be all-reduced across all devices
Returns:
Tensor containing the all-reduced result (sum across all devices)
"""
assert dist.group.WORLD is not None
group_name = dist.group.WORLD.group_name
symm_mem_hdl = symm_mem.rendezvous(a_shared, group_name)
local_signal_pad = symm_mem_hdl.get_signal_pad(
symm_mem_hdl.rank, dtype=torch.int32
).view(-1, symm_mem_hdl.world_size)
signal_pad_addrs = dev_array_to_tensor_short(
symm_mem_hdl.signal_pad_ptrs_dev,
(symm_mem_hdl.world_size,),
dtype=torch.uint64,
device=a_shared.device,
)
return one_shot_all_reduce_kernel(
signal_pad_addrs,
local_signal_pad,
a_shared,
my_rank=symm_mem_hdl.rank,
group_name=group_name,
)
def reference_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
"""
Reference implementation using the symmetric memory one-shot primitive.
"""
dist_group = dist.group.WORLD
if dist_group is None:
raise RuntimeError("No distributed group available")
a_shared_clone = symm_mem.empty(
a_shared.shape,
dtype=a_shared.dtype,
device=a_shared.device,
)
symm_mem.rendezvous(a_shared_clone, dist_group.group_name)
a_shared_clone.copy_(a_shared)
return torch.ops.symm_mem.one_shot_all_reduce(
a_shared_clone, "sum", dist_group.group_name
)
Testing Function#
def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
"""
Test the Helion all-reduce implementation against PyTorch's reference implementation.
Args:
N: Total number of elements to test (will be divided by world_size per device)
device: CUDA device to run the test on
dtype: Data type for the test tensors
"""
dist_group = dist.group.WORLD
assert dist_group is not None
world_size = dist.get_world_size()
rank = dist.get_rank()
# Create symmetric memory tensor for Helion implementation
symm_mem.enable_symm_mem_for_group(dist_group.group_name)
# TODO @kwen2501: no need to divide N
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()
# Create symmetric memory tensor for reference implementation
a_shared_ref = symm_mem.empty(N // world_size, dtype=dtype, device=device)
a_shared_ref.copy_(a_shared)
print(f"[Rank {rank}] Running Helion all-reduce...")
result_helion = helion_one_shot_all_reduce(a_shared)
print(f"[Rank {rank}] Running reference all-reduce...")
result_ref = reference_one_shot_all_reduce(a_shared_ref)
# Compare results
print(f"[Rank {rank}] Comparing results...")
torch.testing.assert_close(result_helion, result_ref, rtol=1e-1, atol=1e-1)
print(f"[Rank {rank}] Results match! ✓")
def main() -> None:
"""
Main entry point for the all-reduce example.
Sets up the distributed environment, initializes CUDA devices, and runs the
all-reduce test, and then clean up.
"""
# Only NVSHMEM backend implements `get_remote_tensor` for now.
symm_mem.set_backend("NVSHMEM")
rank = int(os.environ["LOCAL_RANK"])
torch.manual_seed(42 + rank)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
dist.init_process_group("nccl")
test(16384, device, torch.bfloat16)
dist.destroy_process_group()
if __name__ == "__main__":
"""
Run with:
python -m torch.distributed.run --standalone \
--nproc-per-node 4 \
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
--no_python python3 examples/distributed/all_reduce.py
"""
# TODO(adam-smnk): generalize to XPU
assert DEVICE.type == "cuda", "Requires CUDA device"
main()
Total running time of the script: (0 minutes 0.000 seconds)