Skip to content

KrishnaswamyLab/GraphWaveletEncoder

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Graph Wavelet Encoder

Twitter Follow Twitter LinkedIn
Latest PyPI version PyPI download 3 month PyPI download month

Graph wavelet scattering transform encoder for PyTorch Geometric graphs.

Computes multi-scale wavelet features from graph structure and node signals using lazy random-walk diffusion.

Installation

pip install graph-wavelet-encoder

For development (editable install):

pip install -e .

With dev dependencies (pytest, etc.):

pip install -e ".[dev]"

Usage

import torch
from torch_geometric.data import Data, Batch
from graph_wavelet_encoder import GraphWaveletEncoder

# Single graph or batched PyG Data
# Expects: .x (node features), .edge_index, .batch (for batched graphs)
encoder = GraphWaveletEncoder(
    scales=(1, 2, 4, 8, 16),
    sigma=2.0,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
features = encoder.encode(graph)  # [batch_size, num_nodes, num_features_per_node]

The encoder uses a lazy random-walk matrix and produces zeroth-, first-, and second-order scattering coefficients. See the docstrings in graph_wavelet_encoder.encoder for details.

Testing and benchmarking

pytest tests/test_encoder.py -v -s

License

Yale Licence

About

Graph wavelet encoder

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages