From 00ced0f89f2635c222eef1d118697bfd801c9d2e Mon Sep 17 00:00:00 2001 From: Kyle Montemayor Date: Mon, 6 Apr 2026 20:59:32 +0000 Subject: [PATCH 1/8] Extract sampler factory helpers to gigl/distributed/utils/dist_sampler.py Move create_dist_sampler(), SamplerInput, and SamplerRuntime out of dist_sampling_producer.py into a shared utils module so they can be reused by the upcoming SharedDistSamplingBackend. Also rename `w` -> `worker` in DistSamplingProducer.init() for clarity. Co-Authored-By: Claude Opus 4.6 --- gigl/distributed/dist_sampling_producer.py | 60 ++++---------- gigl/distributed/utils/dist_sampler.py | 94 ++++++++++++++++++++++ 2 files changed, 108 insertions(+), 46 deletions(-) create mode 100644 gigl/distributed/utils/dist_sampler.py diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index f155bd929..3a51715e2 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -37,13 +37,8 @@ from torch.utils.data.dataset import Dataset from gigl.common.logger import Logger -from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler -from gigl.distributed.dist_ppr_sampler import DistPPRNeighborSampler -from gigl.distributed.sampler_options import ( - KHopNeighborSamplerOptions, - PPRSamplerOptions, - SamplerOptions, -) +from gigl.distributed.sampler_options import SamplerOptions +from gigl.distributed.utils.dist_sampler import create_dist_sampler logger = Logger() @@ -100,42 +95,15 @@ def _sampling_worker_loop( if sampling_config.seed is not None: seed_everything(sampling_config.seed) - # Shared args for all sampler types (positional args to DistNeighborSampler.__init__) - shared_sampler_args = ( - data, - sampling_config.num_neighbors, - sampling_config.with_edge, - sampling_config.with_neg, - sampling_config.with_weight, - sampling_config.edge_dir, - sampling_config.collect_features, - channel, - worker_options.use_all2all, - worker_options.worker_concurrency, - current_device, + dist_sampler = create_dist_sampler( + data=data, + sampling_config=sampling_config, + worker_options=worker_options, + channel=channel, + sampler_options=sampler_options, + degree_tensors=degree_tensors, + current_device=current_device, ) - - if isinstance(sampler_options, KHopNeighborSamplerOptions): - dist_sampler = DistNeighborSampler( - *shared_sampler_args, - seed=sampling_config.seed, - ) - elif isinstance(sampler_options, PPRSamplerOptions): - assert degree_tensors is not None - dist_sampler = DistPPRNeighborSampler( - *shared_sampler_args, - seed=sampling_config.seed, - alpha=sampler_options.alpha, - eps=sampler_options.eps, - max_ppr_nodes=sampler_options.max_ppr_nodes, - num_neighbors_per_hop=sampler_options.num_neighbors_per_hop, - total_degree_dtype=sampler_options.total_degree_dtype, - degree_tensors=degree_tensors, - ) - else: - raise NotImplementedError( - f"Unsupported sampler options type: {type(sampler_options)}" - ) dist_sampler.start_loop() unshuffled_index_loader: Optional[DataLoader] @@ -236,7 +204,7 @@ def init(self): self.num_workers * self.worker_options.worker_concurrency ) self._task_queues.append(task_queue) - w = mp_context.Process( + worker = mp_context.Process( target=_sampling_worker_loop, args=( rank, @@ -253,7 +221,7 @@ def init(self): self._degree_tensors, ), ) - w.daemon = True - w.start() - self._workers.append(w) + worker.daemon = True + worker.start() + self._workers.append(worker) barrier.wait() diff --git a/gigl/distributed/utils/dist_sampler.py b/gigl/distributed/utils/dist_sampler.py new file mode 100644 index 000000000..6ed26f613 --- /dev/null +++ b/gigl/distributed/utils/dist_sampler.py @@ -0,0 +1,94 @@ +"""Sampler factory helpers shared across sampling producers.""" + +from typing import Optional, Union + +import torch +from graphlearn_torch.channel import ChannelBase +from graphlearn_torch.distributed import ( + DistDataset, + MpDistSamplingWorkerOptions, + RemoteDistSamplingWorkerOptions, +) +from graphlearn_torch.sampler import EdgeSamplerInput, NodeSamplerInput, SamplingConfig +from graphlearn_torch.typing import EdgeType + +from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler +from gigl.distributed.dist_ppr_sampler import DistPPRNeighborSampler +from gigl.distributed.sampler import ABLPNodeSamplerInput +from gigl.distributed.sampler_options import ( + KHopNeighborSamplerOptions, + PPRSamplerOptions, + SamplerOptions, +) + +SamplerInput = Union[NodeSamplerInput, EdgeSamplerInput, ABLPNodeSamplerInput] +"""Union of all supported sampler input types.""" + +SamplerRuntime = Union[DistNeighborSampler, DistPPRNeighborSampler] +"""Union of all supported GiGL sampler runtime types.""" + + +def create_dist_sampler( + *, + data: DistDataset, + sampling_config: SamplingConfig, + worker_options: Union[MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions], + channel: ChannelBase, + sampler_options: SamplerOptions, + degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + current_device: torch.device, +) -> SamplerRuntime: + """Create a GiGL sampler runtime for one channel on one worker. + + Args: + data: The distributed dataset containing graph topology and features. + sampling_config: Configuration for sampling behavior (neighbors, edges, etc.). + worker_options: Worker-level options (RPC settings, device placement, concurrency). + channel: The communication channel for passing sampled messages. + sampler_options: Algorithm-specific options (k-hop or PPR). + degree_tensors: Pre-computed degree tensors required by PPR sampling. + Must not be ``None`` when ``sampler_options`` is :class:`PPRSamplerOptions`. + current_device: The device on which sampling will run. + + Returns: + A configured sampler runtime, either :class:`DistNeighborSampler` or + :class:`DistPPRNeighborSampler`. + + Raises: + NotImplementedError: If ``sampler_options`` is an unsupported type. + """ + shared_sampler_args = ( + data, + sampling_config.num_neighbors, + sampling_config.with_edge, + sampling_config.with_neg, + sampling_config.with_weight, + sampling_config.edge_dir, + sampling_config.collect_features, + channel, + worker_options.use_all2all, + worker_options.worker_concurrency, + current_device, + ) + if isinstance(sampler_options, KHopNeighborSamplerOptions): + sampler: SamplerRuntime = DistNeighborSampler( + *shared_sampler_args, + seed=sampling_config.seed, + ) + elif isinstance(sampler_options, PPRSamplerOptions): + assert degree_tensors is not None + sampler = DistPPRNeighborSampler( + *shared_sampler_args, + seed=sampling_config.seed, + alpha=sampler_options.alpha, + eps=sampler_options.eps, + max_ppr_nodes=sampler_options.max_ppr_nodes, + num_neighbors_per_hop=sampler_options.num_neighbors_per_hop, + total_degree_dtype=sampler_options.total_degree_dtype, + degree_tensors=degree_tensors, + ) + else: + raise NotImplementedError( + f"Unsupported sampler options type: {type(sampler_options)}" + ) + return sampler From f53e60fa631d121c166d0cceb6848eb80939ffda Mon Sep 17 00:00:00 2001 From: Kyle Montemayor Date: Mon, 6 Apr 2026 21:03:16 +0000 Subject: [PATCH 2/8] =?UTF-8?q?Add=20SharedDistSamplingBackend=20=E2=80=94?= =?UTF-8?q?=20multi-channel=20sampling=20backend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce SharedDistSamplingBackend which manages a pool of worker processes servicing multiple compute-rank channels through a fair-queued round-robin scheduler. This replaces the per-channel producer model in graph-store mode with a shared backend + lightweight per-channel state. Includes tests for pure business logic helpers (_compute_num_batches, _epoch_batch_indices, _compute_worker_seeds_ranges), shuffle behavior, and completion reporting. Co-Authored-By: Claude Opus 4.6 --- .../shared_dist_sampling_producer.py | 814 ++++++++++++++++++ .../dist_sampling_producer_test.py | 214 +++++ 2 files changed, 1028 insertions(+) create mode 100644 gigl/distributed/graph_store/shared_dist_sampling_producer.py create mode 100644 tests/unit/distributed/dist_sampling_producer_test.py diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py new file mode 100644 index 000000000..b22ad9403 --- /dev/null +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -0,0 +1,814 @@ +"""Shared graph-store sampling backend and fair-queued worker loop. + +This module implements the multi-channel sampling backend used in graph-store +mode. A single ``SharedDistSamplingBackend`` per loader instance manages a +pool of worker processes that service many compute-rank channels through a +fair-queued scheduler (``_shared_sampling_worker_loop``). +""" + +import datetime +import queue +import threading +import time +from collections import defaultdict, deque +from dataclasses import dataclass +from enum import Enum, auto +from multiprocessing.process import BaseProcess +from threading import Barrier +from typing import Optional, Union, cast + +import torch +import torch.multiprocessing as mp +from graphlearn_torch.channel import ChannelBase +from graphlearn_torch.distributed import ( + DistDataset, + RemoteDistSamplingWorkerOptions, + get_context, + init_rpc, + init_worker_group, + shutdown_rpc, +) +from graphlearn_torch.distributed.dist_sampling_producer import MP_STATUS_CHECK_INTERVAL +from graphlearn_torch.sampler import ( + EdgeSamplerInput, + NodeSamplerInput, + SamplingConfig, + SamplingType, +) +from graphlearn_torch.typing import EdgeType +from torch._C import _set_worker_signal_handlers + +from gigl.common.logger import Logger +from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions +from gigl.distributed.utils.dist_sampler import ( + SamplerInput, + SamplerRuntime, + create_dist_sampler, +) + +logger = Logger() + + +def _prepare_degree_tensors( + data: DistDataset, + sampler_options: SamplerOptions, +) -> Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: + """Materialize PPR degree tensors before worker spawn when required.""" + degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] = None + if isinstance(sampler_options, PPRSamplerOptions): + degree_tensors = data.degree_tensor + if isinstance(degree_tensors, dict): + logger.info( + "Pre-computed degree tensors for PPR sampling across " + f"{len(degree_tensors)} edge types." + ) + elif degree_tensors is not None: + logger.info( + "Pre-computed degree tensor for PPR sampling with " + f"{degree_tensors.size(0)} nodes." + ) + return degree_tensors + + +EPOCH_DONE_EVENT = "EPOCH_DONE" +SCHEDULER_TICK_SECS = 0.05 +SCHEDULER_STATE_LOG_INTERVAL_SECS = 10.0 +SCHEDULER_STATE_MAX_CHANNELS = 6 +SCHEDULER_SLOW_SUBMIT_SECS = 1.0 + + +class SharedMpCommand(Enum): + REGISTER_INPUT = auto() + UNREGISTER_INPUT = auto() + START_EPOCH = auto() + STOP = auto() + + +@dataclass(frozen=True) +class RegisterInputCmd: + channel_id: int + worker_key: str + sampler_input: SamplerInput + sampling_config: SamplingConfig + channel: ChannelBase + + +@dataclass(frozen=True) +class StartEpochCmd: + channel_id: int + epoch: int + seeds_index: Optional[torch.Tensor] + + +@dataclass +class ActiveEpochState: + channel_id: int + epoch: int + input_len: int + batch_size: int + drop_last: bool + seeds_index: Optional[torch.Tensor] + total_batches: int + submitted_batches: int = 0 + completed_batches: int = 0 + cancelled: bool = False + + +def _command_channel_id(command: SharedMpCommand, payload: object) -> Optional[int]: + """Extract the channel id from a worker command payload.""" + if command == SharedMpCommand.STOP: + return None + if isinstance(payload, RegisterInputCmd): + return payload.channel_id + if isinstance(payload, StartEpochCmd): + return payload.channel_id + if isinstance(payload, int): + return payload + return None + + +def _compute_num_batches(input_len: int, batch_size: int, drop_last: bool) -> int: + """Compute the number of batches emitted for an input length.""" + if input_len <= 0: + return 0 + if drop_last: + return input_len // batch_size + return (input_len + batch_size - 1) // batch_size + + +def _epoch_batch_indices(state: ActiveEpochState) -> Optional[torch.Tensor]: + """Return the next batch of indices for an active epoch. + + Returns the index tensor for the next batch, or None if no more batches + should be submitted (epoch cancelled, all batches already submitted, or + incomplete final batch with drop_last=True). + """ + if state.cancelled or state.submitted_batches >= state.total_batches: + return None + + batch_start = state.submitted_batches * state.batch_size + batch_end = min(batch_start + state.batch_size, state.input_len) + if state.drop_last and batch_end - batch_start < state.batch_size: + return None + + if state.seeds_index is None: + return torch.arange(batch_start, batch_end, dtype=torch.long) + return state.seeds_index[batch_start:batch_end] + + +def _compute_worker_seeds_ranges( + input_len: int, batch_size: int, num_workers: int +) -> list[tuple[int, int]]: + """Distribute complete batches across workers like GLT's producer does.""" + num_worker_batches = [0] * num_workers + num_total_complete_batches = input_len // batch_size + for rank in range(num_workers): + num_worker_batches[rank] += num_total_complete_batches // num_workers + for rank in range(num_total_complete_batches % num_workers): + num_worker_batches[rank] += 1 + + index_ranges: list[tuple[int, int]] = [] + start = 0 + for rank in range(num_workers): + end = start + num_worker_batches[rank] * batch_size + if rank == num_workers - 1: + end = input_len + index_ranges.append((start, end)) + start = end + return index_ranges + + +def _shared_sampling_worker_loop( + rank: int, + data: DistDataset, + worker_options: RemoteDistSamplingWorkerOptions, + task_queue: mp.Queue, + event_queue: mp.Queue, + mp_barrier: Barrier, + sampler_options: SamplerOptions, + degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], +) -> None: + """Run one shared graph-store worker that schedules many input channels. + + Each worker subprocess runs this function as a fair-queued batch scheduler. + Multiple input channels (each representing one compute rank's data stream) + share the same sampling worker processes and graph data. + + Algorithm: + 1. Initialize RPC, sampler infrastructure, and signal the parent via barrier. + 2. Enter the main event loop which alternates between: + a. Draining all pending commands from ``task_queue`` (register/unregister + channels, start epochs, stop). + b. Submitting batches round-robin from ``runnable_channels`` — a FIFO + queue of channels that have pending work. Each channel gets one batch + submitted per round to prevent starvation. + c. If no commands were processed and no batches submitted, blocking on + ``task_queue`` with a short timeout to avoid busy-waiting. + 3. Completion callbacks from the sampler update per-channel state and emit + ``EPOCH_DONE_EVENT`` to ``event_queue`` when all batches for an epoch + are finished. + """ + samplers: dict[int, SamplerRuntime] = {} + channels: dict[int, ChannelBase] = {} + inputs: dict[int, SamplerInput] = {} + cfgs: dict[int, SamplingConfig] = {} + route_key_by_channel: dict[int, str] = {} + started_epoch: dict[int, int] = {} + active_epochs_by_channel: dict[int, ActiveEpochState] = {} + runnable_channels: deque[int] = deque() + runnable_set: set[int] = set() + removing: set[int] = set() + state_lock = threading.RLock() + last_state_log_time = 0.0 + current_device: Optional[torch.device] = None + + # --- Scheduler helper functions --- + + def _enqueue_channel_if_runnable_locked(channel_id: int) -> None: + """Add channel to the fair-queue if it has pending batches.""" + state = active_epochs_by_channel.get(channel_id) + if state is None: + return + if state.cancelled or state.submitted_batches >= state.total_batches: + return + if channel_id in runnable_set: + return + runnable_channels.append(channel_id) + runnable_set.add(channel_id) + + def _clear_registered_input_locked(channel_id: int) -> None: + """Remove a channel's registration and clean up all associated state. + + If the channel still has in-flight batches (submitted but not yet + completed), marks it for deferred removal instead of cleaning up + immediately. + ``_on_batch_done`` will finish the cleanup once the last in-flight + batch completes. + """ + state = active_epochs_by_channel.get(channel_id) + if state is not None and state.completed_batches < state.submitted_batches: + removing.add(channel_id) + state.cancelled = True + return + sampler = samplers.pop(channel_id, None) + if sampler is not None: + sampler.wait_all() + sampler.shutdown_loop() + channels.pop(channel_id, None) + inputs.pop(channel_id, None) + cfgs.pop(channel_id, None) + route_key_by_channel.pop(channel_id, None) + started_epoch.pop(channel_id, None) + active_epochs_by_channel.pop(channel_id, None) + runnable_set.discard(channel_id) + removing.discard(channel_id) + + def _format_scheduler_state_locked() -> str: + """Format a human-readable snapshot of the scheduler for logging. + + Must be called while holding ``state_lock``. + """ + channel_ids = sorted(channels.keys()) + preview = channel_ids[:SCHEDULER_STATE_MAX_CHANNELS] + previews: list[str] = [] + for channel_id in preview: + active_epoch = active_epochs_by_channel.get(channel_id) + if active_epoch is None: + previews.append(f"{channel_id}:idle") + else: + previews.append( + f"{channel_id}:e{active_epoch.epoch}" + f"/{active_epoch.submitted_batches}" + f"/{active_epoch.completed_batches}" + f"/{active_epoch.total_batches}" + ) + extra = "" + if len(channel_ids) > len(preview): + extra = f" +{len(channel_ids) - len(preview)}" + return ( + f"registered={len(channels)} active={len(active_epochs_by_channel)} " + f"runnable={len(runnable_set)} removing={len(removing)} " + f"channels=[{', '.join(previews)}]{extra}" + ) + + def _maybe_log_scheduler_state(reason: str, force: bool = False) -> None: + """Log scheduler state at most once per ``SCHEDULER_STATE_LOG_INTERVAL_SECS``. + + Args: + reason: Short tag included in the log line (e.g. "start_epoch"). + force: If True, log regardless of the time-based throttle. + """ + nonlocal last_state_log_time + now = time.monotonic() + if not force and now - last_state_log_time < SCHEDULER_STATE_LOG_INTERVAL_SECS: + return + with state_lock: + scheduler_state = _format_scheduler_state_locked() + logger.info( + f"shared_sampling_scheduler worker_rank={rank} reason={reason} " + f"{scheduler_state}" + ) + last_state_log_time = now + + def _on_batch_done(channel_id: int, epoch: int) -> None: + """Sampler completion callback — invoked from sampler worker threads. + + Updates the channel's completed-batch counter. + When all batches for the epoch are done, emits ``EPOCH_DONE_EVENT`` + to ``event_queue``. + If the channel is pending removal, finishes cleanup via + ``_clear_registered_input_locked``. + """ + with state_lock: + state = active_epochs_by_channel.get(channel_id) + if state is None or state.epoch != epoch: + return + state.completed_batches += 1 + if state.completed_batches == state.total_batches: + active_epochs_by_channel.pop(channel_id, None) + event_queue.put((EPOCH_DONE_EVENT, channel_id, epoch, rank)) + if ( + channel_id in removing + and state.completed_batches == state.submitted_batches + ): + _clear_registered_input_locked(channel_id) + + def _submit_one_batch(channel_id: int) -> bool: + """Submit the next batch for a channel to its sampler. + + Re-enqueues the channel into ``runnable_channels`` if more batches + remain. + Returns True if a batch was submitted, False if the channel had no + pending work. + """ + with state_lock: + state = active_epochs_by_channel.get(channel_id) + if state is None: + return False + batch_indices = _epoch_batch_indices(state) + if batch_indices is None: + return False + state.submitted_batches += 1 + cfg = cfgs[channel_id] + sampler = samplers[channel_id] + channel_input = inputs[channel_id] + current_epoch = state.epoch + if state.submitted_batches < state.total_batches and not state.cancelled: + runnable_channels.append(channel_id) + runnable_set.add(channel_id) + + sampler_input = channel_input[batch_indices] + + callback = lambda _: _on_batch_done(channel_id, current_epoch) + if cfg.sampling_type == SamplingType.NODE: + sampler.sample_from_nodes( + cast(NodeSamplerInput, sampler_input), callback=callback + ) + elif cfg.sampling_type == SamplingType.LINK: + sampler.sample_from_edges( + cast(EdgeSamplerInput, sampler_input), callback=callback + ) + elif cfg.sampling_type == SamplingType.SUBGRAPH: + sampler.subgraph(cast(NodeSamplerInput, sampler_input), callback=callback) + else: + raise RuntimeError(f"Unsupported sampling type: {cfg.sampling_type}") + return True + + def _pump_runnable_channels() -> bool: + """Submit one batch per runnable channel in round-robin order. + + Returns True if at least one batch was submitted. + """ + made_progress = False + with state_lock: + num_candidates = len(runnable_channels) + for _ in range(num_candidates): + with state_lock: + if not runnable_channels: + break + channel_id = runnable_channels.popleft() + runnable_set.discard(channel_id) + made_progress = _submit_one_batch(channel_id) or made_progress + return made_progress + + def _handle_command(command: SharedMpCommand, payload: object) -> bool: + """Dispatch one command from the task queue. + + Returns True to keep running, False on ``STOP``. + """ + channel_id = _command_channel_id(command, payload) + if command == SharedMpCommand.REGISTER_INPUT: + register = cast(RegisterInputCmd, payload) + assert current_device is not None + sampler = create_dist_sampler( + data=data, + sampling_config=register.sampling_config, + worker_options=worker_options, + channel=register.channel, + sampler_options=sampler_options, + degree_tensors=degree_tensors, + current_device=current_device, + ) + sampler.start_loop() + with state_lock: + samplers[register.channel_id] = sampler + channels[register.channel_id] = register.channel + inputs[register.channel_id] = register.sampler_input + cfgs[register.channel_id] = register.sampling_config + route_key_by_channel[register.channel_id] = register.worker_key + started_epoch[register.channel_id] = -1 + _maybe_log_scheduler_state("register_input", force=True) + return True + + if command == SharedMpCommand.START_EPOCH: + start_epoch = cast(StartEpochCmd, payload) + with state_lock: + if channel_id not in channels: + return True + if started_epoch.get(channel_id, -1) >= start_epoch.epoch: + return True + started_epoch[channel_id] = start_epoch.epoch + sampling_config = cfgs[channel_id] + local_input_len = ( + len(start_epoch.seeds_index) + if start_epoch.seeds_index is not None + else len(inputs[channel_id]) + ) + state = ActiveEpochState( + channel_id=channel_id, + epoch=start_epoch.epoch, + input_len=local_input_len, + batch_size=sampling_config.batch_size, + drop_last=sampling_config.drop_last, + seeds_index=start_epoch.seeds_index, + total_batches=_compute_num_batches( + local_input_len, + sampling_config.batch_size, + sampling_config.drop_last, + ), + ) + active_epochs_by_channel[channel_id] = state + if state.total_batches == 0: + active_epochs_by_channel.pop(channel_id, None) + event_queue.put( + (EPOCH_DONE_EVENT, channel_id, start_epoch.epoch, rank) + ) + return True + _enqueue_channel_if_runnable_locked(channel_id) + _maybe_log_scheduler_state("start_epoch", force=True) + return True + + if command == SharedMpCommand.UNREGISTER_INPUT: + assert channel_id is not None + with state_lock: + _clear_registered_input_locked(channel_id) + _maybe_log_scheduler_state("unregister_input", force=True) + return True + + if command == SharedMpCommand.STOP: + return False + + raise RuntimeError(f"Unknown command type: {command}") + + try: + init_worker_group( + world_size=worker_options.worker_world_size, + rank=worker_options.worker_ranks[rank], + group_name="_sampling_worker_subprocess", + ) + if worker_options.use_all2all: + torch.distributed.init_process_group( + backend="gloo", + timeout=datetime.timedelta(seconds=worker_options.rpc_timeout), + rank=worker_options.worker_ranks[rank], + world_size=worker_options.worker_world_size, + init_method="tcp://{}:{}".format( + worker_options.master_addr, worker_options.master_port + ), + ) + + if worker_options.num_rpc_threads is None: + num_rpc_threads = min(data.num_partitions, 16) + else: + num_rpc_threads = worker_options.num_rpc_threads + current_device = worker_options.worker_devices[rank] + + _set_worker_signal_handlers() + torch.set_num_threads(num_rpc_threads + 1) + + init_rpc( + master_addr=worker_options.master_addr, + master_port=worker_options.master_port, + num_rpc_threads=num_rpc_threads, + rpc_timeout=worker_options.rpc_timeout, + ) + mp_barrier.wait() + + # --- Main event loop --- + keep_running = True + while keep_running: + # Phase 1: Drain all pending commands without blocking. + processed_command = False + while keep_running: + try: + command, payload = task_queue.get_nowait() + except queue.Empty: + break + processed_command = True + keep_running = _handle_command(command, payload) + + # Phase 2: Submit batches round-robin from runnable channels. + made_progress = _pump_runnable_channels() + _maybe_log_scheduler_state("steady_state") + if not keep_running: + break + + # Phase 3: If idle (no commands, no batches), block until next command. + if not (processed_command or made_progress): + try: + command, payload = task_queue.get(timeout=SCHEDULER_TICK_SECS) + except queue.Empty: + continue + keep_running = _handle_command(command, payload) + except KeyboardInterrupt: + pass + finally: + for sampler in list(samplers.values()): + sampler.wait_all() + sampler.shutdown_loop() + shutdown_rpc(graceful=False) + + +class SharedDistSamplingBackend: + """Shared graph-store sampling backend reused across many remote channels.""" + + def __init__( + self, + *, + data: DistDataset, + worker_options: RemoteDistSamplingWorkerOptions, + sampling_config: SamplingConfig, + sampler_options: SamplerOptions, + ) -> None: + self.data = data + self.worker_options = worker_options + self.num_workers = worker_options.num_workers + self._backend_sampling_config = sampling_config + self._sampler_options = sampler_options + self._task_queues: list[mp.Queue] = [] + self._workers: list[BaseProcess] = [] + self._event_queue: Optional[mp.Queue] = None + self._shutdown = False + self._initialized = False + self._lock = threading.RLock() + self._channel_sampling_config: dict[int, SamplingConfig] = {} + self._channel_input_sizes: dict[int, list[int]] = {} + self._channel_worker_seeds_ranges: dict[int, list[tuple[int, int]]] = {} + self._channel_shuffle_generators: dict[int, Optional[torch.Generator]] = {} + self._channel_epoch: dict[int, int] = {} + self._completed_workers: defaultdict[tuple[int, int], set[int]] = defaultdict( + set + ) + + def init_backend(self) -> None: + """Initialize worker processes once for this backend.""" + with self._lock: + if self._initialized: + return + self.worker_options._assign_worker_devices() + current_ctx = get_context() + if current_ctx is None or not current_ctx.is_server(): + raise RuntimeError( + "SharedDistSamplingBackend.init_backend() requires a GLT server context." + ) + self.worker_options._set_worker_ranks(current_ctx) + degree_tensors = _prepare_degree_tensors( + self.data, + self._sampler_options, + ) + mp_context = mp.get_context("spawn") + barrier = mp_context.Barrier(self.num_workers + 1) + self._event_queue = mp_context.Queue() + for rank in range(self.num_workers): + task_queue = mp_context.Queue( + self.num_workers * self.worker_options.worker_concurrency + ) + self._task_queues.append(task_queue) + worker = mp_context.Process( + target=_shared_sampling_worker_loop, + args=( + rank, + self.data, + self.worker_options, + task_queue, + self._event_queue, + barrier, + self._sampler_options, + degree_tensors, + ), + ) + worker.daemon = True + worker.start() + self._workers.append(worker) + barrier.wait() + self._initialized = True + + def _enqueue_worker_command( + self, + worker_rank: int, + command: SharedMpCommand, + payload: object, + ) -> None: + queue_ = self._task_queues[worker_rank] + enqueue_start = time.monotonic() + queue_.put((command, payload)) + elapsed = time.monotonic() - enqueue_start + if elapsed >= SCHEDULER_SLOW_SUBMIT_SECS: + logger.warning( + f"task_queue enqueue_slow worker_rank={worker_rank} " + f"command={command.name} elapsed_secs={elapsed:.2f}" + ) + + def register_input( + self, + channel_id: int, + worker_key: str, + sampler_input: SamplerInput, + sampling_config: SamplingConfig, + channel: ChannelBase, + ) -> None: + """Register a channel-specific input on all backend workers.""" + with self._lock: + if not self._initialized: + raise RuntimeError("SharedDistSamplingBackend is not initialized.") + if channel_id in self._channel_sampling_config: + raise ValueError(f"channel_id {channel_id} is already registered.") + if sampling_config != self._backend_sampling_config: + raise ValueError( + "Sampling config must match the backend sampling config for shared backends." + ) + + shared_sampler_input = sampler_input.share_memory() + worker_ranges = _compute_worker_seeds_ranges( + len(shared_sampler_input), + sampling_config.batch_size, + self.num_workers, + ) + self._channel_sampling_config[channel_id] = sampling_config + self._channel_input_sizes[channel_id] = [ + end - start for start, end in worker_ranges + ] + self._channel_worker_seeds_ranges[channel_id] = worker_ranges + if sampling_config.shuffle: + generator = torch.Generator() + if sampling_config.seed is None: + generator.manual_seed(torch.seed()) + else: + generator.manual_seed(sampling_config.seed) + self._channel_shuffle_generators[channel_id] = generator + else: + self._channel_shuffle_generators[channel_id] = None + self._channel_epoch[channel_id] = -1 + for worker_rank in range(self.num_workers): + self._enqueue_worker_command( + worker_rank, + SharedMpCommand.REGISTER_INPUT, + RegisterInputCmd( + channel_id=channel_id, + worker_key=worker_key, + sampler_input=shared_sampler_input, + sampling_config=sampling_config, + channel=channel, + ), + ) + + def _drain_events(self) -> None: + """Drain worker completion events into the backend-local state.""" + if self._event_queue is None: + return + while True: + try: + event = self._event_queue.get_nowait() + except queue.Empty: + return + if event[0] == EPOCH_DONE_EVENT: + _, channel_id, epoch, worker_rank = event + self._completed_workers[(channel_id, epoch)].add(worker_rank) + + def start_new_epoch_sampling(self, channel_id: int, epoch: int) -> None: + """Start one new epoch for one registered channel.""" + with self._lock: + self._drain_events() + sampling_config = self._channel_sampling_config[channel_id] + if self._channel_epoch[channel_id] >= epoch: + return + previous_epoch = self._channel_epoch[channel_id] + self._channel_epoch[channel_id] = epoch + stale_keys = [ + k + for k in self._completed_workers + if k[0] == channel_id and k[1] <= epoch + ] + for k in stale_keys: + del self._completed_workers[k] + input_len = sum(self._channel_input_sizes[channel_id]) + worker_ranges = self._channel_worker_seeds_ranges[channel_id] + if sampling_config.shuffle: + generator = self._channel_shuffle_generators[channel_id] + assert generator is not None + full_index = torch.randperm(input_len, generator=generator) + for worker_rank, (start, end) in enumerate(worker_ranges): + worker_index = full_index[start:end] + worker_index.share_memory_() + self._enqueue_worker_command( + worker_rank, + SharedMpCommand.START_EPOCH, + StartEpochCmd( + channel_id=channel_id, + epoch=epoch, + seeds_index=worker_index, + ), + ) + else: + for worker_rank, (start, end) in enumerate(worker_ranges): + worker_index = torch.arange(start, end, dtype=torch.long) + worker_index.share_memory_() + self._enqueue_worker_command( + worker_rank, + SharedMpCommand.START_EPOCH, + StartEpochCmd( + channel_id=channel_id, + epoch=epoch, + seeds_index=worker_index, + ), + ) + + def unregister_input(self, channel_id: int) -> None: + """Unregister a channel from the backend workers.""" + with self._lock: + if channel_id not in self._channel_sampling_config: + return + self._drain_events() + self._channel_sampling_config.pop(channel_id, None) + self._channel_input_sizes.pop(channel_id, None) + self._channel_worker_seeds_ranges.pop(channel_id, None) + self._channel_shuffle_generators.pop(channel_id, None) + self._channel_epoch.pop(channel_id, None) + stale_keys = [k for k in self._completed_workers if k[0] == channel_id] + for k in stale_keys: + del self._completed_workers[k] + for worker_rank in range(self.num_workers): + self._enqueue_worker_command( + worker_rank, + SharedMpCommand.UNREGISTER_INPUT, + channel_id, + ) + + def is_channel_epoch_done(self, channel_id: int, epoch: int) -> bool: + """Return whether every worker finished the epoch for one channel.""" + with self._lock: + self._drain_events() + return ( + len(self._completed_workers.get((channel_id, epoch), set())) + == self.num_workers + ) + + def describe_channel(self, channel_id: int) -> dict[str, object]: + """Return lightweight diagnostics for one registered channel.""" + with self._lock: + self._drain_events() + epoch = self._channel_epoch.get(channel_id, -1) + completed_workers = len( + self._completed_workers.get((channel_id, epoch), set()) + ) + return { + "epoch": epoch, + "input_sizes": self._channel_input_sizes.get(channel_id, []), + "completed_workers": completed_workers, + } + + def shutdown(self) -> None: + """Stop all worker processes and release backend resources.""" + with self._lock: + if self._shutdown: + return + self._shutdown = True + try: + for worker_rank in range(len(self._task_queues)): + self._enqueue_worker_command( + worker_rank, + SharedMpCommand.STOP, + None, + ) + for worker in self._workers: + worker.join(timeout=MP_STATUS_CHECK_INTERVAL) + for queue_ in self._task_queues: + queue_.cancel_join_thread() + queue_.close() + if self._event_queue is not None: + self._event_queue.cancel_join_thread() + self._event_queue.close() + finally: + for worker in self._workers: + if worker.is_alive(): + worker.terminate() diff --git a/tests/unit/distributed/dist_sampling_producer_test.py b/tests/unit/distributed/dist_sampling_producer_test.py new file mode 100644 index 000000000..9d19768da --- /dev/null +++ b/tests/unit/distributed/dist_sampling_producer_test.py @@ -0,0 +1,214 @@ +import queue +from typing import cast +from unittest.mock import MagicMock, patch + +import torch +import torch.multiprocessing as mp +from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType + +from gigl.distributed.graph_store.shared_dist_sampling_producer import ( + EPOCH_DONE_EVENT, + ActiveEpochState, + SharedDistSamplingBackend, + SharedMpCommand, + StartEpochCmd, + _compute_num_batches, + _compute_worker_seeds_ranges, + _epoch_batch_indices, +) +from gigl.distributed.sampler_options import KHopNeighborSamplerOptions +from tests.test_assets.test_case import TestCase + + +def _make_sampling_config(*, shuffle: bool = False) -> SamplingConfig: + return SamplingConfig( + sampling_type=SamplingType.NODE, + num_neighbors=[2], + batch_size=2, + shuffle=shuffle, + drop_last=False, + with_edge=True, + collect_features=True, + with_neg=False, + with_weight=False, + edge_dir="out", + seed=1234, + ) + + +class _FakeProcess: + def __init__(self, *args, **kwargs) -> None: + self.daemon = False + + def start(self) -> None: + return None + + def join(self, timeout: float | None = None) -> None: + return None + + def is_alive(self) -> bool: + return False + + def terminate(self) -> None: + return None + + +class _FakeMpContext: + def Barrier(self, parties: int): + return MagicMock(wait=MagicMock()) + + def Queue(self, maxsize: int = 0): + return MagicMock() + + def Process(self, *args, **kwargs): + return _FakeProcess(*args, **kwargs) + + +class DistSamplingProducerTest(TestCase): + def test_compute_num_batches(self) -> None: + self.assertEqual(_compute_num_batches(0, 2, False), 0) + self.assertEqual(_compute_num_batches(1, 2, True), 0) + self.assertEqual(_compute_num_batches(1, 2, False), 1) + self.assertEqual(_compute_num_batches(5, 2, False), 3) + self.assertEqual(_compute_num_batches(5, 2, True), 2) + + def test_epoch_batch_indices(self) -> None: + active_state = ActiveEpochState( + channel_id=0, + epoch=0, + input_len=6, + batch_size=2, + drop_last=False, + seeds_index=torch.arange(6), + total_batches=3, + submitted_batches=1, + cancelled=False, + ) + result = _epoch_batch_indices(active_state) + assert result is not None + self.assert_tensor_equality(result, torch.tensor([2, 3])) + + def test_compute_worker_seeds_ranges(self) -> None: + self.assertEqual( + _compute_worker_seeds_ranges(input_len=7, batch_size=2, num_workers=3), + [(0, 2), (2, 4), (4, 7)], + ) + + @patch("gigl.distributed.graph_store.shared_dist_sampling_producer.get_context") + @patch("gigl.distributed.graph_store.shared_dist_sampling_producer.mp.get_context") + @patch( + "gigl.distributed.graph_store.shared_dist_sampling_producer._prepare_degree_tensors" + ) + def test_init_backend_prepares_worker_options( + self, + mock_prepare_degree_tensors: MagicMock, + mock_get_mp_context: MagicMock, + mock_get_context: MagicMock, + ) -> None: + worker_options = MagicMock() + worker_options.num_workers = 2 + worker_options.worker_concurrency = 1 + mock_get_context.return_value = MagicMock( + is_server=MagicMock(return_value=True) + ) + mock_get_mp_context.return_value = _FakeMpContext() + backend = SharedDistSamplingBackend( + data=MagicMock(), + worker_options=worker_options, + sampling_config=_make_sampling_config(), + sampler_options=KHopNeighborSamplerOptions(num_neighbors=[2]), + ) + + backend.init_backend() + + worker_options._assign_worker_devices.assert_called_once() + worker_options._set_worker_ranks.assert_called_once_with( + mock_get_context.return_value + ) + self.assertEqual(len(backend._task_queues), 2) + self.assertEqual(len(backend._workers), 2) + self.assertTrue(backend._initialized) + mock_prepare_degree_tensors.assert_called_once() + + def test_start_new_epoch_sampling_shuffle_refreshes_per_epoch(self) -> None: + worker_options = MagicMock() + worker_options.num_workers = 2 + worker_options.worker_concurrency = 1 + backend = SharedDistSamplingBackend( + data=MagicMock(), + worker_options=worker_options, + sampling_config=_make_sampling_config(shuffle=True), + sampler_options=KHopNeighborSamplerOptions(num_neighbors=[2]), + ) + backend._initialized = True + recorded: list[tuple[int, SharedMpCommand, object]] = [] + backend._enqueue_worker_command = lambda worker_rank, command, payload: recorded.append( # type: ignore[method-assign] + (worker_rank, command, payload) + ) + + channel = MagicMock() + input_tensor = torch.arange(6, dtype=torch.long) + backend.register_input( + channel_id=1, + worker_key="loader_a_compute_rank_0", + sampler_input=NodeSamplerInput(node=input_tensor.clone()), + sampling_config=_make_sampling_config(shuffle=True), + channel=channel, + ) + backend.register_input( + channel_id=2, + worker_key="loader_b_compute_rank_0", + sampler_input=NodeSamplerInput(node=input_tensor.clone()), + sampling_config=_make_sampling_config(shuffle=True), + channel=channel, + ) + + def _collect_epoch_indices(channel_id: int, epoch: int) -> torch.Tensor: + recorded.clear() + backend.start_new_epoch_sampling(channel_id, epoch) + worker_payloads = { + worker_rank: cast(StartEpochCmd, payload).seeds_index + for worker_rank, command, payload in recorded + if command == SharedMpCommand.START_EPOCH + } + assert all( + seed_index is not None for seed_index in worker_payloads.values() + ) + return torch.cat( + [ + cast(torch.Tensor, worker_payloads[worker_rank]) + for worker_rank in sorted(worker_payloads) + ] + ) + + channel_1_epoch_0 = _collect_epoch_indices(1, 0) + channel_2_epoch_0 = _collect_epoch_indices(2, 0) + channel_1_epoch_1 = _collect_epoch_indices(1, 1) + + self.assert_tensor_equality(channel_1_epoch_0, channel_2_epoch_0) + self.assertNotEqual( + channel_1_epoch_0.tolist(), + channel_1_epoch_1.tolist(), + ) + + def test_describe_channel_reports_completed_workers(self) -> None: + worker_options = MagicMock() + worker_options.num_workers = 2 + worker_options.worker_concurrency = 1 + backend = SharedDistSamplingBackend( + data=MagicMock(), + worker_options=worker_options, + sampling_config=_make_sampling_config(), + sampler_options=KHopNeighborSamplerOptions(num_neighbors=[2]), + ) + backend._initialized = True + backend._event_queue = cast(mp.Queue, queue.Queue()) + backend._channel_input_sizes[1] = [4, 2] + backend._channel_epoch[1] = 3 + cast(queue.Queue, backend._event_queue).put((EPOCH_DONE_EVENT, 1, 3, 0)) + + description = backend.describe_channel(1) + + self.assertEqual(description["epoch"], 3) + self.assertEqual(description["input_sizes"], [4, 2]) + self.assertEqual(description["completed_workers"], 1) From bc400818fffe9eb2ba650dbad1942171ad73b7ed Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 8 Apr 2026 16:57:43 +0000 Subject: [PATCH 3/8] docs: add ASCII architecture diagrams to shared_dist_sampling_producer module docstring Co-Authored-By: Claude Sonnet 4.6 --- .../shared_dist_sampling_producer.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index b22ad9403..24188fb10 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -4,6 +4,58 @@ mode. A single ``SharedDistSamplingBackend`` per loader instance manages a pool of worker processes that service many compute-rank channels through a fair-queued scheduler (``_shared_sampling_worker_loop``). + +High-level architecture:: + + ┌──────────────────────────────────────────────┐ + │ SharedDistSamplingBackend │ + │ (main process) │ + ├──────────────────────────────────────────────┤ + │ register_input() │ + │ start_new_epoch_sampling() │ + │ is_channel_epoch_done() │ + │ unregister_input() │ + │ shutdown() │ + └──────┬──────────────────────────────▲────────┘ + │ task_queues │ event_queue + │ (SharedMpCommand, payload) │ (EPOCH_DONE_EVENT, + │ │ channel_id, epoch, + ▼ │ worker_rank) + ┌──────────────────────────────────────────────┐ + │ Worker 0 .. N-1 │ + │ _shared_sampling_worker_loop() │ + │ │ + │ ┌─────────────┐ sample_from_* ┌─────────┐│ + │ │ Sampler │───────────────▶│ Channel ││ + │ │ (per channel)│ (results) │ (output) ││ + │ └─────────────┘ └─────────┘│ + └──────────────────────────────────────────────┘ + +Worker event-loop internals:: + + ┌─────────────────────────────────────────────────┐ + │ Phase 1: Drain commands (non-blocking) │ + │ task_queue.get_nowait() ──▶ _handle_command() │ + │ REGISTER_INPUT ──▶ create sampler + state │ + │ START_EPOCH ──▶ ActiveEpochState │ + │ + enqueue to runnable │ + │ UNREGISTER_INPUT ──▶ cleanup / defer │ + │ STOP ──▶ exit loop │ + ├─────────────────────────────────────────────────┤ + │ Phase 2: Round-robin batch submission │ + │ for each channel in runnable_channels: │ + │ pop ──▶ _submit_one_batch() │ + │ ──▶ sampler.sample_from_*() │ + │ if more batches: re-enqueue channel │ + │ │ + │ completion callback (_on_batch_done): │ + │ completed_batches += 1 │ + │ if all done ──▶ EPOCH_DONE to event_queue │ + ├─────────────────────────────────────────────────┤ + │ Phase 3: Idle wait │ + │ if no commands and no batches submitted: │ + │ task_queue.get(timeout=SCHEDULER_TICK_SECS) │ + └─────────────────────────────────────────────────┘ """ import datetime From 546a5135d2d89099688ab1980056cdec4aab39ba Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 8 Apr 2026 16:59:19 +0000 Subject: [PATCH 4/8] docs: add Google-style docstrings to SharedMpCommand enum and dataclasses Co-Authored-By: Claude Sonnet 4.6 --- .../shared_dist_sampling_producer.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index 24188fb10..1158c51d6 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -130,6 +130,23 @@ def _prepare_degree_tensors( class SharedMpCommand(Enum): + """Commands sent from the backend to worker subprocesses via task queues. + + Each command is paired with a payload in a ``(command, payload)`` tuple + placed on the per-worker ``task_queue``. + + Attributes: + REGISTER_INPUT: Register a new channel with its sampler input, + sampling config, and output channel. + Payload: ``RegisterInputCmd``. + UNREGISTER_INPUT: Remove a channel and clean up its state. + Payload: ``int`` (the channel_id). + START_EPOCH: Begin sampling a new epoch for one channel. + Payload: ``StartEpochCmd``. + STOP: Shut down the worker process. + Payload: ``None``. + """ + REGISTER_INPUT = auto() UNREGISTER_INPUT = auto() START_EPOCH = auto() @@ -138,6 +155,20 @@ class SharedMpCommand(Enum): @dataclass(frozen=True) class RegisterInputCmd: + """Payload for ``SharedMpCommand.REGISTER_INPUT``. + + Carries everything a worker needs to set up sampling for one channel. + + Attributes: + channel_id: Unique identifier for this channel across the backend. + worker_key: Routing key used to identify this channel in the worker + group (passed through to ``create_dist_sampler``). + sampler_input: The full set of seed node/edge inputs for this channel, + already in shared memory. + sampling_config: Sampling parameters (batch size, num neighbors, etc.). + channel: The output channel where sampled subgraphs are written. + """ + channel_id: int worker_key: str sampler_input: SamplerInput @@ -147,6 +178,17 @@ class RegisterInputCmd: @dataclass(frozen=True) class StartEpochCmd: + """Payload for ``SharedMpCommand.START_EPOCH``. + + Attributes: + channel_id: The channel whose epoch is starting. + epoch: Monotonically increasing epoch number. + Duplicate or stale epochs are silently ignored by the worker. + seeds_index: Index tensor selecting which seeds from the channel's + ``sampler_input`` to sample this epoch. + ``None`` means use the full input range. + """ + channel_id: int epoch: int seeds_index: Optional[torch.Tensor] @@ -154,6 +196,29 @@ class StartEpochCmd: @dataclass class ActiveEpochState: + """Mutable per-channel state for an in-progress epoch inside a worker. + + Created by ``_handle_command`` on ``START_EPOCH`` and removed when all + batches complete. + + Attributes: + channel_id: The channel this epoch belongs to. + epoch: The epoch number. + input_len: Total number of seed indices assigned to this worker for + this epoch. + batch_size: Number of seeds per batch. + drop_last: If True, the final incomplete batch is skipped. + seeds_index: Index tensor into the channel's ``sampler_input``. + ``None`` means sequential indices ``[0, input_len)``. + total_batches: Pre-computed number of batches for this epoch. + submitted_batches: Number of batches submitted to the sampler so far. + Mutated by ``_submit_one_batch``. + completed_batches: Number of batches whose sampler callbacks have + fired. Mutated by ``_on_batch_done``. + cancelled: Set to True when the channel is unregistered while batches + are still in flight. Mutated by ``_clear_registered_input_locked``. + """ + channel_id: int epoch: int input_len: int From 2a2ee93b1bfe1f5256323a05654fb96fc1bfd6bd Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 8 Apr 2026 17:01:35 +0000 Subject: [PATCH 5/8] docs: add Google-style docstrings to module-level helper functions Co-Authored-By: Claude Sonnet 4.6 --- .../shared_dist_sampling_producer.py | 76 +++++++++++++++++-- 1 file changed, 68 insertions(+), 8 deletions(-) diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index 1158c51d6..d68945d52 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -105,7 +105,22 @@ def _prepare_degree_tensors( data: DistDataset, sampler_options: SamplerOptions, ) -> Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: - """Materialize PPR degree tensors before worker spawn when required.""" + """Materialize PPR degree tensors before worker spawn when required. + + Called once in the main process so that degree data is available in shared + memory before workers fork. Returns ``None`` for non-PPR sampler options. + + Args: + data: The distributed dataset whose ``degree_tensor`` property is + read. + sampler_options: Sampler configuration. Degree tensors are only + materialized when this is a ``PPRSamplerOptions`` instance. + + Returns: + A single degree tensor (homogeneous graph), a dict mapping edge types + to degree tensors (heterogeneous graph), or ``None`` if PPR sampling + is not configured. + """ degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] = None if isinstance(sampler_options, PPRSamplerOptions): degree_tensors = data.degree_tensor @@ -232,7 +247,17 @@ class ActiveEpochState: def _command_channel_id(command: SharedMpCommand, payload: object) -> Optional[int]: - """Extract the channel id from a worker command payload.""" + """Extract the channel id from a worker command payload. + + Args: + command: The command type. + payload: The associated payload — one of ``RegisterInputCmd``, + ``StartEpochCmd``, ``int`` (channel_id), or ``None``. + + Returns: + The channel id if the command targets a specific channel, + or ``None`` for ``STOP``. + """ if command == SharedMpCommand.STOP: return None if isinstance(payload, RegisterInputCmd): @@ -245,7 +270,17 @@ def _command_channel_id(command: SharedMpCommand, payload: object) -> Optional[i def _compute_num_batches(input_len: int, batch_size: int, drop_last: bool) -> int: - """Compute the number of batches emitted for an input length.""" + """Compute the number of batches emitted for an input length. + + Args: + input_len: Total number of seed indices. + batch_size: Number of seeds per batch. + drop_last: If True, drops the final batch when it is smaller than + ``batch_size``. + + Returns: + The number of batches. Returns 0 when ``input_len <= 0``. + """ if input_len <= 0: return 0 if drop_last: @@ -254,11 +289,21 @@ def _compute_num_batches(input_len: int, batch_size: int, drop_last: bool) -> in def _epoch_batch_indices(state: ActiveEpochState) -> Optional[torch.Tensor]: - """Return the next batch of indices for an active epoch. + """Return the next batch of seed indices for an active epoch. + + Advances the logical cursor by one batch based on + ``state.submitted_batches``. - Returns the index tensor for the next batch, or None if no more batches - should be submitted (epoch cancelled, all batches already submitted, or - incomplete final batch with drop_last=True). + Args: + state: The mutable epoch state for the channel. + ``submitted_batches`` is read but **not** mutated here — the + caller (``_submit_one_batch``) increments it after calling. + + Returns: + A 1-D ``torch.long`` tensor of seed indices for the next batch, + or ``None`` if no more batches should be submitted (epoch cancelled, + all batches already submitted, or incomplete final batch with + ``drop_last=True``). """ if state.cancelled or state.submitted_batches >= state.total_batches: return None @@ -276,7 +321,22 @@ def _epoch_batch_indices(state: ActiveEpochState) -> Optional[torch.Tensor]: def _compute_worker_seeds_ranges( input_len: int, batch_size: int, num_workers: int ) -> list[tuple[int, int]]: - """Distribute complete batches across workers like GLT's producer does.""" + """Distribute seed indices across workers using GLT-compatible logic. + + Divides complete batches as evenly as possible across workers + (lower-ranked workers get one extra batch when the division is uneven). + The last worker's range extends to ``input_len`` so that the remainder + (incomplete final batch) is included. + + Args: + input_len: Total number of seed indices. + batch_size: Number of seeds per batch. + num_workers: Number of worker processes. + + Returns: + A list of ``(start, end)`` index ranges, one per worker. The ranges + are contiguous and non-overlapping, covering ``[0, input_len)``. + """ num_worker_batches = [0] * num_workers num_total_complete_batches = input_len // batch_size for rank in range(num_workers): From 4dcb3e7c2dc33debd85923d2e7d01d1a1bba195d Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 8 Apr 2026 17:03:21 +0000 Subject: [PATCH 6/8] docs: add Args section to _shared_sampling_worker_loop docstring --- .../shared_dist_sampling_producer.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index d68945d52..debc97f20 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -371,6 +371,26 @@ def _shared_sampling_worker_loop( Multiple input channels (each representing one compute rank's data stream) share the same sampling worker processes and graph data. + Args: + rank: This worker's index within the pool (``0 .. num_workers-1``). + data: The distributed dataset, shared across all workers via the + spawn context. + worker_options: GLT remote sampling worker configuration (RPC + addresses, devices, concurrency settings). + task_queue: Per-worker command queue. The backend enqueues + ``(SharedMpCommand, payload)`` tuples; the worker drains them + in Phase 1 of the event loop. + event_queue: Shared completion queue. Workers emit + ``(EPOCH_DONE_EVENT, channel_id, epoch, worker_rank)`` tuples + when all batches for an epoch have completed. + mp_barrier: Synchronization barrier. The worker signals it after + RPC initialization is complete so the parent can proceed. + sampler_options: GiGL sampler configuration (e.g. ``PPRSamplerOptions`` + for PPR-based sampling). + degree_tensors: Pre-computed degree tensors for PPR sampling, or + ``None`` for non-PPR samplers. Materialized once in the parent + process by ``_prepare_degree_tensors`` and shared across workers. + Algorithm: 1. Initialize RPC, sampler infrastructure, and signal the parent via barrier. 2. Enter the main event loop which alternates between: From 270d2c63a628ae86d9b70898df229e29b4eff196 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 8 Apr 2026 17:05:35 +0000 Subject: [PATCH 7/8] docs: add Google-style docstrings to SharedDistSamplingBackend methods Co-Authored-By: Claude Opus 4.6 --- .../shared_dist_sampling_producer.py | 142 +++++++++++++++++- 1 file changed, 134 insertions(+), 8 deletions(-) diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index debc97f20..ecea07a86 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -747,6 +747,20 @@ def __init__( sampling_config: SamplingConfig, sampler_options: SamplerOptions, ) -> None: + """Initialize the shared sampling backend. + + Does not start worker processes — call ``init_backend`` to spawn them. + + Args: + data: The distributed dataset to sample from. + worker_options: GLT remote sampling worker configuration (RPC + addresses, devices, concurrency). + sampling_config: Sampling parameters (batch size, neighbor counts, + shuffle, etc.). All channels registered on this backend must + use the same config. + sampler_options: GiGL sampler variant configuration (e.g. + ``PPRSamplerOptions`` for PPR-based sampling). + """ self.data = data self.worker_options = worker_options self.num_workers = worker_options.num_workers @@ -768,7 +782,26 @@ def __init__( ) def init_backend(self) -> None: - """Initialize worker processes once for this backend.""" + """Initialize worker processes once for this backend. + + Spawns ``num_workers`` subprocesses running + ``_shared_sampling_worker_loop``. Each worker initializes RPC and + signals readiness via a shared barrier. This method blocks until all + workers are ready. + + The initialization sequence is: + + 1. Assign devices and worker ranks from the GLT server context. + 2. Pre-compute degree tensors for PPR sampling (if applicable). + 3. Spawn worker processes with per-worker task queues and a shared + event queue. + 4. Wait on the barrier for all workers to finish RPC init. + + No-op if already initialized. + + Raises: + RuntimeError: If no GLT server context is active. + """ with self._lock: if self._initialized: return @@ -816,6 +849,17 @@ def _enqueue_worker_command( command: SharedMpCommand, payload: object, ) -> None: + """Enqueue a command on one worker's task queue. + + Logs a warning if the enqueue blocks for longer than + ``SCHEDULER_SLOW_SUBMIT_SECS``. + + Args: + worker_rank: Index of the target worker (``0 .. num_workers-1``). + command: The command type to send. + payload: The command payload (``RegisterInputCmd``, + ``StartEpochCmd``, ``int``, or ``None``). + """ queue_ = self._task_queues[worker_rank] enqueue_start = time.monotonic() queue_.put((command, payload)) @@ -834,7 +878,25 @@ def register_input( sampling_config: SamplingConfig, channel: ChannelBase, ) -> None: - """Register a channel-specific input on all backend workers.""" + """Register a new channel on all backend workers. + + Moves ``sampler_input`` into shared memory, computes per-worker seed + ranges, initializes shuffle state (if configured), and broadcasts a + ``REGISTER_INPUT`` command to every worker. + + Args: + channel_id: Unique identifier for this channel. + worker_key: Routing key for the channel in the worker group. + sampler_input: Seed node/edge inputs for this channel. + sampling_config: Must match the backend's ``sampling_config``. + channel: Output channel where sampled subgraphs are written. + + Raises: + RuntimeError: If the backend has not been initialized via + ``init_backend``. + ValueError: If ``channel_id`` is already registered, or if + ``sampling_config`` does not match the backend config. + """ with self._lock: if not self._initialized: raise RuntimeError("SharedDistSamplingBackend is not initialized.") @@ -880,7 +942,12 @@ def register_input( ) def _drain_events(self) -> None: - """Drain worker completion events into the backend-local state.""" + """Drain worker completion events into the backend-local state. + + Reads all pending ``EPOCH_DONE_EVENT`` tuples from the shared + ``event_queue`` and records which workers have finished each + ``(channel_id, epoch)`` in ``_completed_workers``. + """ if self._event_queue is None: return while True: @@ -893,7 +960,21 @@ def _drain_events(self) -> None: self._completed_workers[(channel_id, epoch)].add(worker_rank) def start_new_epoch_sampling(self, channel_id: int, epoch: int) -> None: - """Start one new epoch for one registered channel.""" + """Start a new sampling epoch for one registered channel. + + Cleans up stale completion records, generates a shuffled or sequential + seed permutation, slices it into per-worker ranges, and dispatches + ``START_EPOCH`` commands to all workers. + + No-op if the channel has already started an epoch >= ``epoch``. + + Args: + channel_id: The registered channel to start. + epoch: Monotonically increasing epoch number. + + Raises: + KeyError: If ``channel_id`` is not registered. + """ with self._lock: self._drain_events() sampling_config = self._channel_sampling_config[channel_id] @@ -941,7 +1022,16 @@ def start_new_epoch_sampling(self, channel_id: int, epoch: int) -> None: ) def unregister_input(self, channel_id: int) -> None: - """Unregister a channel from the backend workers.""" + """Unregister a channel from all backend workers. + + Removes backend-side bookkeeping and broadcasts + ``UNREGISTER_INPUT`` to every worker. + + No-op if ``channel_id`` is not currently registered. + + Args: + channel_id: The channel to remove. + """ with self._lock: if channel_id not in self._channel_sampling_config: return @@ -962,7 +1052,18 @@ def unregister_input(self, channel_id: int) -> None: ) def is_channel_epoch_done(self, channel_id: int, epoch: int) -> bool: - """Return whether every worker finished the epoch for one channel.""" + """Return whether every worker finished the epoch for one channel. + + Drains pending completion events before checking. + + Args: + channel_id: The channel to query. + epoch: The epoch number to check. + + Returns: + ``True`` if all ``num_workers`` workers have reported + ``EPOCH_DONE`` for this ``(channel_id, epoch)`` pair. + """ with self._lock: self._drain_events() return ( @@ -971,7 +1072,21 @@ def is_channel_epoch_done(self, channel_id: int, epoch: int) -> bool: ) def describe_channel(self, channel_id: int) -> dict[str, object]: - """Return lightweight diagnostics for one registered channel.""" + """Return lightweight diagnostics for one registered channel. + + Drains pending completion events before building the snapshot. + + Args: + channel_id: The channel to describe. + + Returns: + A dict with keys: + + - ``"epoch"``: Current epoch number (``-1`` if never started). + - ``"input_sizes"``: Per-worker seed counts. + - ``"completed_workers"``: Number of workers that finished the + current epoch. + """ with self._lock: self._drain_events() epoch = self._channel_epoch.get(channel_id, -1) @@ -985,7 +1100,18 @@ def describe_channel(self, channel_id: int) -> dict[str, object]: } def shutdown(self) -> None: - """Stop all worker processes and release backend resources.""" + """Stop all worker processes and release backend resources. + + Cleanup sequence: + + 1. Send ``STOP`` to every worker's task queue. + 2. Join each worker with a timeout of + ``MP_STATUS_CHECK_INTERVAL`` seconds. + 3. Close all task queues and the event queue. + 4. Terminate any workers still alive after the join timeout. + + No-op if already shut down. + """ with self._lock: if self._shutdown: return From 51a45b95f592011a4ea025a00787e20a289c54d3 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 8 Apr 2026 18:00:00 +0000 Subject: [PATCH 8/8] update --- .../shared_dist_sampling_producer.py | 188 ++++++++---------- .../shared_dist_sampling_producer_test.py} | 8 +- 2 files changed, 81 insertions(+), 115 deletions(-) rename tests/unit/distributed/{dist_sampling_producer_test.py => graph_store/shared_dist_sampling_producer_test.py} (97%) diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index ecea07a86..c69a51f52 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -5,6 +5,12 @@ pool of worker processes that service many compute-rank channels through a fair-queued scheduler (``_shared_sampling_worker_loop``). +We need this "fair-queued" scheduler to ensure that each compute rank gets a fair share of the work. +If we didn't have this, then compute ranks with more data would starve the compute ranks with less data +as `sample_from_*` calls would be blocked by the compute ranks with more data. +Surprisingly, upping `worker_concurrency` does not fix this problem. +TODO(kmonte): Look into why worker_concurrency does not fix this problem. + High-level architecture:: ┌──────────────────────────────────────────────┐ @@ -25,10 +31,10 @@ │ Worker 0 .. N-1 │ │ _shared_sampling_worker_loop() │ │ │ - │ ┌─────────────┐ sample_from_* ┌─────────┐│ - │ │ Sampler │───────────────▶│ Channel ││ - │ │ (per channel)│ (results) │ (output) ││ - │ └─────────────┘ └─────────┘│ + │ ┌─────────────┐ sample_from_* ┌─────────┐ │ + │ │ Sampler │───────────────▶│ Channel │ │ + │ │ (per channel)│ (results) │ (output)│ │ + │ └─────────────┘ └─────────┘ │ └──────────────────────────────────────────────┘ Worker event-loop internals:: @@ -42,10 +48,10 @@ │ UNREGISTER_INPUT ──▶ cleanup / defer │ │ STOP ──▶ exit loop │ ├─────────────────────────────────────────────────┤ - │ Phase 2: Round-robin batch submission │ - │ for each channel in runnable_channels: │ + │ Phase 2: Round-robin batch submission │ + │ for each channel in runnable_channel_ids: │ │ pop ──▶ _submit_one_batch() │ - │ ──▶ sampler.sample_from_*() │ + │ ──▶ sampler.sample_from_*() │ │ if more batches: re-enqueue channel │ │ │ │ completion callback (_on_batch_done): │ @@ -91,7 +97,7 @@ from torch._C import _set_worker_signal_handlers from gigl.common.logger import Logger -from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions +from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils.dist_sampler import ( SamplerInput, SamplerRuntime, @@ -101,42 +107,6 @@ logger = Logger() -def _prepare_degree_tensors( - data: DistDataset, - sampler_options: SamplerOptions, -) -> Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: - """Materialize PPR degree tensors before worker spawn when required. - - Called once in the main process so that degree data is available in shared - memory before workers fork. Returns ``None`` for non-PPR sampler options. - - Args: - data: The distributed dataset whose ``degree_tensor`` property is - read. - sampler_options: Sampler configuration. Degree tensors are only - materialized when this is a ``PPRSamplerOptions`` instance. - - Returns: - A single degree tensor (homogeneous graph), a dict mapping edge types - to degree tensors (heterogeneous graph), or ``None`` if PPR sampling - is not configured. - """ - degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] = None - if isinstance(sampler_options, PPRSamplerOptions): - degree_tensors = data.degree_tensor - if isinstance(degree_tensors, dict): - logger.info( - "Pre-computed degree tensors for PPR sampling across " - f"{len(degree_tensors)} edge types." - ) - elif degree_tensors is not None: - logger.info( - "Pre-computed degree tensor for PPR sampling with " - f"{degree_tensors.size(0)} nodes." - ) - return degree_tensors - - EPOCH_DONE_EVENT = "EPOCH_DONE" SCHEDULER_TICK_SECS = 0.05 SCHEDULER_STATE_LOG_INTERVAL_SECS = 10.0 @@ -396,7 +366,7 @@ def _shared_sampling_worker_loop( 2. Enter the main event loop which alternates between: a. Draining all pending commands from ``task_queue`` (register/unregister channels, start epochs, stop). - b. Submitting batches round-robin from ``runnable_channels`` — a FIFO + b. Submitting batches round-robin from ``runnable_channel_ids`` — a FIFO queue of channels that have pending work. Each channel gets one batch submitted per round to prevent starvation. c. If no commands were processed and no batches submitted, blocking on @@ -405,16 +375,16 @@ def _shared_sampling_worker_loop( ``EPOCH_DONE_EVENT`` to ``event_queue`` when all batches for an epoch are finished. """ - samplers: dict[int, SamplerRuntime] = {} - channels: dict[int, ChannelBase] = {} - inputs: dict[int, SamplerInput] = {} - cfgs: dict[int, SamplingConfig] = {} - route_key_by_channel: dict[int, str] = {} - started_epoch: dict[int, int] = {} - active_epochs_by_channel: dict[int, ActiveEpochState] = {} - runnable_channels: deque[int] = deque() - runnable_set: set[int] = set() - removing: set[int] = set() + sampler_by_channel_id: dict[int, SamplerRuntime] = {} + output_channel_by_channel_id: dict[int, ChannelBase] = {} + input_by_channel_id: dict[int, SamplerInput] = {} + config_by_channel_id: dict[int, SamplingConfig] = {} + route_key_by_channel_id: dict[int, str] = {} + started_epoch_by_channel_id: dict[int, int] = {} + active_epoch_by_channel_id: dict[int, ActiveEpochState] = {} + runnable_channel_ids: deque[int] = deque() + runnable_channel_id_set: set[int] = set() + removing_channel_ids: set[int] = set() state_lock = threading.RLock() last_state_log_time = 0.0 current_device: Optional[torch.device] = None @@ -423,15 +393,15 @@ def _shared_sampling_worker_loop( def _enqueue_channel_if_runnable_locked(channel_id: int) -> None: """Add channel to the fair-queue if it has pending batches.""" - state = active_epochs_by_channel.get(channel_id) + state = active_epoch_by_channel_id.get(channel_id) if state is None: return if state.cancelled or state.submitted_batches >= state.total_batches: return - if channel_id in runnable_set: + if channel_id in runnable_channel_id_set: return - runnable_channels.append(channel_id) - runnable_set.add(channel_id) + runnable_channel_ids.append(channel_id) + runnable_channel_id_set.add(channel_id) def _clear_registered_input_locked(channel_id: int) -> None: """Remove a channel's registration and clean up all associated state. @@ -442,34 +412,34 @@ def _clear_registered_input_locked(channel_id: int) -> None: ``_on_batch_done`` will finish the cleanup once the last in-flight batch completes. """ - state = active_epochs_by_channel.get(channel_id) + state = active_epoch_by_channel_id.get(channel_id) if state is not None and state.completed_batches < state.submitted_batches: - removing.add(channel_id) + removing_channel_ids.add(channel_id) state.cancelled = True return - sampler = samplers.pop(channel_id, None) + sampler = sampler_by_channel_id.pop(channel_id, None) if sampler is not None: sampler.wait_all() sampler.shutdown_loop() - channels.pop(channel_id, None) - inputs.pop(channel_id, None) - cfgs.pop(channel_id, None) - route_key_by_channel.pop(channel_id, None) - started_epoch.pop(channel_id, None) - active_epochs_by_channel.pop(channel_id, None) - runnable_set.discard(channel_id) - removing.discard(channel_id) + output_channel_by_channel_id.pop(channel_id, None) + input_by_channel_id.pop(channel_id, None) + config_by_channel_id.pop(channel_id, None) + route_key_by_channel_id.pop(channel_id, None) + started_epoch_by_channel_id.pop(channel_id, None) + active_epoch_by_channel_id.pop(channel_id, None) + runnable_channel_id_set.discard(channel_id) + removing_channel_ids.discard(channel_id) def _format_scheduler_state_locked() -> str: """Format a human-readable snapshot of the scheduler for logging. Must be called while holding ``state_lock``. """ - channel_ids = sorted(channels.keys()) + channel_ids = sorted(output_channel_by_channel_id.keys()) preview = channel_ids[:SCHEDULER_STATE_MAX_CHANNELS] previews: list[str] = [] for channel_id in preview: - active_epoch = active_epochs_by_channel.get(channel_id) + active_epoch = active_epoch_by_channel_id.get(channel_id) if active_epoch is None: previews.append(f"{channel_id}:idle") else: @@ -483,8 +453,8 @@ def _format_scheduler_state_locked() -> str: if len(channel_ids) > len(preview): extra = f" +{len(channel_ids) - len(preview)}" return ( - f"registered={len(channels)} active={len(active_epochs_by_channel)} " - f"runnable={len(runnable_set)} removing={len(removing)} " + f"registered={len(output_channel_by_channel_id)} active={len(active_epoch_by_channel_id)} " + f"runnable={len(runnable_channel_id_set)} removing={len(removing_channel_ids)} " f"channels=[{', '.join(previews)}]{extra}" ) @@ -517,15 +487,15 @@ def _on_batch_done(channel_id: int, epoch: int) -> None: ``_clear_registered_input_locked``. """ with state_lock: - state = active_epochs_by_channel.get(channel_id) + state = active_epoch_by_channel_id.get(channel_id) if state is None or state.epoch != epoch: return state.completed_batches += 1 if state.completed_batches == state.total_batches: - active_epochs_by_channel.pop(channel_id, None) + active_epoch_by_channel_id.pop(channel_id, None) event_queue.put((EPOCH_DONE_EVENT, channel_id, epoch, rank)) if ( - channel_id in removing + channel_id in removing_channel_ids and state.completed_batches == state.submitted_batches ): _clear_registered_input_locked(channel_id) @@ -533,26 +503,26 @@ def _on_batch_done(channel_id: int, epoch: int) -> None: def _submit_one_batch(channel_id: int) -> bool: """Submit the next batch for a channel to its sampler. - Re-enqueues the channel into ``runnable_channels`` if more batches + Re-enqueues the channel into ``runnable_channel_ids`` if more batches remain. Returns True if a batch was submitted, False if the channel had no pending work. """ with state_lock: - state = active_epochs_by_channel.get(channel_id) + state = active_epoch_by_channel_id.get(channel_id) if state is None: return False batch_indices = _epoch_batch_indices(state) if batch_indices is None: return False state.submitted_batches += 1 - cfg = cfgs[channel_id] - sampler = samplers[channel_id] - channel_input = inputs[channel_id] + cfg = config_by_channel_id[channel_id] + sampler = sampler_by_channel_id[channel_id] + channel_input = input_by_channel_id[channel_id] current_epoch = state.epoch if state.submitted_batches < state.total_batches and not state.cancelled: - runnable_channels.append(channel_id) - runnable_set.add(channel_id) + runnable_channel_ids.append(channel_id) + runnable_channel_id_set.add(channel_id) sampler_input = channel_input[batch_indices] @@ -571,20 +541,20 @@ def _submit_one_batch(channel_id: int) -> bool: raise RuntimeError(f"Unsupported sampling type: {cfg.sampling_type}") return True - def _pump_runnable_channels() -> bool: + def _pump_runnable_channel_ids() -> bool: """Submit one batch per runnable channel in round-robin order. Returns True if at least one batch was submitted. """ made_progress = False with state_lock: - num_candidates = len(runnable_channels) + num_candidates = len(runnable_channel_ids) for _ in range(num_candidates): with state_lock: - if not runnable_channels: + if not runnable_channel_ids: break - channel_id = runnable_channels.popleft() - runnable_set.discard(channel_id) + channel_id = runnable_channel_ids.popleft() + runnable_channel_id_set.discard(channel_id) made_progress = _submit_one_batch(channel_id) or made_progress return made_progress @@ -608,28 +578,28 @@ def _handle_command(command: SharedMpCommand, payload: object) -> bool: ) sampler.start_loop() with state_lock: - samplers[register.channel_id] = sampler - channels[register.channel_id] = register.channel - inputs[register.channel_id] = register.sampler_input - cfgs[register.channel_id] = register.sampling_config - route_key_by_channel[register.channel_id] = register.worker_key - started_epoch[register.channel_id] = -1 + sampler_by_channel_id[register.channel_id] = sampler + output_channel_by_channel_id[register.channel_id] = register.channel + input_by_channel_id[register.channel_id] = register.sampler_input + config_by_channel_id[register.channel_id] = register.sampling_config + route_key_by_channel_id[register.channel_id] = register.worker_key + started_epoch_by_channel_id[register.channel_id] = -1 _maybe_log_scheduler_state("register_input", force=True) return True if command == SharedMpCommand.START_EPOCH: start_epoch = cast(StartEpochCmd, payload) with state_lock: - if channel_id not in channels: + if channel_id not in output_channel_by_channel_id: return True - if started_epoch.get(channel_id, -1) >= start_epoch.epoch: + if started_epoch_by_channel_id.get(channel_id, -1) >= start_epoch.epoch: return True - started_epoch[channel_id] = start_epoch.epoch - sampling_config = cfgs[channel_id] + started_epoch_by_channel_id[channel_id] = start_epoch.epoch + sampling_config = config_by_channel_id[channel_id] local_input_len = ( len(start_epoch.seeds_index) if start_epoch.seeds_index is not None - else len(inputs[channel_id]) + else len(input_by_channel_id[channel_id]) ) state = ActiveEpochState( channel_id=channel_id, @@ -644,9 +614,9 @@ def _handle_command(command: SharedMpCommand, payload: object) -> bool: sampling_config.drop_last, ), ) - active_epochs_by_channel[channel_id] = state + active_epoch_by_channel_id[channel_id] = state if state.total_batches == 0: - active_epochs_by_channel.pop(channel_id, None) + active_epoch_by_channel_id.pop(channel_id, None) event_queue.put( (EPOCH_DONE_EVENT, channel_id, start_epoch.epoch, rank) ) @@ -715,7 +685,7 @@ def _handle_command(command: SharedMpCommand, payload: object) -> bool: keep_running = _handle_command(command, payload) # Phase 2: Submit batches round-robin from runnable channels. - made_progress = _pump_runnable_channels() + made_progress = _pump_runnable_channel_ids() _maybe_log_scheduler_state("steady_state") if not keep_running: break @@ -730,7 +700,7 @@ def _handle_command(command: SharedMpCommand, payload: object) -> bool: except KeyboardInterrupt: pass finally: - for sampler in list(samplers.values()): + for sampler in list(sampler_by_channel_id.values()): sampler.wait_all() sampler.shutdown_loop() shutdown_rpc(graceful=False) @@ -746,6 +716,7 @@ def __init__( worker_options: RemoteDistSamplingWorkerOptions, sampling_config: SamplingConfig, sampler_options: SamplerOptions, + degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], ) -> None: """Initialize the shared sampling backend. @@ -760,6 +731,7 @@ def __init__( use the same config. sampler_options: GiGL sampler variant configuration (e.g. ``PPRSamplerOptions`` for PPR-based sampling). + degree_tensors: Pre-computed degree tensors for PPR sampling (if applicable). """ self.data = data self.worker_options = worker_options @@ -780,6 +752,7 @@ def __init__( self._completed_workers: defaultdict[tuple[int, int], set[int]] = defaultdict( set ) + self._degree_tensors = degree_tensors def init_backend(self) -> None: """Initialize worker processes once for this backend. @@ -792,7 +765,6 @@ def init_backend(self) -> None: The initialization sequence is: 1. Assign devices and worker ranks from the GLT server context. - 2. Pre-compute degree tensors for PPR sampling (if applicable). 3. Spawn worker processes with per-worker task queues and a shared event queue. 4. Wait on the barrier for all workers to finish RPC init. @@ -812,10 +784,6 @@ def init_backend(self) -> None: "SharedDistSamplingBackend.init_backend() requires a GLT server context." ) self.worker_options._set_worker_ranks(current_ctx) - degree_tensors = _prepare_degree_tensors( - self.data, - self._sampler_options, - ) mp_context = mp.get_context("spawn") barrier = mp_context.Barrier(self.num_workers + 1) self._event_queue = mp_context.Queue() @@ -834,7 +802,7 @@ def init_backend(self) -> None: self._event_queue, barrier, self._sampler_options, - degree_tensors, + self._degree_tensors, ), ) worker.daemon = True diff --git a/tests/unit/distributed/dist_sampling_producer_test.py b/tests/unit/distributed/graph_store/shared_dist_sampling_producer_test.py similarity index 97% rename from tests/unit/distributed/dist_sampling_producer_test.py rename to tests/unit/distributed/graph_store/shared_dist_sampling_producer_test.py index 9d19768da..01598cfbd 100644 --- a/tests/unit/distributed/dist_sampling_producer_test.py +++ b/tests/unit/distributed/graph_store/shared_dist_sampling_producer_test.py @@ -96,12 +96,8 @@ def test_compute_worker_seeds_ranges(self) -> None: @patch("gigl.distributed.graph_store.shared_dist_sampling_producer.get_context") @patch("gigl.distributed.graph_store.shared_dist_sampling_producer.mp.get_context") - @patch( - "gigl.distributed.graph_store.shared_dist_sampling_producer._prepare_degree_tensors" - ) def test_init_backend_prepares_worker_options( self, - mock_prepare_degree_tensors: MagicMock, mock_get_mp_context: MagicMock, mock_get_context: MagicMock, ) -> None: @@ -117,6 +113,7 @@ def test_init_backend_prepares_worker_options( worker_options=worker_options, sampling_config=_make_sampling_config(), sampler_options=KHopNeighborSamplerOptions(num_neighbors=[2]), + degree_tensors=None, ) backend.init_backend() @@ -128,7 +125,6 @@ def test_init_backend_prepares_worker_options( self.assertEqual(len(backend._task_queues), 2) self.assertEqual(len(backend._workers), 2) self.assertTrue(backend._initialized) - mock_prepare_degree_tensors.assert_called_once() def test_start_new_epoch_sampling_shuffle_refreshes_per_epoch(self) -> None: worker_options = MagicMock() @@ -139,6 +135,7 @@ def test_start_new_epoch_sampling_shuffle_refreshes_per_epoch(self) -> None: worker_options=worker_options, sampling_config=_make_sampling_config(shuffle=True), sampler_options=KHopNeighborSamplerOptions(num_neighbors=[2]), + degree_tensors=None, ) backend._initialized = True recorded: list[tuple[int, SharedMpCommand, object]] = [] @@ -200,6 +197,7 @@ def test_describe_channel_reports_completed_workers(self) -> None: worker_options=worker_options, sampling_config=_make_sampling_config(), sampler_options=KHopNeighborSamplerOptions(num_neighbors=[2]), + degree_tensors=None, ) backend._initialized = True backend._event_queue = cast(mp.Queue, queue.Queue())