Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
474 changes: 254 additions & 220 deletions gigl/distributed/base_dist_loader.py

Large diffs are not rendered by default.

26 changes: 6 additions & 20 deletions gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import abc, defaultdict
from itertools import count
from typing import Callable, Optional, Union
from typing import Optional, Union

import torch
from graphlearn_torch.channel import SampleMessage
Expand All @@ -21,7 +20,6 @@
PPR_WEIGHT_METADATA_KEY,
)
from gigl.distributed.dist_sampling_producer import DistSamplingProducer
from gigl.distributed.graph_store.dist_server import DistServer
from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset
from gigl.distributed.sampler import (
NEGATIVE_LABEL_METADATA_KEY,
Expand Down Expand Up @@ -64,11 +62,6 @@


class DistABLPLoader(BaseDistLoader):
# Counts instantiations of this class, per process.
# This is needed so we can generate unique worker key for each instance, for graph store mode.
# NOTE: This is per-class, not per-instance.
_counter = count(0)

def __init__(
self,
dataset: Union[DistDataset, RemoteDistDataset],
Expand Down Expand Up @@ -266,7 +259,7 @@ def __init__(
logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}")

del supervision_edge_type
self._instance_count = next(self._counter)
self._instance_count = next(BaseDistLoader._global_loader_counter)

# Resolve distributed context
runtime = BaseDistLoader.resolve_runtime(
Expand Down Expand Up @@ -362,22 +355,17 @@ def __init__(
drop_last=drop_last,
)

# Build the producer: a pre-constructed producer for colocated mode,
# or an RPC callable for graph store mode.
producer: Optional[DistSamplingProducer] = None
if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED:
assert isinstance(dataset, DistDataset)
assert isinstance(worker_options, MpDistSamplingWorkerOptions)
producer: Union[
DistSamplingProducer, Callable[..., int]
] = BaseDistLoader.create_mp_producer(
producer = BaseDistLoader.create_mp_producer(
dataset=dataset,
sampler_input=sampler_input,
sampling_config=sampling_config,
worker_options=worker_options,
sampler_options=sampler_options,
)
else:
producer = DistServer.create_sampling_producer

# Call base class — handles metadata storage and connection initialization
# (including staggered init for colocated mode).
Expand Down Expand Up @@ -624,13 +612,11 @@ def _setup_for_graph_store(
edge_feature_info = dataset.fetch_edge_feature_info()
edge_types = dataset.fetch_edge_types()
compute_rank = torch.distributed.get_rank()
worker_key = (
f"compute_ablp_loader_rank_{compute_rank}_worker_{self._instance_count}"
)
self._backend_key = f"dist_ablp_loader_{self._instance_count}"
worker_key = f"{self._backend_key}_compute_rank_{compute_rank}"
logger.info(f"rank: {compute_rank}, worker_key: {worker_key}")
worker_options = BaseDistLoader.create_graph_store_worker_options(
dataset=dataset,
compute_rank=compute_rank,
worker_key=worker_key,
num_workers=num_workers,
worker_concurrency=worker_concurrency,
Expand Down
130 changes: 87 additions & 43 deletions gigl/distributed/dist_sampling_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DistDataset,
DistMpSamplingProducer,
MpDistSamplingWorkerOptions,
RemoteDistSamplingWorkerOptions,
init_rpc,
init_worker_group,
shutdown_rpc,
Expand All @@ -39,6 +40,7 @@
from gigl.common.logger import Logger
from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler
from gigl.distributed.dist_ppr_sampler import DistPPRNeighborSampler
from gigl.distributed.sampler import ABLPNodeSamplerInput
from gigl.distributed.sampler_options import (
KHopNeighborSamplerOptions,
PPRSamplerOptions,
Expand All @@ -47,6 +49,78 @@

logger = Logger()

SamplerInput = Union[NodeSamplerInput, EdgeSamplerInput, ABLPNodeSamplerInput]
SamplerRuntime = Union[DistNeighborSampler, DistPPRNeighborSampler]


def _create_dist_sampler(
*,
data: DistDataset,
sampling_config: SamplingConfig,
worker_options: Union[MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions],
channel: ChannelBase,
sampler_options: SamplerOptions,
degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]],
current_device: torch.device,
) -> SamplerRuntime:
"""Create a GiGL sampler runtime for one channel on one worker."""
shared_sampler_args = (
data,
sampling_config.num_neighbors,
sampling_config.with_edge,
sampling_config.with_neg,
sampling_config.with_weight,
sampling_config.edge_dir,
sampling_config.collect_features,
channel,
worker_options.use_all2all,
worker_options.worker_concurrency,
current_device,
)
if isinstance(sampler_options, KHopNeighborSamplerOptions):
sampler: SamplerRuntime = DistNeighborSampler(
*shared_sampler_args,
seed=sampling_config.seed,
)
elif isinstance(sampler_options, PPRSamplerOptions):
assert degree_tensors is not None
sampler = DistPPRNeighborSampler(
*shared_sampler_args,
seed=sampling_config.seed,
alpha=sampler_options.alpha,
eps=sampler_options.eps,
max_ppr_nodes=sampler_options.max_ppr_nodes,
num_neighbors_per_hop=sampler_options.num_neighbors_per_hop,
total_degree_dtype=sampler_options.total_degree_dtype,
degree_tensors=degree_tensors,
)
else:
raise NotImplementedError(
f"Unsupported sampler options type: {type(sampler_options)}"
)
return sampler


def _prepare_degree_tensors(
data: DistDataset,
sampler_options: SamplerOptions,
) -> Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]:
"""Materialize PPR degree tensors before worker spawn when required."""
degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] = None
if isinstance(sampler_options, PPRSamplerOptions):
degree_tensors = data.degree_tensor
if isinstance(degree_tensors, dict):
logger.info(
"Pre-computed degree tensors for PPR sampling across "
f"{len(degree_tensors)} edge types."
)
elif degree_tensors is not None:
logger.info(
"Pre-computed degree tensor for PPR sampling with "
f"{degree_tensors.size(0)} nodes."
)
return degree_tensors


def _sampling_worker_loop(
rank: int,
Expand Down Expand Up @@ -100,42 +174,15 @@ def _sampling_worker_loop(
if sampling_config.seed is not None:
seed_everything(sampling_config.seed)

# Shared args for all sampler types (positional args to DistNeighborSampler.__init__)
shared_sampler_args = (
data,
sampling_config.num_neighbors,
sampling_config.with_edge,
sampling_config.with_neg,
sampling_config.with_weight,
sampling_config.edge_dir,
sampling_config.collect_features,
channel,
worker_options.use_all2all,
worker_options.worker_concurrency,
current_device,
dist_sampler = _create_dist_sampler(
data=data,
sampling_config=sampling_config,
worker_options=worker_options,
channel=channel,
sampler_options=sampler_options,
degree_tensors=degree_tensors,
current_device=current_device,
)

if isinstance(sampler_options, KHopNeighborSamplerOptions):
dist_sampler = DistNeighborSampler(
*shared_sampler_args,
seed=sampling_config.seed,
)
elif isinstance(sampler_options, PPRSamplerOptions):
assert degree_tensors is not None
dist_sampler = DistPPRNeighborSampler(
*shared_sampler_args,
seed=sampling_config.seed,
alpha=sampler_options.alpha,
eps=sampler_options.eps,
max_ppr_nodes=sampler_options.max_ppr_nodes,
num_neighbors_per_hop=sampler_options.num_neighbors_per_hop,
total_degree_dtype=sampler_options.total_degree_dtype,
degree_tensors=degree_tensors,
)
else:
raise NotImplementedError(
f"Unsupported sampler options type: {type(sampler_options)}"
)
dist_sampler.start_loop()

unshuffled_index_loader: Optional[DataLoader]
Expand Down Expand Up @@ -186,16 +233,13 @@ def _sampling_worker_loop(
dist_sampler.wait_all()

with sampling_completed_worker_count.get_lock():
sampling_completed_worker_count.value += (
1 # non-atomic, lock is necessary
)
sampling_completed_worker_count.value += 1

elif command == MpCommand.STOP:
keep_running = False
else:
raise RuntimeError("Unknown command type")
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass

if dist_sampler is not None:
Expand Down Expand Up @@ -236,7 +280,7 @@ def init(self):
self.num_workers * self.worker_options.worker_concurrency
)
self._task_queues.append(task_queue)
w = mp_context.Process(
worker = mp_context.Process(
target=_sampling_worker_loop,
args=(
rank,
Expand All @@ -253,7 +297,7 @@ def init(self):
self._degree_tensors,
),
)
w.daemon = True
w.start()
self._workers.append(w)
worker.daemon = True
worker.start()
self._workers.append(worker)
barrier.wait()
24 changes: 6 additions & 18 deletions gigl/distributed/distributed_neighborloader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys
from collections import abc
from itertools import count
from typing import Callable, Optional, Tuple, Union
from typing import Optional, Tuple, Union

import torch
from graphlearn_torch.channel import SampleMessage
Expand All @@ -23,7 +22,6 @@
PPR_WEIGHT_METADATA_KEY,
)
from gigl.distributed.dist_sampling_producer import DistSamplingProducer
from gigl.distributed.graph_store.dist_server import DistServer as GiglDistServer
from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset
from gigl.distributed.sampler_options import (
PPRSamplerOptions,
Expand Down Expand Up @@ -61,11 +59,6 @@ def flush():


class DistNeighborLoader(BaseDistLoader):
# Counts instantiations of this class, per process.
# This is needed so we can generate unique worker key for each instance, for graph store mode.
# NOTE: This is per-class, not per-instance.
_counter = count(0)

def __init__(
self,
dataset: Union[DistDataset, RemoteDistDataset],
Expand Down Expand Up @@ -208,7 +201,7 @@ def __init__(
)
logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}")

self._instance_count = next(self._counter)
self._instance_count = next(BaseDistLoader._global_loader_counter)
device = (
pin_memory_device
if pin_memory_device
Expand Down Expand Up @@ -271,22 +264,17 @@ def __init__(
drop_last=drop_last,
)

# Build the producer: a pre-constructed producer for colocated mode,
# or an RPC callable for graph store mode.
producer: Optional[DistSamplingProducer] = None
if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED:
assert isinstance(dataset, DistDataset)
assert isinstance(worker_options, MpDistSamplingWorkerOptions)
producer: Union[
DistSamplingProducer, Callable[..., int]
] = BaseDistLoader.create_mp_producer(
producer = BaseDistLoader.create_mp_producer(
dataset=dataset,
sampler_input=input_data,
sampling_config=sampling_config,
worker_options=worker_options,
sampler_options=sampler_options,
)
else:
producer = GiglDistServer.create_sampling_producer

# Call base class — handles metadata storage and connection initialization
# (including staggered init for colocated mode).
Expand Down Expand Up @@ -341,11 +329,11 @@ def _setup_for_graph_store(
edge_types = dataset.fetch_edge_types()
compute_rank = torch.distributed.get_rank()

worker_key = f"compute_rank_{compute_rank}_worker_{self._instance_count}"
self._backend_key = f"dist_neighbor_loader_{self._instance_count}"
worker_key = f"{self._backend_key}_compute_rank_{compute_rank}"
logger.info(f"Rank {compute_rank} worker key: {worker_key}")
worker_options = BaseDistLoader.create_graph_store_worker_options(
dataset=dataset,
compute_rank=compute_rank,
worker_key=worker_key,
num_workers=num_workers,
worker_concurrency=worker_concurrency,
Expand Down
Loading