Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_lift_topology(self):
# (or, equivalently, the SC has a simplex in its facets set if complex_dim = |maximal_clique|-1)

# Convert adjacency matrix to NetworkX graph
G_from_latent_complex = nx.from_numpy_matrix(
G_from_latent_complex = nx.from_numpy_array(
edge_prob_one_adj.to_dense().numpy()
)
G_input = nx.Graph()
Expand All @@ -95,7 +95,7 @@ def test_lift_topology(self):
# (or, equivalently, there is no subset of the 1-skeleton of the SC isomorphic to the input graph)

# Convert adjacency matrix to NetworkX graph
G_from_latent_complex = nx.from_numpy_matrix(
G_from_latent_complex = nx.from_numpy_array(
edge_prob_any_adj.to_dense().numpy()
)
G_input = nx.Graph()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def setup_method(self):
[0.16, 0.45],
]
)
self.data = Data(x=pos, y=torch.tensor(y))
self.data = Data(x=pos, y=y)

# Initialise the HypergraphKHopLifting class
self.lifting = MoGMSTLifting(min_components=3, random_state=0)
Expand Down
33 changes: 33 additions & 0 deletions topobench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,38 @@
"""TopoBench: A library for benchmarking of topological models."""

# torch >= 2.6 defaults to weights_only=True in torch.load, but OGB and
# older PyG code serialize these classes. Register them as safe so that
# torch.load works without weights_only=False everywhere.
import numpy as np
import torch

if hasattr(torch.serialization, "add_safe_globals"):
from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr
from torch_geometric.data.storage import (
EdgeStorage,
GlobalStorage,
NodeStorage,
)

safe_globals = [
DataEdgeAttr,
DataTensorAttr,
GlobalStorage,
NodeStorage,
EdgeStorage,
np.core.multiarray.scalar,
np.dtype,
]
# numpy >= 1.25 uses typed DType subclasses (e.g. Int64DType) in pickle
# streams; register all of them so weights_only=True succeeds.
import numpy.dtypes

for name in dir(numpy.dtypes):
obj = getattr(numpy.dtypes, name)
if isinstance(obj, type) and name.endswith("DType"):
safe_globals.append(obj)
torch.serialization.add_safe_globals(safe_globals)

# Import submodules
from . import (
data,
Expand Down
4 changes: 3 additions & 1 deletion topobench/transforms/data_manipulations/group_homophily.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def forward(self, data: torch_geometric.data.Data):
if max_k != 1:
H_k = H[:, torch.where(he_cardinalities == max_k)[0]].clone()

he_cardinalities_k = torch.tensor(H_k.sum(0), dtype=torch.long)
he_cardinalities_k = (
H_k.sum(0).detach().clone().to(dtype=torch.long)
)
Dt, D = self.calculate_D_matrix(
H_k,
labels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def lift_topology(self, data: Data) -> dict:
for i in range(n):
st.insert([i])

graph: nx.Graph = nx.from_numpy_matrix(adj_mat).to_undirected()
graph: nx.Graph = nx.from_numpy_array(adj_mat).to_undirected()

# Insert all edges
for v, u in graph.edges:
Expand Down
23 changes: 10 additions & 13 deletions uv_env_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@ echo "======================================================="
# ------------------------------------------------------------------------------
TORCH_VER="2.3.0"

if [ "$PLATFORM" == "cpu" ]; then
TARGET_INDEX="pytorch-cpu"
PYG_URL="https://data.pyg.org/whl/torch-${TORCH_VER}+cpu.html"
elif [ "$PLATFORM" == "cu118" ]; then
TARGET_INDEX="pytorch-cu118"
PYG_URL="https://data.pyg.org/whl/torch-${TORCH_VER}+cu118.html"
elif [ "$PLATFORM" == "cu121" ]; then
TARGET_INDEX="pytorch-cu121"
PYG_URL="https://data.pyg.org/whl/torch-${TORCH_VER}+cu121.html"
else
echo "❌ Error: Invalid platform '$PLATFORM'. Use: cpu, cu118, or cu121."
return 1 2>/dev/null || exit 1
fi
case "$PLATFORM" in
cpu|cu118|cu121)
TARGET_INDEX="pytorch-${PLATFORM}"
PYG_URL="https://data.pyg.org/whl/torch-${TORCH_VER}+${PLATFORM}.html"
;;
*)
echo "❌ Error: Invalid platform '$PLATFORM'. Use: cpu, cu118, or cu121."
return 1 2>/dev/null || exit 1
;;
esac

echo "⚙️ Updating pyproject.toml..."

Expand Down
Loading