Skip to content
Draft
35 changes: 16 additions & 19 deletions gigl/distributed/graph_store/dist_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
)
from gigl.distributed.sampler import ABLPNodeSamplerInput
from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions
from gigl.distributed.utils.neighborloader import shard_nodes_by_process
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.types.graph import FeatureInfo, select_label_edge_types
from gigl.utils.data_splitters import get_labels_for_anchor_nodes
Expand Down Expand Up @@ -283,14 +282,14 @@ def get_node_ids(

Args:
request: The node-fetch request, including split, node type,
and round-robin rank/world_size.
and optional rank/world_size for partitioning.

Returns:
The node ids.

Raises:
ValueError:
* If the rank and world_size are not provided together
* If rank and world_size are not provided together
* If the split is invalid
* If the node ids are not a torch.Tensor or a dict[NodeType, torch.Tensor]
* If the node type is provided for a homogeneous dataset
Expand All @@ -315,25 +314,25 @@ def _get_node_ids(
rank: Optional[int] = None,
world_size: Optional[int] = None,
) -> torch.Tensor:
"""Core implementation for fetching node IDs by split, type, and sharding.
"""Core implementation for fetching node IDs by split, type, and partitioning.

Args:
split: The dataset split to fetch from (``"train"``, ``"val"``,
``"test"``, or ``None`` for all nodes).
node_type: The node type to select. Must be ``None`` for
homogeneous datasets.
rank: Round-robin rank for sharding. Must be provided together
with ``world_size``.
world_size: Total number of processes for sharding. Must be
provided together with ``rank``.
rank: Which partition to return (0-indexed). Must be
provided together with ``world_size``.
world_size: Total number of partitions. Must be provided
together with ``rank``.

Returns:
The node IDs tensor, optionally sharded by rank.
The node IDs tensor, optionally partitioned.

Raises:
ValueError: If rank/world_size are not provided together, the
split is invalid, or the node type is inconsistent with
the dataset type (homogeneous vs. heterogeneous).
ValueError: If the split parameters are invalid, the split is
invalid, or the node type is inconsistent with the dataset
type (homogeneous vs. heterogeneous).
"""
if (rank is None) ^ (world_size is None):
raise ValueError(
Expand Down Expand Up @@ -367,7 +366,7 @@ def _get_node_ids(
)

if rank is not None and world_size is not None:
return shard_nodes_by_process(nodes, rank, world_size)
return torch.tensor_split(nodes, world_size)[rank]
return nodes

def get_edge_types(self) -> Optional[list[EdgeType]]:
Expand Down Expand Up @@ -396,17 +395,15 @@ def get_ablp_input(
self,
request: FetchABLPInputRequest,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Get the ABLP (Anchor Based Link Prediction) input for a specific rank in distributed processing.

Note: rank and world_size here are for the process group we're *fetching for*, not the process group we're *fetching from*.
e.g. if our compute cluster is of world size 4, and we have 2 storage nodes, then the world size this gets called with is 4, not 2.
"""Get the ABLP (Anchor Based Link Prediction) input for distributed processing.

Args:
request: The ABLP fetch request, including split, node type,
supervision edge type, and round-robin rank/world_size.
supervision edge type, and optional rank/world_size for
partitioning.

Returns:
A tuple containing the anchor nodes for the rank, the positive labels, and the negative labels.
A tuple containing the anchor nodes, the positive labels, and the negative labels.
The positive labels are of shape [N, M], where N is the number of anchor nodes and M is the number of positive labels.
The negative labels are of shape [N, M], where N is the number of anchor nodes and M is the number of negative labels.
The negative labels may be None if no negative labels are available.
Expand Down
75 changes: 51 additions & 24 deletions gigl/distributed/graph_store/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,53 @@
from gigl.src.common.types.graph_data import EdgeType, NodeType


def _validate_sharding_params(
rank: Optional[int],
world_size: Optional[int],
) -> None:
"""Validate that sharding parameters are consistent.

Args:
rank: Which partition to select (0-indexed).
world_size: Total number of partitions.

Raises:
ValueError: If only one of ``rank``/``world_size`` is provided,
or if the values are out of range.
"""
if (rank is None) ^ (world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={rank}, world_size={world_size}"
)
if rank is not None and world_size is not None:
if world_size <= 0:
raise ValueError(f"world_size must be > 0, received {world_size}")
if rank < 0 or rank >= world_size:
raise ValueError(
"rank must be in [0, world_size). "
f"Received rank={rank}, world_size={world_size}"
)


@dataclass(frozen=True)
class FetchNodesRequest:
"""Request for fetching node IDs from a storage server.

Args:
rank: The rank of the process requesting node ids.
rank: Which partition of the node IDs to return (0-indexed).
Must be provided together with ``world_size``.
world_size: The total number of processes in the distributed setup.
world_size: Total number of partitions.
Must be provided together with ``rank``.
split: The split of the dataset to get node ids from.
node_type: The type of nodes to get node ids for.

Examples:
Fetch all nodes without sharding:
Fetch all nodes without splitting:

>>> FetchNodesRequest()

Fetch training nodes for rank 0 of 4:
Fetch partition 0 of 4 from training nodes:

>>> FetchNodesRequest(rank=0, world_size=4, split="train")

Expand All @@ -38,16 +67,15 @@ class FetchNodesRequest:
node_type: Optional[NodeType] = None

def validate(self) -> None:
"""Validate that the request has consistent rank/world_size.
"""Validate that the request has consistent sharding parameters.

Raises:
ValueError: If only one of ``rank`` or ``world_size`` is provided.
ValueError: If sharding parameters are partially specified or out of range.
"""
if (self.rank is None) ^ (self.world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={self.rank}, world_size={self.world_size}"
)
_validate_sharding_params(
rank=self.rank,
world_size=self.world_size,
)


@dataclass(frozen=True)
Expand All @@ -58,19 +86,19 @@ class FetchABLPInputRequest:
split: The split of the dataset to get ABLP input from.
node_type: The type of anchor nodes to retrieve.
supervision_edge_type: The edge type used for supervision.
rank: The rank of the process requesting ABLP input.
rank: Which partition of the anchor nodes to return (0-indexed).
Must be provided together with ``world_size``.
world_size: The total number of processes in the distributed setup.
world_size: Total number of partitions.
Must be provided together with ``rank``.

Examples:
Fetch training ABLP input without sharding:
Fetch training ABLP input without splitting:

>>> FetchABLPRequest(split="train", node_type="user", supervision_edge_type=("user", "to", "item"))
>>> FetchABLPInputRequest(split="train", node_type="user", supervision_edge_type=("user", "to", "item"))

Fetch training ABLP input for rank 0 of 4:
Fetch partition 0 of 4 from training ABLP input:

>>> FetchABLPRequest(split="train", node_type="user", supervision_edge_type=("user", "to", "item"), rank=0, world_size=4)
>>> FetchABLPInputRequest(split="train", node_type="user", supervision_edge_type=("user", "to", "item"), rank=0, world_size=4)
"""

split: Union[Literal["train", "val", "test"], str]
Expand All @@ -80,13 +108,12 @@ class FetchABLPInputRequest:
world_size: Optional[int] = None

def validate(self) -> None:
"""Validate that the request has consistent rank/world_size.
"""Validate that the request has consistent sharding parameters.

Raises:
ValueError: If only one of ``rank`` or ``world_size`` is provided.
ValueError: If sharding parameters are partially specified or out of range.
"""
if (self.rank is None) ^ (self.world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={self.rank}, world_size={self.world_size}"
)
_validate_sharding_params(
rank=self.rank,
world_size=self.world_size,
)
Loading