Skip to content

Commit cfcb8cb

Browse files
committed
Apply clang-format to ppr_forward_push.cpp
1 parent 0fe733c commit cfcb8cb

1 file changed

Lines changed: 63 additions & 64 deletions

File tree

gigl/distributed/cpp_extensions/ppr_forward_push.cpp

Lines changed: 63 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
#include <pybind11/stl.h> // Automatic conversion between C++ containers and Python types
12
#include <torch/extension.h> // PyTorch C++ API (tensors, TORCH_CHECK)
2-
#include <pybind11/stl.h> // Automatic conversion between C++ containers and Python types
33

4-
#include <algorithm> // std::partial_sort, std::min
5-
#include <cstdint> // Fixed-width integer types: int32_t, int64_t, uint32_t, uint64_t
4+
#include <algorithm> // std::partial_sort, std::min
5+
#include <cstdint> // Fixed-width integer types: int32_t, int64_t, uint32_t, uint64_t
66
#include <unordered_map> // std::unordered_map — like Python dict, O(1) average lookup
77
#include <unordered_set> // std::unordered_set — like Python set, O(1) average lookup
88
#include <vector> // std::vector — like Python list, contiguous in memory
@@ -47,16 +47,12 @@ static inline uint64_t pack_key(int32_t node_id, int32_t etype_id) {
4747
// 4. push_residuals(fetched_by_etype_id) — push residuals, update queue
4848
// 5. extract_top_k(max_ppr_nodes) — top-k selection per seed per node type
4949
class PPRForwardPushState {
50-
public:
51-
PPRForwardPushState(
52-
torch::Tensor seed_nodes,
53-
int32_t seed_node_type_id,
54-
double alpha,
55-
double requeue_threshold_factor,
56-
std::vector<std::vector<int32_t>> node_type_to_edge_type_ids,
57-
std::vector<int32_t> edge_type_to_dst_ntype_id,
58-
std::vector<torch::Tensor> degree_tensors
59-
)
50+
public:
51+
PPRForwardPushState(torch::Tensor seed_nodes, int32_t seed_node_type_id, double alpha,
52+
double requeue_threshold_factor,
53+
std::vector<std::vector<int32_t>> node_type_to_edge_type_ids,
54+
std::vector<int32_t> edge_type_to_dst_ntype_id,
55+
std::vector<torch::Tensor> degree_tensors)
6056
: alpha_(alpha),
6157
one_minus_alpha_(1.0 - alpha),
6258
requeue_threshold_factor_(requeue_threshold_factor),
@@ -66,23 +62,25 @@ class PPRForwardPushState {
6662
node_type_to_edge_type_ids_(std::move(node_type_to_edge_type_ids)),
6763
edge_type_to_dst_ntype_id_(std::move(edge_type_to_dst_ntype_id)),
6864
degree_tensors_(std::move(degree_tensors)) {
69-
7065
TORCH_CHECK(seed_nodes.dim() == 1, "seed_nodes must be 1D");
7166
batch_size_ = static_cast<int32_t>(seed_nodes.size(0));
7267
num_node_types_ = static_cast<int32_t>(node_type_to_edge_type_ids_.size());
7368

7469
// Allocate per-seed, per-node-type tables.
7570
// .assign(n, val) fills a vector with n copies of val — like [val] * n in Python.
7671
// Each inner element is an empty hash map / hash set for that (seed, ntype) pair.
77-
ppr_scores_.assign(batch_size_, std::vector<std::unordered_map<int32_t, double>>(num_node_types_));
78-
residuals_.assign(batch_size_, std::vector<std::unordered_map<int32_t, double>>(num_node_types_));
79-
queue_.assign(batch_size_, std::vector<std::unordered_set<int32_t>>(num_node_types_));
80-
queued_nodes_.assign(batch_size_, std::vector<std::unordered_set<int32_t>>(num_node_types_));
72+
ppr_scores_.assign(batch_size_,
73+
std::vector<std::unordered_map<int32_t, double>>(num_node_types_));
74+
residuals_.assign(batch_size_,
75+
std::vector<std::unordered_map<int32_t, double>>(num_node_types_));
76+
queue_.assign(batch_size_, std::vector<std::unordered_set<int32_t>>(num_node_types_));
77+
queued_nodes_.assign(batch_size_,
78+
std::vector<std::unordered_set<int32_t>>(num_node_types_));
8179

8280
// accessor<dtype, ndim>() returns a typed view into the tensor's data that
8381
// supports [i] indexing with bounds checking in debug builds. Here we read
8482
// each seed node ID from the 1-D int64 tensor.
85-
auto acc = seed_nodes.accessor<int64_t, 1>();
83+
auto acc = seed_nodes.accessor<int64_t, 1>();
8684
num_nodes_in_queue_ = batch_size_;
8785
for (int32_t i = 0; i < batch_size_; ++i) {
8886
// static_cast<int32_t>: explicit narrowing from int64 to int32.
@@ -116,7 +114,8 @@ class PPRForwardPushState {
116114
// (alias) to the existing set — clearing it modifies the original in-place
117115
// rather than operating on a copy.
118116
for (int32_t s = 0; s < batch_size_; ++s)
119-
for (auto& qs : queued_nodes_[s]) qs.clear();
117+
for (auto& qs : queued_nodes_[s])
118+
qs.clear();
120119

121120
// nodes_to_lookup[eid] = set of node IDs that need a neighbor fetch for
122121
// edge type eid this round. Using a set deduplicates nodes that appear
@@ -126,7 +125,8 @@ class PPRForwardPushState {
126125

127126
for (int32_t s = 0; s < batch_size_; ++s) {
128127
for (int32_t nt = 0; nt < num_node_types_; ++nt) {
129-
if (queue_[s][nt].empty()) continue;
128+
if (queue_[s][nt].empty())
129+
continue;
130130

131131
// Move the live queue into the snapshot (no data copy — O(1)).
132132
// queue_ is then reset to an empty set so new entries added by
@@ -213,7 +213,8 @@ class PPRForwardPushState {
213213
// c. Enqueue any neighbor whose residual now exceeds the requeue threshold.
214214
for (int32_t s = 0; s < batch_size_; ++s) {
215215
for (int32_t nt = 0; nt < num_node_types_; ++nt) {
216-
if (queued_nodes_[s][nt].empty()) continue;
216+
if (queued_nodes_[s][nt].empty())
217+
continue;
217218

218219
for (int32_t src : queued_nodes_[s][nt]) {
219220
// `auto&` gives a reference to the residual map for this
@@ -222,7 +223,7 @@ class PPRForwardPushState {
222223
auto& src_res = residuals_[s][nt];
223224
// .find() returns an iterator; .end() means "not found".
224225
// We treat a missing entry as residual = 0.
225-
auto it = src_res.find(src);
226+
auto it = src_res.find(src);
226227
double res = (it != src_res.end()) ? it->second : 0.0;
227228

228229
// a. Absorb: move residual into the PPR score.
@@ -232,7 +233,8 @@ class PPRForwardPushState {
232233
int32_t total_deg = get_total_degree(src, nt);
233234
// Destination-only nodes (no outgoing edges) absorb residual
234235
// into their PPR score but do not push further.
235-
if (total_deg == 0) continue;
236+
if (total_deg == 0)
237+
continue;
236238

237239
// b. Distribute: each neighbor of src (across all edge types
238240
// from nt) receives an equal share of the pushed residual.
@@ -249,18 +251,20 @@ class PPRForwardPushState {
249251
// We use a pointer (rather than copying the list) so we can check
250252
// for absence with nullptr without allocating anything.
251253
const std::vector<int32_t>* nbr_list = nullptr;
252-
auto fi = fetched.find(pack_key(src, eid));
254+
auto fi = fetched.find(pack_key(src, eid));
253255
if (fi != fetched.end()) {
254256
// `&fi->second` takes the address of the vector stored in
255257
// the map — nbr_list now points to it without copying.
256258
nbr_list = &fi->second;
257259
} else {
258260
auto ci = neighbor_cache_.find(pack_key(src, eid));
259-
if (ci != neighbor_cache_.end()) nbr_list = &ci->second;
261+
if (ci != neighbor_cache_.end())
262+
nbr_list = &ci->second;
260263
}
261264
// Skip if no neighbor list is available (node has no edges of
262265
// this type, or the fetch returned an empty list).
263-
if (!nbr_list || nbr_list->empty()) continue;
266+
if (!nbr_list || nbr_list->empty())
267+
continue;
264268

265269
int32_t dst_nt = edge_type_to_dst_ntype_id_[eid];
266270

@@ -270,7 +274,7 @@ class PPRForwardPushState {
270274
residuals_[s][dst_nt][nbr] += res_per_nbr;
271275

272276
double threshold = requeue_threshold_factor_ *
273-
static_cast<double>(get_total_degree(nbr, dst_nt));
277+
static_cast<double>(get_total_degree(nbr, dst_nt));
274278

275279
// Only enqueue if: (1) not already in queue for this
276280
// iteration, and (2) residual exceeds the push threshold
@@ -315,13 +319,14 @@ class PPRForwardPushState {
315319
std::unordered_set<int32_t> active;
316320
for (int32_t s = 0; s < batch_size_; ++s)
317321
for (int32_t nt = 0; nt < num_node_types_; ++nt)
318-
if (!ppr_scores_[s][nt].empty()) active.insert(nt);
322+
if (!ppr_scores_[s][nt].empty())
323+
active.insert(nt);
319324

320325
py::dict result;
321326
for (int32_t nt : active) {
322327
// Flat output vectors — entries for all seeds are concatenated.
323328
std::vector<int64_t> flat_ids;
324-
std::vector<float> flat_weights;
329+
std::vector<float> flat_weights;
325330
std::vector<int64_t> valid_counts;
326331

327332
for (int32_t s = 0; s < batch_size_; ++s) {
@@ -341,7 +346,8 @@ class PPRForwardPushState {
341346
// is an anonymous comparator (like Python's `key=` argument).
342347
// `.second` accesses the score (second element of the pair);
343348
// `>` makes it descending (highest score first).
344-
std::partial_sort(items.begin(), items.begin() + k, items.end(),
349+
std::partial_sort(
350+
items.begin(), items.begin() + k, items.end(),
345351
[](const auto& a, const auto& b) { return a.second > b.second; });
346352

347353
for (int32_t i = 0; i < k; ++i) {
@@ -355,27 +361,25 @@ class PPRForwardPushState {
355361
}
356362

357363
// py::make_tuple wraps C++ values into a Python tuple.
358-
result[py::int_(nt)] = py::make_tuple(
359-
torch::tensor(flat_ids, torch::kLong),
360-
torch::tensor(flat_weights, torch::kFloat),
361-
torch::tensor(valid_counts, torch::kLong)
362-
);
364+
result[py::int_(nt)] = py::make_tuple(torch::tensor(flat_ids, torch::kLong),
365+
torch::tensor(flat_weights, torch::kFloat),
366+
torch::tensor(valid_counts, torch::kLong));
363367
}
364368
return result;
365369
}
366370

367-
private:
371+
private:
368372
// Look up the total (across all edge types) out-degree of a node.
369373
// Returns 0 for destination-only node types (no outgoing edges).
370374
int32_t get_total_degree(int32_t node_id, int32_t ntype_id) const {
371-
if (ntype_id >= static_cast<int32_t>(degree_tensors_.size())) return 0;
375+
if (ntype_id >= static_cast<int32_t>(degree_tensors_.size()))
376+
return 0;
372377
const auto& t = degree_tensors_[ntype_id];
373-
if (t.numel() == 0) return 0; // destination-only type: no outgoing edges
374-
TORCH_CHECK(
375-
node_id < static_cast<int32_t>(t.size(0)),
376-
"Node ID ", node_id, " out of range for degree tensor of ntype_id ", ntype_id,
377-
" (size=", t.size(0), "). This indicates corrupted graph data or a sampler bug."
378-
);
378+
if (t.numel() == 0)
379+
return 0; // destination-only type: no outgoing edges
380+
TORCH_CHECK(node_id < static_cast<int32_t>(t.size(0)), "Node ID ", node_id,
381+
" out of range for degree tensor of ntype_id ", ntype_id, " (size=", t.size(0),
382+
"). This indicates corrupted graph data or a sampler bug.");
379383
// data_ptr<int32_t>() returns a raw C pointer to the tensor's int32 data
380384
// buffer. Direct pointer indexing ([node_id]) is safe here because we
381385
// validated the bounds with TORCH_CHECK above.
@@ -385,13 +389,14 @@ class PPRForwardPushState {
385389
// -------------------------------------------------------------------------
386390
// Scalar algorithm parameters
387391
// -------------------------------------------------------------------------
388-
double alpha_; // Restart probability
389-
double one_minus_alpha_; // 1 - alpha, precomputed to avoid repeated subtraction
390-
double requeue_threshold_factor_; // alpha * eps; multiplied by degree to get per-node threshold
392+
double alpha_; // Restart probability
393+
double one_minus_alpha_; // 1 - alpha, precomputed to avoid repeated subtraction
394+
double
395+
requeue_threshold_factor_; // alpha * eps; multiplied by degree to get per-node threshold
391396

392-
int32_t batch_size_; // Number of seeds in the current batch
393-
int32_t num_node_types_; // Total number of node types (homo + hetero)
394-
int32_t num_nodes_in_queue_{0}; // Running count of nodes across all seeds / types
397+
int32_t batch_size_; // Number of seeds in the current batch
398+
int32_t num_node_types_; // Total number of node types (homo + hetero)
399+
int32_t num_nodes_in_queue_{0}; // Running count of nodes across all seeds / types
395400

396401
// -------------------------------------------------------------------------
397402
// Graph structure (read-only after construction)
@@ -400,10 +405,10 @@ class PPRForwardPushState {
400405
// traversed from that node type (outgoing or incoming, depending on edge_dir).
401406
std::vector<std::vector<int32_t>> node_type_to_edge_type_ids_;
402407
// edge_type_to_dst_ntype_id_[eid] → node type ID at the destination end.
403-
std::vector<int32_t> edge_type_to_dst_ntype_id_;
408+
std::vector<int32_t> edge_type_to_dst_ntype_id_;
404409
// degree_tensors_[ntype_id][node_id] → total degree of that node across all
405410
// edge types traversable from its type. Empty tensor means no outgoing edges.
406-
std::vector<torch::Tensor> degree_tensors_;
411+
std::vector<torch::Tensor> degree_tensors_;
407412

408413
// -------------------------------------------------------------------------
409414
// Per-seed, per-node-type PPR state (indexed [seed_idx][ntype_id])
@@ -417,11 +422,11 @@ class PPRForwardPushState {
417422
// residuals_[s][nt]: node_id → unabsorbed probability mass waiting to be pushed
418423
std::vector<std::vector<std::unordered_map<int32_t, double>>> residuals_;
419424
// queue_[s][nt]: nodes whose residual exceeds the threshold and need a push next round
420-
std::vector<std::vector<std::unordered_set<int32_t>>> queue_;
425+
std::vector<std::vector<std::unordered_set<int32_t>>> queue_;
421426
// queued_nodes_[s][nt]: snapshot of queue_ taken by drain_queue() for the current round.
422427
// Separating it from queue_ lets push_residuals() enqueue new nodes into queue_ without
423428
// modifying the set currently being iterated.
424-
std::vector<std::vector<std::unordered_set<int32_t>>> queued_nodes_;
429+
std::vector<std::vector<std::unordered_set<int32_t>>> queued_nodes_;
425430

426431
// -------------------------------------------------------------------------
427432
// Neighbor cache
@@ -445,15 +450,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
445450
// .def(py::init<...>()) exposes the constructor. The template arguments
446451
// list the exact C++ parameter types so pybind11 can convert Python
447452
// arguments to the correct C++ types automatically.
448-
.def(py::init<
449-
torch::Tensor,
450-
int32_t,
451-
double, double,
452-
std::vector<std::vector<int32_t>>,
453-
std::vector<int32_t>,
454-
std::vector<torch::Tensor>
455-
>())
456-
.def("drain_queue", &PPRForwardPushState::drain_queue)
453+
.def(py::init<torch::Tensor, int32_t, double, double, std::vector<std::vector<int32_t>>,
454+
std::vector<int32_t>, std::vector<torch::Tensor>>())
455+
.def("drain_queue", &PPRForwardPushState::drain_queue)
457456
.def("push_residuals", &PPRForwardPushState::push_residuals)
458-
.def("extract_top_k", &PPRForwardPushState::extract_top_k);
457+
.def("extract_top_k", &PPRForwardPushState::extract_top_k);
459458
}

0 commit comments

Comments
 (0)