1+ #include < pybind11/stl.h> // Automatic conversion between C++ containers and Python types
12#include < torch/extension.h> // PyTorch C++ API (tensors, TORCH_CHECK)
2- #include < pybind11/stl.h> // Automatic conversion between C++ containers and Python types
33
4- #include < algorithm> // std::partial_sort, std::min
5- #include < cstdint> // Fixed-width integer types: int32_t, int64_t, uint32_t, uint64_t
4+ #include < algorithm> // std::partial_sort, std::min
5+ #include < cstdint> // Fixed-width integer types: int32_t, int64_t, uint32_t, uint64_t
66#include < unordered_map> // std::unordered_map — like Python dict, O(1) average lookup
77#include < unordered_set> // std::unordered_set — like Python set, O(1) average lookup
88#include < vector> // std::vector — like Python list, contiguous in memory
@@ -47,16 +47,12 @@ static inline uint64_t pack_key(int32_t node_id, int32_t etype_id) {
4747// 4. push_residuals(fetched_by_etype_id) — push residuals, update queue
4848// 5. extract_top_k(max_ppr_nodes) — top-k selection per seed per node type
4949class PPRForwardPushState {
50- public:
51- PPRForwardPushState (
52- torch::Tensor seed_nodes,
53- int32_t seed_node_type_id,
54- double alpha,
55- double requeue_threshold_factor,
56- std::vector<std::vector<int32_t >> node_type_to_edge_type_ids,
57- std::vector<int32_t > edge_type_to_dst_ntype_id,
58- std::vector<torch::Tensor> degree_tensors
59- )
50+ public:
51+ PPRForwardPushState (torch::Tensor seed_nodes, int32_t seed_node_type_id, double alpha,
52+ double requeue_threshold_factor,
53+ std::vector<std::vector<int32_t >> node_type_to_edge_type_ids,
54+ std::vector<int32_t > edge_type_to_dst_ntype_id,
55+ std::vector<torch::Tensor> degree_tensors)
6056 : alpha_(alpha),
6157 one_minus_alpha_ (1.0 - alpha),
6258 requeue_threshold_factor_(requeue_threshold_factor),
@@ -66,23 +62,25 @@ class PPRForwardPushState {
6662 node_type_to_edge_type_ids_(std::move(node_type_to_edge_type_ids)),
6763 edge_type_to_dst_ntype_id_(std::move(edge_type_to_dst_ntype_id)),
6864 degree_tensors_(std::move(degree_tensors)) {
69-
7065 TORCH_CHECK (seed_nodes.dim () == 1 , " seed_nodes must be 1D" );
7166 batch_size_ = static_cast <int32_t >(seed_nodes.size (0 ));
7267 num_node_types_ = static_cast <int32_t >(node_type_to_edge_type_ids_.size ());
7368
7469 // Allocate per-seed, per-node-type tables.
7570 // .assign(n, val) fills a vector with n copies of val — like [val] * n in Python.
7671 // Each inner element is an empty hash map / hash set for that (seed, ntype) pair.
77- ppr_scores_.assign (batch_size_, std::vector<std::unordered_map<int32_t , double >>(num_node_types_));
78- residuals_.assign (batch_size_, std::vector<std::unordered_map<int32_t , double >>(num_node_types_));
79- queue_.assign (batch_size_, std::vector<std::unordered_set<int32_t >>(num_node_types_));
80- queued_nodes_.assign (batch_size_, std::vector<std::unordered_set<int32_t >>(num_node_types_));
72+ ppr_scores_.assign (batch_size_,
73+ std::vector<std::unordered_map<int32_t , double >>(num_node_types_));
74+ residuals_.assign (batch_size_,
75+ std::vector<std::unordered_map<int32_t , double >>(num_node_types_));
76+ queue_.assign (batch_size_, std::vector<std::unordered_set<int32_t >>(num_node_types_));
77+ queued_nodes_.assign (batch_size_,
78+ std::vector<std::unordered_set<int32_t >>(num_node_types_));
8179
8280 // accessor<dtype, ndim>() returns a typed view into the tensor's data that
8381 // supports [i] indexing with bounds checking in debug builds. Here we read
8482 // each seed node ID from the 1-D int64 tensor.
85- auto acc = seed_nodes.accessor <int64_t , 1 >();
83+ auto acc = seed_nodes.accessor <int64_t , 1 >();
8684 num_nodes_in_queue_ = batch_size_;
8785 for (int32_t i = 0 ; i < batch_size_; ++i) {
8886 // static_cast<int32_t>: explicit narrowing from int64 to int32.
@@ -116,7 +114,8 @@ class PPRForwardPushState {
116114 // (alias) to the existing set — clearing it modifies the original in-place
117115 // rather than operating on a copy.
118116 for (int32_t s = 0 ; s < batch_size_; ++s)
119- for (auto & qs : queued_nodes_[s]) qs.clear ();
117+ for (auto & qs : queued_nodes_[s])
118+ qs.clear ();
120119
121120 // nodes_to_lookup[eid] = set of node IDs that need a neighbor fetch for
122121 // edge type eid this round. Using a set deduplicates nodes that appear
@@ -126,7 +125,8 @@ class PPRForwardPushState {
126125
127126 for (int32_t s = 0 ; s < batch_size_; ++s) {
128127 for (int32_t nt = 0 ; nt < num_node_types_; ++nt) {
129- if (queue_[s][nt].empty ()) continue ;
128+ if (queue_[s][nt].empty ())
129+ continue ;
130130
131131 // Move the live queue into the snapshot (no data copy — O(1)).
132132 // queue_ is then reset to an empty set so new entries added by
@@ -213,7 +213,8 @@ class PPRForwardPushState {
213213 // c. Enqueue any neighbor whose residual now exceeds the requeue threshold.
214214 for (int32_t s = 0 ; s < batch_size_; ++s) {
215215 for (int32_t nt = 0 ; nt < num_node_types_; ++nt) {
216- if (queued_nodes_[s][nt].empty ()) continue ;
216+ if (queued_nodes_[s][nt].empty ())
217+ continue ;
217218
218219 for (int32_t src : queued_nodes_[s][nt]) {
219220 // `auto&` gives a reference to the residual map for this
@@ -222,7 +223,7 @@ class PPRForwardPushState {
222223 auto & src_res = residuals_[s][nt];
223224 // .find() returns an iterator; .end() means "not found".
224225 // We treat a missing entry as residual = 0.
225- auto it = src_res.find (src);
226+ auto it = src_res.find (src);
226227 double res = (it != src_res.end ()) ? it->second : 0.0 ;
227228
228229 // a. Absorb: move residual into the PPR score.
@@ -232,7 +233,8 @@ class PPRForwardPushState {
232233 int32_t total_deg = get_total_degree (src, nt);
233234 // Destination-only nodes (no outgoing edges) absorb residual
234235 // into their PPR score but do not push further.
235- if (total_deg == 0 ) continue ;
236+ if (total_deg == 0 )
237+ continue ;
236238
237239 // b. Distribute: each neighbor of src (across all edge types
238240 // from nt) receives an equal share of the pushed residual.
@@ -249,18 +251,20 @@ class PPRForwardPushState {
249251 // We use a pointer (rather than copying the list) so we can check
250252 // for absence with nullptr without allocating anything.
251253 const std::vector<int32_t >* nbr_list = nullptr ;
252- auto fi = fetched.find (pack_key (src, eid));
254+ auto fi = fetched.find (pack_key (src, eid));
253255 if (fi != fetched.end ()) {
254256 // `&fi->second` takes the address of the vector stored in
255257 // the map — nbr_list now points to it without copying.
256258 nbr_list = &fi->second ;
257259 } else {
258260 auto ci = neighbor_cache_.find (pack_key (src, eid));
259- if (ci != neighbor_cache_.end ()) nbr_list = &ci->second ;
261+ if (ci != neighbor_cache_.end ())
262+ nbr_list = &ci->second ;
260263 }
261264 // Skip if no neighbor list is available (node has no edges of
262265 // this type, or the fetch returned an empty list).
263- if (!nbr_list || nbr_list->empty ()) continue ;
266+ if (!nbr_list || nbr_list->empty ())
267+ continue ;
264268
265269 int32_t dst_nt = edge_type_to_dst_ntype_id_[eid];
266270
@@ -270,7 +274,7 @@ class PPRForwardPushState {
270274 residuals_[s][dst_nt][nbr] += res_per_nbr;
271275
272276 double threshold = requeue_threshold_factor_ *
273- static_cast <double >(get_total_degree (nbr, dst_nt));
277+ static_cast <double >(get_total_degree (nbr, dst_nt));
274278
275279 // Only enqueue if: (1) not already in queue for this
276280 // iteration, and (2) residual exceeds the push threshold
@@ -315,13 +319,14 @@ class PPRForwardPushState {
315319 std::unordered_set<int32_t > active;
316320 for (int32_t s = 0 ; s < batch_size_; ++s)
317321 for (int32_t nt = 0 ; nt < num_node_types_; ++nt)
318- if (!ppr_scores_[s][nt].empty ()) active.insert (nt);
322+ if (!ppr_scores_[s][nt].empty ())
323+ active.insert (nt);
319324
320325 py::dict result;
321326 for (int32_t nt : active) {
322327 // Flat output vectors — entries for all seeds are concatenated.
323328 std::vector<int64_t > flat_ids;
324- std::vector<float > flat_weights;
329+ std::vector<float > flat_weights;
325330 std::vector<int64_t > valid_counts;
326331
327332 for (int32_t s = 0 ; s < batch_size_; ++s) {
@@ -341,7 +346,8 @@ class PPRForwardPushState {
341346 // is an anonymous comparator (like Python's `key=` argument).
342347 // `.second` accesses the score (second element of the pair);
343348 // `>` makes it descending (highest score first).
344- std::partial_sort (items.begin (), items.begin () + k, items.end (),
349+ std::partial_sort (
350+ items.begin (), items.begin () + k, items.end (),
345351 [](const auto & a, const auto & b) { return a.second > b.second ; });
346352
347353 for (int32_t i = 0 ; i < k; ++i) {
@@ -355,27 +361,25 @@ class PPRForwardPushState {
355361 }
356362
357363 // py::make_tuple wraps C++ values into a Python tuple.
358- result[py::int_ (nt)] = py::make_tuple (
359- torch::tensor (flat_ids, torch::kLong ),
360- torch::tensor (flat_weights, torch::kFloat ),
361- torch::tensor (valid_counts, torch::kLong )
362- );
364+ result[py::int_ (nt)] = py::make_tuple (torch::tensor (flat_ids, torch::kLong ),
365+ torch::tensor (flat_weights, torch::kFloat ),
366+ torch::tensor (valid_counts, torch::kLong ));
363367 }
364368 return result;
365369 }
366370
367- private:
371+ private:
368372 // Look up the total (across all edge types) out-degree of a node.
369373 // Returns 0 for destination-only node types (no outgoing edges).
370374 int32_t get_total_degree (int32_t node_id, int32_t ntype_id) const {
371- if (ntype_id >= static_cast <int32_t >(degree_tensors_.size ())) return 0 ;
375+ if (ntype_id >= static_cast <int32_t >(degree_tensors_.size ()))
376+ return 0 ;
372377 const auto & t = degree_tensors_[ntype_id];
373- if (t.numel () == 0 ) return 0 ; // destination-only type: no outgoing edges
374- TORCH_CHECK (
375- node_id < static_cast <int32_t >(t.size (0 )),
376- " Node ID " , node_id, " out of range for degree tensor of ntype_id " , ntype_id,
377- " (size=" , t.size (0 ), " ). This indicates corrupted graph data or a sampler bug."
378- );
378+ if (t.numel () == 0 )
379+ return 0 ; // destination-only type: no outgoing edges
380+ TORCH_CHECK (node_id < static_cast <int32_t >(t.size (0 )), " Node ID " , node_id,
381+ " out of range for degree tensor of ntype_id " , ntype_id, " (size=" , t.size (0 ),
382+ " ). This indicates corrupted graph data or a sampler bug." );
379383 // data_ptr<int32_t>() returns a raw C pointer to the tensor's int32 data
380384 // buffer. Direct pointer indexing ([node_id]) is safe here because we
381385 // validated the bounds with TORCH_CHECK above.
@@ -385,13 +389,14 @@ class PPRForwardPushState {
385389 // -------------------------------------------------------------------------
386390 // Scalar algorithm parameters
387391 // -------------------------------------------------------------------------
388- double alpha_; // Restart probability
389- double one_minus_alpha_; // 1 - alpha, precomputed to avoid repeated subtraction
390- double requeue_threshold_factor_; // alpha * eps; multiplied by degree to get per-node threshold
392+ double alpha_; // Restart probability
393+ double one_minus_alpha_; // 1 - alpha, precomputed to avoid repeated subtraction
394+ double
395+ requeue_threshold_factor_; // alpha * eps; multiplied by degree to get per-node threshold
391396
392- int32_t batch_size_; // Number of seeds in the current batch
393- int32_t num_node_types_; // Total number of node types (homo + hetero)
394- int32_t num_nodes_in_queue_{0 }; // Running count of nodes across all seeds / types
397+ int32_t batch_size_; // Number of seeds in the current batch
398+ int32_t num_node_types_; // Total number of node types (homo + hetero)
399+ int32_t num_nodes_in_queue_{0 }; // Running count of nodes across all seeds / types
395400
396401 // -------------------------------------------------------------------------
397402 // Graph structure (read-only after construction)
@@ -400,10 +405,10 @@ class PPRForwardPushState {
400405 // traversed from that node type (outgoing or incoming, depending on edge_dir).
401406 std::vector<std::vector<int32_t >> node_type_to_edge_type_ids_;
402407 // edge_type_to_dst_ntype_id_[eid] → node type ID at the destination end.
403- std::vector<int32_t > edge_type_to_dst_ntype_id_;
408+ std::vector<int32_t > edge_type_to_dst_ntype_id_;
404409 // degree_tensors_[ntype_id][node_id] → total degree of that node across all
405410 // edge types traversable from its type. Empty tensor means no outgoing edges.
406- std::vector<torch::Tensor> degree_tensors_;
411+ std::vector<torch::Tensor> degree_tensors_;
407412
408413 // -------------------------------------------------------------------------
409414 // Per-seed, per-node-type PPR state (indexed [seed_idx][ntype_id])
@@ -417,11 +422,11 @@ class PPRForwardPushState {
417422 // residuals_[s][nt]: node_id → unabsorbed probability mass waiting to be pushed
418423 std::vector<std::vector<std::unordered_map<int32_t , double >>> residuals_;
419424 // queue_[s][nt]: nodes whose residual exceeds the threshold and need a push next round
420- std::vector<std::vector<std::unordered_set<int32_t >>> queue_;
425+ std::vector<std::vector<std::unordered_set<int32_t >>> queue_;
421426 // queued_nodes_[s][nt]: snapshot of queue_ taken by drain_queue() for the current round.
422427 // Separating it from queue_ lets push_residuals() enqueue new nodes into queue_ without
423428 // modifying the set currently being iterated.
424- std::vector<std::vector<std::unordered_set<int32_t >>> queued_nodes_;
429+ std::vector<std::vector<std::unordered_set<int32_t >>> queued_nodes_;
425430
426431 // -------------------------------------------------------------------------
427432 // Neighbor cache
@@ -445,15 +450,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
445450 // .def(py::init<...>()) exposes the constructor. The template arguments
446451 // list the exact C++ parameter types so pybind11 can convert Python
447452 // arguments to the correct C++ types automatically.
448- .def (py::init<
449- torch::Tensor,
450- int32_t ,
451- double , double ,
452- std::vector<std::vector<int32_t >>,
453- std::vector<int32_t >,
454- std::vector<torch::Tensor>
455- >())
456- .def (" drain_queue" , &PPRForwardPushState::drain_queue)
453+ .def (py::init<torch::Tensor, int32_t , double , double , std::vector<std::vector<int32_t >>,
454+ std::vector<int32_t >, std::vector<torch::Tensor>>())
455+ .def (" drain_queue" , &PPRForwardPushState::drain_queue)
457456 .def (" push_residuals" , &PPRForwardPushState::push_residuals)
458- .def (" extract_top_k" , &PPRForwardPushState::extract_top_k);
457+ .def (" extract_top_k" , &PPRForwardPushState::extract_top_k);
459458}
0 commit comments