diff --git a/.gitmodules b/.gitmodules index 41cf4672e..e08122c8b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ -[submodule "ext/shader_sdma"] - path = ext/shader_sdma - url = https://github.com/AARInternal/shader_sdma.git +[submodule "ext/rocm-xio"] + path = ext/rocm-xio + url = https://github.com/ROCm/rocm-xio.git diff --git a/examples/06_message_passing/message_passing_host_initiated.py b/examples/06_message_passing/message_passing_host_initiated.py index c6dd2238e..d2c1b88c0 100644 --- a/examples/06_message_passing/message_passing_host_initiated.py +++ b/examples/06_message_passing/message_passing_host_initiated.py @@ -8,7 +8,7 @@ the consumer (GPU 1) remains a device kernel. Key difference from message_passing_put.py: -- Producer: Host uses anvil to initiate SDMA transfers from Python +- Producer: Host uses sdma_ep (rocm-xio) to initiate SDMA transfers from Python - Consumer: Same device kernel waiting for data This shows how to orchestrate GPU-to-GPU transfers from Python without diff --git a/iris/device/__init__.py b/iris/device/__init__.py new file mode 100644 index 000000000..5d28bee59 --- /dev/null +++ b/iris/device/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Device-side utilities for Iris. + +This module provides low-level device-side functions for use in Triton kernels, +including SDMA queue management and packet construction utilities. +""" + +from . import sdma_utils + +__all__ = ["sdma_utils"] diff --git a/iris/device/sdma_utils.py b/iris/device/sdma_utils.py new file mode 100644 index 000000000..20fed0588 --- /dev/null +++ b/iris/device/sdma_utils.py @@ -0,0 +1,396 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +SDMA (System DMA) device-side utilities for Triton kernels. + +This module provides low-level Triton device functions for directly managing +SDMA queues from GPU kernels, including packet construction, queue reservation, +and submission operations. +""" + +import triton +import triton.language as tl +from xio import sdma_ep + + +@triton.jit +def wait_cnt(): + tl.inline_asm_elementwise("s_waitcnt vmcnt(0)", "=r", [], dtype=tl.int32, is_pure=False, pack=1) + + +@triton.jit +def wrap_into_ring(index: tl.uint64): + queue_size_u32 = sdma_ep.SDMA_QUEUE_SIZE + queue_size = queue_size_u32.to(tl.uint64) + return index.to(tl.uint64) % queue_size + + +@triton.jit +def can_write_up_to(rptr, up_to_index: tl.uint64): + """Check if there's space to write up to the given index in the ring buffer.""" + hw_read_ptr = tl.load(rptr, cache_modifier=".cv", volatile=True) + return (up_to_index - hw_read_ptr) < sdma_ep.SDMA_QUEUE_SIZE + + +@triton.jit +def acquire( + queue_ptr_u32, + read_ptr, + write_ptr, + doorbell_ptr, + cached_write_ptr: tl.pointer_type(tl.uint64), + committed_write_ptr, + command_in_bytes: tl.uint64, +): + """ + Reserve space in the SDMA queue. + Returns (base_index, offset) where: + - base_index: the index where the packet should be written (cur_index initially) + - offset: padding bytes added for wraparound (0 if no wraparound) + + Based on ReserveQueueSpace from anvil_device.hpp. + """ + queue_size_u32 = sdma_ep.SDMA_QUEUE_SIZE + queue_size_in_bytes = queue_size_u32.to(tl.uint64) + + base_u32 = 0 + base = (base_u32).to(tl.uint64) + offset_u32 = 0 + offset = (offset_u32).to(tl.uint64) + + stop_loop = False + while not stop_loop: + cur_index = tl.load(cached_write_ptr, volatile=True) + offset = (offset_u32).to(tl.uint64) + + # Calculate current position in ring buffer + cur_ring_pos = wrap_into_ring(cur_index) + + # Check if we need to wrap around + if (cur_ring_pos + command_in_bytes) > queue_size_in_bytes: + # Need to pad to end of ring before wrap around + offset = queue_size_in_bytes - cur_ring_pos + + # Calculate new index including any wraparound padding + new_index = cur_index + command_in_bytes + offset + base = cur_index + + # Check if queue has space + if can_write_up_to(read_ptr, new_index): + # Try to atomically claim this space + if tl.atomic_cas(cached_write_ptr, cur_index, new_index, sem="relaxed", scope="gpu") == cur_index: + stop_loop = True + return base, offset + + +# acquire function using atomic_add instead of atomic_cas +@triton.jit +def acquire_fadd( + queue_ptr_u32, + read_ptr, + write_ptr, + doorbell_ptr, + cached_write_ptr: tl.pointer_type(tl.uint64), + committed_write_ptr, + command_in_bytes: tl.uint64, +): + """ + Reserve space in the SDMA queue using atomic_add. + Returns (base_index, 0) where base_index is where the packet should be written. + + Uses atomic_add instead of CAS. Immediately acquires space, and if wraparound + is detected, places padding NOP packet, submits it, and tries again. + Always returns a non-wrapping allocation. + """ + queue_size_u32 = sdma_ep.SDMA_QUEUE_SIZE + queue_size_in_bytes = queue_size_u32.to(tl.uint64) + + stop_loop = False + base_u32 = 0 + base = (base_u32).to(tl.uint64) + offset_u32 = 0 + offset = (offset_u32).to(tl.uint64) + + while not stop_loop: + # Atomically acquire space for the command + base = tl.atomic_add(cached_write_ptr, command_in_bytes, sem="relaxed", scope="gpu") + end_index = base + command_in_bytes + # Calculate current position in ring buffer + cur_ring_pos = wrap_into_ring(base) + + # Block until there is space in the queue to write the command + while not can_write_up_to(read_ptr, end_index): + pass + + # Check if we need to wrap around + if (cur_ring_pos + command_in_bytes) > queue_size_in_bytes: + # Wrap detected - need to pad to end of ring + padding_bytes = queue_size_in_bytes - cur_ring_pos + + # Place NOP packet at end of ring + place_nop_packet(queue_ptr_u32, base, padding_bytes) + + # Place remaining NOP padding at the beginning of the ring. + remaining_padding_bytes = command_in_bytes - padding_bytes + place_nop_packet(queue_ptr_u32, base + padding_bytes, remaining_padding_bytes) + + # Submit the padding - update committed write pointer + # This allows other threads to proceed past this padding + submit(write_ptr, doorbell_ptr, committed_write_ptr, base, end_index) + + # Continue loop to acquire space for the actual command (will be at ring start) + else: + # No wrap - this allocation is good + stop_loop = True + + return base, offset + + +@triton.jit +def submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr): + """ + Submit SDMA commands to the hardware by updating write pointer and ringing the doorbell. + + Waits for previous threads to commit, then updates write pointer, rings doorbell, + and updates committed pointer to allow subsequent threads to proceed. + """ + while tl.load(committed_write_ptr, cache_modifier=".cv", volatile=True) != base: + pass + + wait_cnt() + tl.debug_barrier() + + tl.store(write_ptr, pending_wptr, cache_modifier=".wt") + wait_cnt() + tl.debug_barrier() + + # Ring doorbell + tl.store(doorbell_ptr, pending_wptr, cache_modifier=".wt") + wait_cnt() + tl.debug_barrier() + + tl.store(committed_write_ptr, pending_wptr, cache_modifier=".wt") + + +@triton.jit +def place_nop_packet(queue_ptr_u32, offset_bytes: tl.uint64, padding_bytes): + """Place a NOP (no operation) packet for ring buffer padding.""" + num_padding_dwords = (padding_bytes // 4).to(tl.int32) + offset_ring_pos = wrap_into_ring(offset_bytes) + offset_in_dwords = (offset_ring_pos // 4).to(tl.int32) + for i in range(num_padding_dwords): + if i == 0: + tl.store(queue_ptr_u32 + offset_in_dwords, ((num_padding_dwords - 1) & 0xFFFF) << 16, cache_modifier=".wt") + else: + tl.store(queue_ptr_u32 + offset_in_dwords + i, 0, cache_modifier=".wt") + + +@triton.jit +def place_copy_packet(queue_ptr_u32, offset_bytes: tl.uint64, size_bytes: tl.uint32, src_ptr_val, dst_ptr_val): + """Place a SDMA_PKT_COPY_LINEAR packet for 1D linear memory copy.""" + slot_ptr_u32 = queue_ptr_u32 + (wrap_into_ring(offset_bytes) // 4) + # offset 0: op + sub_op + tl.store(slot_ptr_u32 + 0, 1, cache_modifier=".wt") + # offset 1: count + tl.store(slot_ptr_u32 + 1, size_bytes - 1, cache_modifier=".wt") + # offset 2: parameters + tl.store(slot_ptr_u32 + 2, 0, cache_modifier=".wt") + # offset 3: src address 31:0 + tl.store(slot_ptr_u32 + 3, src_ptr_val.to(tl.uint32), cache_modifier=".wt") + # offset 4: src address 63:32 + tl.store(slot_ptr_u32 + 4, (src_ptr_val >> 32).to(tl.uint32), cache_modifier=".wt") + # offset 5: dst address 31:0 + tl.store(slot_ptr_u32 + 5, dst_ptr_val.to(tl.uint32), cache_modifier=".wt") + # offset 6: dst address 63:32 + tl.store(slot_ptr_u32 + 6, (dst_ptr_val >> 32).to(tl.uint32), cache_modifier=".wt") + + +# atomic op codes and operation +# atomic add 32bit w/rtn: op 10, operation 15 +# atomic add 64bit w/rtn: op 10, operation: 47 -> 32 + 15 +# atomic add 32bit w/o rtn: op 10, operation: 31 -> 64 + 15 +# atomic add 64bit w/o rtn: op 10, operation: 63 -> 96 + 15 +# atomic cmp&swap 32bit w/rtn: op 10, operation: 8 +# atomic cmp&swap 64bit w/rtn: op 10, operation: -> 32 + 8 +# atomic cmp&swap 32bit w/o rtn: op 10, operation -> 64 + 8 +# atomic cmp&swap 64bit w/o rtn: op 10, operation 56 -> 06 + 8 +@triton.jit +def place_atomic_packet( + queue_ptr_u32, + offset_bytes: tl.uint64, + dst_ptr_val, + src_data, + comp_data, + OP: tl.constexpr, + RETURN: tl.constexpr = False, + IS_64_BIT: tl.constexpr = False, +): + """ + Place a SDMA_PKT_ATOMIC packet for atomic memory operations. + + OP codes: + 15: atomic add (32/64-bit with/without return) + 8: atomic compare-and-swap (32/64-bit with/without return) + Flags are encoded via IS_64_BIT (bit 4) and RETURN (bit 5). + """ + slot_ptr_u32 = queue_ptr_u32 + (wrap_into_ring(offset_bytes) // 4) + if IS_64_BIT: + OP = OP | (0x1 << 4) + if not RETURN: + OP = OP | (0x1 << 5) + tl.store(slot_ptr_u32 + 0, ((OP & 0x7F) << 25) | (0xA & 0xFF), cache_modifier=".wt") + # offset 1: dst address 31:0 + tl.store(slot_ptr_u32 + 1, dst_ptr_val.to(tl.uint32), cache_modifier=".wt") + # offset 2: dst address 63:32 + tl.store(slot_ptr_u32 + 2, (dst_ptr_val >> 32).to(tl.uint32), cache_modifier=".wt") + # offset 3: src data 31:0 + tl.store(slot_ptr_u32 + 3, src_data, cache_modifier=".wt") + # offset 4: src data 63:32 + if IS_64_BIT: + tl.store(slot_ptr_u32 + 4, (src_data << 32).to(tl.uint32), cache_modifier=".wt") + else: + tl.store(slot_ptr_u32 + 4, 0, cache_modifier=".wt") + # offset 5: compare data 31:0 + tl.store(slot_ptr_u32 + 5, comp_data, cache_modifier=".wt") + # offset 6: compare data 63:32 + if IS_64_BIT: + tl.store(slot_ptr_u32 + 6, comp_data, cache_modifier=".wt") + else: + tl.store(slot_ptr_u32 + 6, 0, cache_modifier=".wt") + # offset 7: loop timer + loop interval + tl.store(slot_ptr_u32 + 7, 0, cache_modifier=".wt") + + +@triton.jit +def place_atomic_add_packet(queue_ptr_u32, offset_bytes: tl.uint64, dst_ptr_val, val): + """Place an atomic add packet (OP=15, with return).""" + place_atomic_packet(queue_ptr_u32, offset_bytes, dst_ptr_val, val, 0, 15, True) + + +@triton.jit +def place_atomic_cas_packet( + queue_ptr_u32, + offset_bytes: tl.uint64, + dst_ptr_val, + compare_val, + swap_val, +): + """Place an atomic compare-and-swap packet (OP=8, with return).""" + place_atomic_packet(queue_ptr_u32, offset_bytes, dst_ptr_val, swap_val, compare_val, 8, True) + + +@triton.jit +def place_poll_regmem_packet( + queue_ptr_u32, + offset_bytes: tl.uint64, + flag_ptr_val, + expected_value, + interval: tl.constexpr = 10, + retry_count: tl.constexpr = 0xFFF, +): + """ + Place a SDMA_PKT_POLL_REGMEM packet for memory polling. + + Polls memory location until (value >= expected_value). + """ + slot_ptr_u32 = queue_ptr_u32 + (wrap_into_ring(offset_bytes) // 4) + header = ((1 & 0x1) << 31) | ((5 & 0x7) << 28) | (8 & 0xFF) + dw5 = ((retry_count & 0xFFF) << 16) | (interval & 0xFFFF) + + tl.store(slot_ptr_u32 + 0, header, cache_modifier=".wt") + tl.store(slot_ptr_u32 + 1, flag_ptr_val.to(tl.uint32), cache_modifier=".wt") + tl.store(slot_ptr_u32 + 2, (flag_ptr_val.to(tl.uint64) >> 32).to(tl.uint32), cache_modifier=".wt") + tl.store(slot_ptr_u32 + 3, expected_value.to(tl.uint32), cache_modifier=".wt") + tl.store(slot_ptr_u32 + 4, 0xFFFFFFFF, cache_modifier=".wt") + tl.store(slot_ptr_u32 + 5, dw5, cache_modifier=".wt") + + +@triton.jit +def place_sub_window_copy_packet( + queue_ptr_u32, + offset_bytes: tl.uint64, + src_ptr_val, + dst_ptr_val, + tile_width: tl.uint32, + tile_height: tl.uint32, + src_buffer_pitch: tl.uint32, + dst_buffer_pitch: tl.uint32, + src_x: tl.uint32, + src_y: tl.uint32, + dst_x: tl.uint32, + dst_y: tl.uint32, +): + """ + Place a SDMA_PKT_LINEAR_LARGE_SUB_WINDOW_COPY packet for 2D tile transfer. + + Copies a rectangular tile with arbitrary source/destination offsets. + Note: pitch, slice_pitch and rect fields are 1-based (subtract 1 before writing). + Args: + queue_ptr_u32: Pointer to the SDMA queue buffer (as uint32 array) + offset_bytes: Byte offset in the queue where to place the packet + src_ptr_val: Source buffer base address + dst_ptr_val: Destination buffer base address + tile_width: Width of the tile to copy in bytes + tile_height: Height of the tile to copy in rows + src_buffer_pitch: Row stride of the source buffer in bytes + dst_buffer_pitch: Row stride of the destination buffer in bytes + src_x: Source X offset in bytes + src_y: Source Y offset in rows + dst_x: Destination X offset in bytes + dst_y: Destination Y offset in rows + """ + slot_ptr_u32 = queue_ptr_u32 + (wrap_into_ring(offset_bytes) // 4) + + # DW 0: Header (op=1, sub_op=0x24) + # op[7:0] = 1 (SDMA_OP_COPY), sub_op[15:8] = 0x24 (SDMA_SUBOP_COPY_LINEAR_SUB_WINDOW) + tl.store(slot_ptr_u32 + 0, ((0x24 & 0xFF) << 8) | (0x1 & 0xFF), cache_modifier=".wt") + + # DW 1-2: Source base address + tl.store(slot_ptr_u32 + 1, src_ptr_val.to(tl.uint32), cache_modifier=".wt") + tl.store(slot_ptr_u32 + 2, (src_ptr_val >> 32).to(tl.uint32), cache_modifier=".wt") + + # DW 3: Source X offset (bytes) + tl.store(slot_ptr_u32 + 3, src_x, cache_modifier=".wt") + + # DW 4: Source Y offset (rows) + tl.store(slot_ptr_u32 + 4, src_y, cache_modifier=".wt") + + # DW 5: Source Z offset (0 for 2D) + tl.store(slot_ptr_u32 + 5, 0, cache_modifier=".wt") + + # DW 6: Source pitch (1-based, so subtract 1) + tl.store(slot_ptr_u32 + 6, src_buffer_pitch - 1, cache_modifier=".wt") + + # DW 7-8: Source slice pitch (1-based, 0 means slice_pitch of 1, for 2D) + tl.store(slot_ptr_u32 + 7, 0, cache_modifier=".wt") + tl.store(slot_ptr_u32 + 8, 0, cache_modifier=".wt") + + # DW 9-10: Destination base address + tl.store(slot_ptr_u32 + 9, dst_ptr_val.to(tl.uint32), cache_modifier=".wt") + tl.store(slot_ptr_u32 + 10, (dst_ptr_val >> 32).to(tl.uint32), cache_modifier=".wt") + + # DW 11: Destination X offset (bytes) + tl.store(slot_ptr_u32 + 11, dst_x, cache_modifier=".wt") + + # DW 12: Destination Y offset (rows) + tl.store(slot_ptr_u32 + 12, dst_y, cache_modifier=".wt") + + # DW 13: Destination Z offset (0 for 2D) + tl.store(slot_ptr_u32 + 13, 0, cache_modifier=".wt") + + # DW 14: Destination pitch (1-based, so subtract 1) + tl.store(slot_ptr_u32 + 14, dst_buffer_pitch - 1, cache_modifier=".wt") + + # DW 15-16: Destination slice pitch (1-based, 0 means slice_pitch of 1, for 2D) + tl.store(slot_ptr_u32 + 15, 0, cache_modifier=".wt") + tl.store(slot_ptr_u32 + 16, 0, cache_modifier=".wt") + + # DW 17: Rectangle X (width in bytes, 1-based) + tl.store(slot_ptr_u32 + 17, tile_width - 1, cache_modifier=".wt") + + # DW 18: Rectangle Y (height in rows, 1-based) + tl.store(slot_ptr_u32 + 18, tile_height - 1, cache_modifier=".wt") + + # DW 19: Rectangle Z (depth, 1-based, 0 for 2D means depth of 1) + tl.store(slot_ptr_u32 + 19, 0, cache_modifier=".wt") diff --git a/iris/iris.py b/iris/iris.py index 608dbb432..d471bf011 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -56,7 +56,8 @@ count_devices, ) -import anvil +from xio import sdma_ep +from iris.device import sdma_utils from iris.symmetric_heap import SymmetricHeap import numpy as np from typing import Any @@ -70,7 +71,7 @@ from .tracing import ( Tracing, DeviceTracing, -) # noqa: F401 re-export for iris.TraceEvent +) # noqa: F401 # Import shared tensor-creation helpers from . import tensor_creation @@ -140,10 +141,9 @@ def __init__(self, heap_size=1 << 30, allocator_type="torch"): distributed_barrier() # initialize copy engines - self.copy_engines = anvil.AnvilLib.get_instance() - self.copy_engines.init() + sdma_ep.init() - context_size = anvil.QUEUE_DEVICE_CTX_SIZE + context_size = sdma_ep.QUEUE_DEVICE_CTX_SIZE self.copy_engines_device_ctx = torch.zeros((num_ranks, context_size), dtype=torch.uint64, device=self.device) num_local_ranks = min(num_gpus, num_ranks) @@ -151,11 +151,11 @@ def __init__(self, heap_size=1 << 30, allocator_type="torch"): for local_rank in range(num_local_ranks): # Device-initiated queues - self.copy_engines.connect(cur_local_rank, local_rank, allocate_on_host=False) + sdma_ep.create_queue(cur_local_rank, local_rank) # Host-initiated queues - self.copy_engines.connect(cur_local_rank, local_rank, allocate_on_host=True) + sdma_ep.create_host_queue(cur_local_rank, local_rank) - handle = self.copy_engines.get_queue_device_ctx(cur_local_rank, local_rank) + handle = sdma_ep.get_queue_device_ctx(cur_local_rank, local_rank) self.debug(f"---- Queue {local_rank} ------------") self.debug(f"queue_buf {handle.queue_buf:#x} at {id(handle.queue_buf):#x}") self.debug(f"rptr {handle.rptr:#x} at {id(handle.rptr):#x}") @@ -1034,7 +1034,7 @@ def translate(self, ptr: int, from_rank: int, to_rank: int) -> int: >>> buffer = ctx.zeros(1024, dtype=torch.float32) >>> # Translate buffer address from rank 0 to rank 1's address space >>> remote_addr = ctx.translate(buffer.data_ptr(), 0, 1) - >>> ctx.copy_engines.host_put(0, 1, 0, src_ptr, remote_addr, size) + >>> ctx.copy_engines.put(0, 1, 0, src_ptr, remote_addr, size) """ # Use pre-cached CPU copy to avoid GPU->CPU transfer on every call from_base = int(self.heap_bases_cpu[from_rank]) @@ -1042,6 +1042,41 @@ def translate(self, ptr: int, from_rank: int, to_rank: int) -> int: offset = ptr - from_base return to_base + offset + @staticmethod + def _dtype_to_flag_bits(dtype: torch.dtype) -> int: + if dtype in (torch.int32, torch.int): + return 32 + if dtype in (torch.int64, torch.long): + return 64 + raise ValueError(f"Unsupported flag tensor dtype: {dtype}") + + def _flag_pointer_and_bits( + self, + flag, + *, + translate: bool = False, + dst_rank: int | None = None, + default_bits: int = 32, + ) -> tuple[int, int]: + if flag is None: + return 0, 0 + + if isinstance(flag, torch.Tensor): + bits = self._dtype_to_flag_bits(flag.dtype) + ptr = flag.data_ptr() + else: + ptr = int(flag) + bits = default_bits + + if translate: + if dst_rank is None: + raise ValueError("dst_rank must be provided when translate=True") + ptr = self.translate(ptr, self.get_rank(), dst_rank) + + if ptr == 0: + return 0, 0 + return ptr, bits + def put( self, src_tensor: torch.Tensor, @@ -1097,63 +1132,37 @@ def put( dst_ptr = self.translate(dst_tensor.data_ptr(), src_rank, dst_rank) size = src_tensor.numel() * src_tensor.element_size() - # Determine which SDMA packet combination to use - has_wait = wait_flag is not None - has_signal = signal_flag is not None + wait_ptr, wait_bits = self._flag_pointer_and_bits(wait_flag) + signal_ptr, signal_bits = self._flag_pointer_and_bits(signal_flag, translate=True, dst_rank=dst_rank) + + has_wait = wait_ptr != 0 + has_signal = signal_ptr != 0 if has_wait and has_signal: - # POLL + COPY + ATOMIC (two submissions) - wait_ptr = wait_flag.data_ptr() - signal_ptr = self.translate(signal_flag.data_ptr(), src_rank, dst_rank) - - # First: POLL + COPY - self.copy_engines.host_wait_flag_then_put( - src_rank, - dst_rank, - channel, - wait_ptr, - wait_value, - src_ptr, - dst_ptr, - size, + # Wait + copy + signal (two calls) + wait_val = int(wait_value if wait_value is not None else 0) + signal_val = int(signal_value) + sdma_ep.wait_flag_then_put( + src_rank, dst_rank, channel, wait_ptr, wait_val, src_ptr, dst_ptr, size, wait_bits ) - # Then: ATOMIC - self.copy_engines.host_atomic_add(src_rank, dst_rank, channel, signal_ptr, signal_value) - + sdma_ep.signal(src_rank, dst_rank, channel, signal_ptr, signal_val, signal_bits) + self.copy_engines.signal(src_rank, dst_rank, channel, signal_ptr, signal_val, signal_bits) elif has_wait: - # POLL + COPY - wait_ptr = wait_flag.data_ptr() - self.copy_engines.host_wait_flag_then_put( - src_rank, - dst_rank, - channel, - wait_ptr, - wait_value, - src_ptr, - dst_ptr, - size, + # Wait + copy + wait_val = int(wait_value if wait_value is not None else 0) + sdma_ep.wait_flag_then_put( + src_rank, dst_rank, channel, wait_ptr, wait_val, src_ptr, dst_ptr, size, wait_bits ) - elif has_signal: - # COPY + ATOMIC (combined in one submission) - signal_ptr = self.translate(signal_flag.data_ptr(), src_rank, dst_rank) - self.copy_engines.host_put_signal( - src_rank, - dst_rank, - channel, - src_ptr, - dst_ptr, - size, - signal_ptr, - signal_value, - ) - + # Copy + signal + signal_val = int(signal_value) + sdma_ep.put_signal(src_rank, dst_rank, channel, src_ptr, dst_ptr, size, signal_ptr, signal_val, signal_bits) else: - # Simple COPY - self.copy_engines.host_put(src_rank, dst_rank, channel, src_ptr, dst_ptr, size) + # Simple copy + sdma_ep.put(src_rank, dst_rank, channel, src_ptr, dst_ptr, size) if not async_op: - self.copy_engines.host_quiet(src_rank, dst_rank, channel) + sdma_ep.quiet(src_rank, dst_rank, channel) def put_tile( self, @@ -1174,7 +1183,7 @@ def put_tile( Low-level API - caller provides pre-translated pointers for performance. Args: - tile: Pre-configured anvil.Tile object with data pointer and dimensions set + tile: Pre-configured sdma_ep.Tile object with data pointer and dimensions set dst_rank: Destination rank dst_ptr: Destination pointer (already translated to remote address space) dst_stride: Destination row stride in bytes @@ -1187,7 +1196,7 @@ def put_tile( Examples: >>> import anvil - >>> tile = anvil.Tile() + >>> tile = sdma_ep.Tile() >>> tile.pid_m = 0 >>> tile.pid_n = 0 >>> tile.block_m = 256 @@ -1204,55 +1213,39 @@ def put_tile( """ src_rank = self.get_rank() - has_wait = wait_flag is not None - has_signal = signal_flag is not None + wait_ptr, wait_bits = self._flag_pointer_and_bits(wait_flag, default_bits=32) + signal_ptr, signal_bits = self._flag_pointer_and_bits(signal_flag, default_bits=32) + + has_wait = wait_ptr != 0 + has_signal = signal_ptr != 0 if has_wait and has_signal: - # POLL + SUB_WINDOW_COPY + ATOMIC (two submissions) - self.copy_engines.host_wait_flag_then_put_tile( - src_rank, - dst_rank, - channel, - wait_flag, - wait_value, - tile, - dst_ptr, - dst_stride, + # Wait + tile copy + signal (two calls) + wait_val = int(wait_value if wait_value is not None else 0) + signal_val = int(signal_value) + sdma_ep.wait_flag_then_put_tile( + src_rank, dst_rank, channel, wait_ptr, wait_val, tile, int(dst_ptr), int(dst_stride), wait_bits ) - self.copy_engines.host_atomic_add_32(src_rank, dst_rank, channel, signal_flag, signal_value) - + sdma_ep.signal(src_rank, dst_rank, channel, signal_ptr, signal_val, signal_bits) + self.copy_engines.signal(src_rank, dst_rank, channel, signal_ptr, signal_val, signal_bits) elif has_wait: - # POLL + SUB_WINDOW_COPY - self.copy_engines.host_wait_flag_then_put_tile( - src_rank, - dst_rank, - channel, - wait_flag, - wait_value, - tile, - dst_ptr, - dst_stride, + # Wait + tile copy + wait_val = int(wait_value if wait_value is not None else 0) + sdma_ep.wait_flag_then_put_tile( + src_rank, dst_rank, channel, wait_ptr, wait_val, tile, int(dst_ptr), int(dst_stride), wait_bits ) - elif has_signal: - # SUB_WINDOW_COPY + ATOMIC - self.copy_engines.host_put_tile_signal( - src_rank, - dst_rank, - channel, - tile, - dst_ptr, - dst_stride, - signal_flag, - signal_value, + # Tile copy + signal + signal_val = int(signal_value) + sdma_ep.put_tile_signal( + src_rank, dst_rank, channel, tile, int(dst_ptr), int(dst_stride), signal_ptr, signal_val, signal_bits ) - else: - # Simple SUB_WINDOW_COPY - self.copy_engines.host_put_tile(src_rank, dst_rank, channel, tile, dst_ptr, dst_stride) + # Simple tile copy + sdma_ep.put_tile(src_rank, dst_rank, channel, tile, int(dst_ptr), int(dst_stride)) if not async_op: - self.copy_engines.host_quiet(src_rank, dst_rank, channel) + sdma_ep.quiet(src_rank, dst_rank, channel) def put_tiles( self, @@ -1271,7 +1264,7 @@ def put_tiles( Batched 2D tile transfer with optional shared wait/signal. Args: - tiles: Sequence of pre-configured anvil.Tile objects + tiles: Sequence of pre-configured sdma_ep.Tile objects dst_rank: Destination rank dst_ptrs: Sequence of translated destination pointers dst_strides: Sequence of destination row strides in bytes @@ -1287,38 +1280,41 @@ def put_tiles( if len(tiles) != len(dst_ptrs) or len(tiles) != len(dst_strides): raise ValueError("tiles, dst_ptrs, and dst_strides must have the same length") - has_wait = wait_flag is not None - has_signal = signal_flag is not None - - if has_wait: - self.copy_engines.host_wait_flag_then_put_tiles( - src_rank, - dst_rank, - channel, - wait_flag, - wait_value, - tiles, - dst_ptrs, - dst_strides, + wait_ptr, wait_bits = self._flag_pointer_and_bits(wait_flag, default_bits=32) + signal_ptr, signal_bits = self._flag_pointer_and_bits(signal_flag, default_bits=32) + + has_wait = wait_ptr != 0 + has_signal = signal_ptr != 0 + + dst_ptr_list = [int(p) for p in dst_ptrs] + dst_stride_list = [int(s) for s in dst_strides] + + if has_wait and has_signal: + # Wait + tiles copy + signal (two calls) + wait_val = int(wait_value if wait_value is not None else 0) + signal_val = int(signal_value) + sdma_ep.wait_flag_then_put_tiles( + src_rank, dst_rank, channel, wait_ptr, wait_val, list(tiles), dst_ptr_list, dst_stride_list, wait_bits + ) + sdma_ep.signal(src_rank, dst_rank, channel, signal_ptr, signal_val, signal_bits) + self.copy_engines.signal(src_rank, dst_rank, channel, signal_ptr, signal_val, signal_bits) + elif has_wait: + # Wait + tiles copy + wait_val = int(wait_value if wait_value is not None else 0) + sdma_ep.wait_flag_then_put_tiles( + src_rank, dst_rank, channel, wait_ptr, wait_val, list(tiles), dst_ptr_list, dst_stride_list, wait_bits ) - if has_signal: - self.copy_engines.host_atomic_add_32(src_rank, dst_rank, channel, signal_flag, signal_value) + elif has_signal: + # Tiles copy + signal (loop + signal) + signal_val = int(signal_value) + sdma_ep.put_tiles(src_rank, dst_rank, channel, list(tiles), dst_ptr_list, dst_stride_list) + sdma_ep.signal(src_rank, dst_rank, channel, signal_ptr, signal_val, signal_bits) else: - for tile, dst_ptr, dst_stride in zip(tiles, dst_ptrs, dst_strides): - self.put_tile( - tile, - dst_rank=dst_rank, - dst_ptr=dst_ptr, - dst_stride=dst_stride, - signal_flag=None, - async_op=True, - channel=channel, - ) - if has_signal: - self.copy_engines.host_atomic_add_32(src_rank, dst_rank, channel, signal_flag, signal_value) + # Simple tiles copy + sdma_ep.put_tiles(src_rank, dst_rank, channel, list(tiles), dst_ptr_list, dst_stride_list) if not async_op: - self.copy_engines.host_quiet(src_rank, dst_rank, channel) + sdma_ep.quiet(src_rank, dst_rank, channel) def quiet(self, dst_rank: int = None, channel: int = 0): """ @@ -1336,11 +1332,11 @@ def quiet(self, dst_rank: int = None, channel: int = 0): """ src_rank = self.get_rank() if dst_rank is not None: - self.copy_engines.host_quiet(src_rank, dst_rank, channel) + sdma_ep.quiet(src_rank, dst_rank, channel) else: # Quiet to all ranks for rank in range(self.get_num_ranks()): - self.copy_engines.host_quiet(src_rank, rank, channel) + sdma_ep.quiet(src_rank, rank, channel) def _build_device_context(self): """ @@ -2925,7 +2921,7 @@ def put( tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) else: - ctx = copy_engine_ctx + (anvil.QUEUE_DEVICE_CTX_SIZE * to_rank) + ctx = copy_engine_ctx + (sdma_ep.QUEUE_DEVICE_CTX_SIZE * to_rank) queue_ptr_u32 = tl.load(ctx + 0).to(tl.pointer_type(tl.uint32)) read_ptr = tl.load(ctx + 1).to(tl.pointer_type(tl.uint64)) write_ptr = tl.load(ctx + 2).to(tl.pointer_type(tl.uint64)) @@ -2964,10 +2960,12 @@ def put( # Linear copy packet: 32 bytes for 1D, Sub-window copy packet: 80 bytes for 2D # IS_2D_COPY is a compile-time constant for proper branch elimination mask_int = mask.to(tl.int32) - command_in_bytes = anvil.SDMA_PKT_LINEAR_SUB_WINDOW_BYTES if IS_2D_COPY else anvil.SDMA_PKT_COPY_LINEAR_BYTES + command_in_bytes = ( + sdma_ep.COPY_LINEAR_SUB_WINDOW_COMMAND_BYTES if IS_2D_COPY else sdma_ep.COPY_LINEAR_COMMAND_BYTES + ) # Acquire space in the queue - base, offset = anvil.acquire_fadd( + base, offset = sdma_utils.acquire_fadd( queue_ptr_u32, read_ptr, write_ptr, @@ -2978,7 +2976,7 @@ def put( ) # Write padding NOPs if we wrapped around - anvil.place_nop_packet(queue_ptr_u32, base, offset) + sdma_utils.place_nop_packet(queue_ptr_u32, base, offset) # Place the appropriate packet type packet_offset_bytes = base + offset @@ -2989,7 +2987,7 @@ def put( size_bytes = (num_elements * element_size_bytes).to(tl.uint32) # Place linear copy packet for 1D/flat copies - anvil.place_copy_packet( + sdma_utils.place_copy_packet( queue_ptr_u32, packet_offset_bytes, size_bytes, @@ -3018,7 +3016,7 @@ def put( dst_y_val = (tile_offset_bytes_dst // dst_stride).to(tl.uint32) dst_x_val = (tile_offset_bytes_dst % dst_stride).to(tl.uint32) - anvil.place_sub_window_copy_packet( + sdma_utils.place_sub_window_copy_packet( queue_ptr_u32, packet_offset_bytes, src_base, @@ -3035,7 +3033,7 @@ def put( # Submit the command to the queue pending_wptr = base + offset + command_in_bytes - anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) + sdma_utils.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) @triton.jit @@ -3087,7 +3085,7 @@ def atomic_add( if not USE_COPY_ENGINE: return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) else: - handle = copy_engine_ctx + (anvil.QUEUE_DEVICE_CTX_SIZE * to_rank) + handle = copy_engine_ctx + (sdma_ep.QUEUE_DEVICE_CTX_SIZE * to_rank) queue_ptr_u32 = tl.load(handle + 0).to(tl.pointer_type(tl.uint32)) read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) @@ -3097,10 +3095,10 @@ def atomic_add( dst_ptr_val = translated_ptr.to(tl.uint64) - command_in_bytes = anvil.SDMA_PKT_ATOMIC_BYTES + command_in_bytes = sdma_ep.ATOMIC_COMMAND_BYTES # Acquire space (returns base index and wraparound offset) - base, offset = anvil.acquire_fadd( - # base = anvil.acquire( + base, offset = sdma_utils.acquire_fadd( + # base = sdma_utils.acquire( queue_ptr_u32, read_ptr, write_ptr, @@ -3110,17 +3108,17 @@ def atomic_add( command_in_bytes, ) # Write padding NOPs if we wrapped around - anvil.place_nop_packet(queue_ptr_u32, base, offset) + sdma_utils.place_nop_packet(queue_ptr_u32, base, offset) # Calculate packet position (base + offset for wraparound) packet_offset_bytes = base + offset # Place command packet - anvil.place_atomic_add_packet(queue_ptr_u32, packet_offset_bytes, dst_ptr_val, val) + sdma_utils.place_atomic_add_packet(queue_ptr_u32, packet_offset_bytes, dst_ptr_val, val) # Submit command pending_wptr = base + offset + command_in_bytes - anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) + sdma_utils.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) @triton.jit @@ -3227,7 +3225,7 @@ def atomic_cas( if not USE_COPY_ENGINE: return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) else: - handle = copy_engine_ctx + (anvil.QUEUE_DEVICE_CTX_SIZE * to_rank) + handle = copy_engine_ctx + (sdma_ep.QUEUE_DEVICE_CTX_SIZE * to_rank) queue_ptr_u32 = tl.load(handle + 0).to(tl.pointer_type(tl.uint32)) read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) @@ -3237,9 +3235,9 @@ def atomic_cas( dst_ptr_val = translated_ptr.to(tl.uint64) - command_in_bytes = anvil.SDMA_PKT_ATOMIC_BYTES + command_in_bytes = sdma_ep.ATOMIC_COMMAND_BYTES # Acquire space (returns base index and wraparound offset) - base, offset = anvil.acquire_fadd( + base, offset = sdma_utils.acquire_fadd( queue_ptr_u32, read_ptr, write_ptr, @@ -3249,11 +3247,11 @@ def atomic_cas( command_in_bytes, ) # Write padding NOPs if we wrapped around - anvil.place_nop_packet(queue_ptr_u32, base, offset) + sdma_utils.place_nop_packet(queue_ptr_u32, base, offset) # Calculate packet position (base + offset for wraparound) packet_offset_bytes = base + offset # Place command packet - anvil.place_atomic_cas_packet( + sdma_utils.place_atomic_cas_packet( queue_ptr_u32, packet_offset_bytes, dst_ptr_val, @@ -3262,7 +3260,7 @@ def atomic_cas( ) # Submit command pending_wptr = base + offset + command_in_bytes - anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) + sdma_utils.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) @triton.jit diff --git a/tests/examples/test_message_passing.py b/tests/examples/test_message_passing.py index d4dace005..cc8ce9392 100644 --- a/tests/examples/test_message_passing.py +++ b/tests/examples/test_message_passing.py @@ -125,6 +125,8 @@ def run_message_passing_kernels(module, args, *, use_copy_engine: bool = False): import gc gc.collect() + # Clear CUDA cache to free GPU memory between tests + torch.cuda.empty_cache() @pytest.mark.parametrize(