Skip to content

Commit 067bc74

Browse files
committed
stitch redesign tmp
1 parent b65e754 commit 067bc74

11 files changed

Lines changed: 2355 additions & 0 deletions

File tree

CLAUDE.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Before Every Response
2+
- NEVER describe how code works without reading it first (Read/Grep). If you didn't use a tool to check, say "I haven't verified this" or check first.
3+
- NEVER use nested/inline imports. ALL imports go at the top of the file. Check your edits for this before submitting.
4+
- Never remove breakpoints or uncomment code that was left commented out.
5+
6+
# Rules
7+
- Do what's asked, nothing more/less. NEVER create files unless absolutely necessary.
8+
- NEVER add comments about what code used to be or what was moved/removed.
9+
- Follow instructions precisely. If asked to implement but not integrate, don't integrate.
10+
- NEVER use unittest mocks — only mocker fixture.
11+
- Always write vectorized numpy — no Python loops over arrays.
12+
- Keep notebooks simple — short function calls only, all logic in modules.
13+
- No patchwork — design complete algorithms from first principles.
14+
- No fat VM solutions — hard constraint.
15+
- Never create git commits — user commits themselves.
16+
- Never modify user's code without asking first.
17+
- Test code end-to-end before presenting.
18+
- Terse responses — no trailing summaries.
19+
20+
# Project Context
21+
Read `pychunkedgraph/debug/stitch_test/SESSION.md` for full stitch redesign context.
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Stitch Redesign — Session Context
2+
3+
For a new Claude session to pick up this work, tell it:
4+
> Read pychunkedgraph/debug/stitch_test/SESSION.md and design.md to understand the stitch redesign state.
5+
6+
## Key files
7+
8+
- `.env/stitching/design.md` — high-level algorithm design doc
9+
- `pychunkedgraph/debug/stitch_test/design.md` — detailed algorithm description with phase breakdowns
10+
- `pychunkedgraph/debug/stitch_test/proposed.py` — the proposed stitch implementation
11+
- `pychunkedgraph/debug/stitch_test/wave.py` — unified test runner (single/wave/multiwave experiments)
12+
- `pychunkedgraph/debug/stitch_test/utils.py` — structure extraction, batched parallel extraction, comparison functions
13+
- `pychunkedgraph/debug/stitch_test/compare.py` — orchestration, persistence helpers
14+
- `pychunkedgraph/debug/stitch_test/current.py` — wrapper for current `add_edges` baseline
15+
- `pychunkedgraph/debug/stitch_test/tables.py` — BigTable backup/restore, env setup, autoscaling
16+
- `.env/stitching/hsmith_mec.ipynb` — test notebook
17+
18+
## Module dependency order (no cycles)
19+
20+
tables → utils → {current, proposed} → compare → wave
21+
22+
- `utils.py` has pure functions: extract_structure, _compare_*, _convert_for_json, batched extraction, SV-based comparison
23+
- `compare.py` has orchestration + persistence: imports from current, proposed, utils
24+
- Never import from compare into utils
25+
26+
## Current status (2026-03-23)
27+
28+
### What works
29+
- Proposed algorithm implemented and structurally correct (single file match verified)
30+
- Single file test: proposed ~151s vs current ~205s (1.35x speedup on this VM)
31+
- Wave 0 current baseline: 606 files, 311K roots, ~1050s wall with 512 workers
32+
- Wave 0 proposed: completed 638s wall (1.64x speedup), structural comparison pending (comparison bug fixed, needs re-run)
33+
34+
### Extraction and comparison design
35+
- **SV-based components**: `extract_structure` resolves L2 → SVs so components are frozensets of SV IDs (stable across tables, order-independent)
36+
- **Compressed storage**: `np.savez_compressed` with flat arrays + offsets for variable-length SV sets
37+
- **Independent extraction**: each side extracted into its own subdirectory (`current/`, `proposed/`)
38+
- **Order-independent comparison**: uses sets of frozensets per layer, not sorted lists. No shard-to-shard matching needed.
39+
- **No table deletion**: user manages table cleanup via prefix
40+
41+
### Retry safety
42+
- **`_get_all_parents_filtered`**: replaces `get_all_parents_dict_multiple` for stitching. Applies `filter_failed_node_ids` at every layer during parent chain traversal to detect and remap orphaned nodes from prior failed stitch attempts.
43+
- **`filter_failed_node_ids`** applied to both `l2ids` and `l2_siblings` after reading their children.
44+
- **Two-phase writes**: `_build_entries` returns `(node_entries, parent_entries)`. Node rows written first, then Parent pointers. Ensures Parent pointers only reference rows that exist.
45+
- **No FormerParent**: proposed path does not write FormerParent/deprecation entries.
46+
- **Crash recovery**: `stitch_results.json` saved immediately after stitch completes. Pass `run_id` to resume.
47+
- **Fresh runs**: `_clear_log_dir` deletes old results before restoring table.
48+
49+
### Architecture decisions
50+
- **No neighbor CrossChunkEdge updates**: stale is OK — future proposed stitches read AtomicCrossChunkEdge (immutable) + Parent + Child.
51+
- **No locks**: lock-free, enables true parallelism within waves.
52+
- **No table deletion**: user manages cleanup.
53+
54+
### Test infrastructure
55+
- **Entry points**: `run_current(experiment)` and `run_proposed_and_compare(experiment, run_id=None)`
56+
- **Experiment types**: "single" (one file), "wave" (wave 0), "multiwave" (all waves)
57+
- **Extraction**: 500K root batches, sharded across cpu_count workers, each saves own .npz
58+
- **Retries**: tenacity on extraction reads (3 attempts, exponential backoff)
59+
- **Workers**: `min(n_files, 4 * cpu_count)` for wave processing
60+
- **Progress**: tqdm for wave file processing
61+
- **Autoscaling**: for wave/multiwave, sets BigTable CPU target to 25% before, reverts to 60% after (in `finally`).
62+
63+
### Performance
64+
65+
**Single file (task_0_0.edges, 1024 edges)**:
66+
- Proposed ~151s vs current ~205s (1.35x)
67+
68+
**Wave 0 (606 files, 311K roots)**:
69+
- Current: ~1050s wall
70+
- Proposed: ~638s wall (1.64x)
71+
- Proposed per-file: mean=245s, median=272s, p95=295s, max=399s (task_0_591.edges)
72+
73+
### Remaining work
74+
- Re-run wave 0 comparison with fixed SV-based extraction
75+
- Add incremental file result saving during wave runs
76+
- Optimize proposed further (straggler task_0_591 took 399s)
77+
- Run multiwave test once wave 0 validates
78+
79+
## User preferences (critical)
80+
- **Never describe how code works without reading it first** — use Read/Grep, or say "I haven't verified this"
81+
- **Never use nested/inline imports** — all imports at module top level, design modules to avoid circular deps
82+
- **Never create commits** — user does them
83+
- **Vectorized numpy** — no Python loops where numpy works
84+
- **Keep notebooks simple** — short function calls only, all logic in modules
85+
- **No patchwork** — design complete algorithms from first principles
86+
- **No fat VMs** — hard constraint
87+
- **Max effort always**
88+
- **No mocks** — only mocker fixture
89+
- **Test end-to-end before presenting**
90+
- **Never modify user's code without asking**
91+
- **Terse responses** — no trailing summaries
92+
- **Never delete tables** — user manages cleanup via prefix
93+
94+
## Dataset
95+
- **hsmith_mec**: 7 layers, ~600k edges, 1095 total files
96+
- Wave 0: 606 files
97+
- Edge source: `gs://dodam_exp/hammerschmith_mec/100GVx_cutout/proofreadable_exp16_0.26/agg_chunk_ext_edges`
98+
- Backup table: `hsmith-mec-100gvx-exp16-0.26-backup`
99+
- BigTable project: `zetta-proofreading`, instance: `pychunkedgraph`
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .tables import restore_test_table
2+
from .inspect import inspect_stitch_edges, inspect_l2_cross_edges, inspect_hierarchy
3+
from .current import run_current_stitch
4+
from .proposed import run_proposed_stitch
5+
from .wave import run_current, run_proposed_and_compare, list_wave_files, list_all_waves
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import json
2+
import pickle
3+
import secrets
4+
from datetime import datetime, timezone
5+
from pathlib import Path
6+
7+
import numpy as np
8+
from cloudfiles import CloudFile
9+
10+
from pychunkedgraph.graph import basetypes
11+
12+
from .current import run_current_stitch
13+
from .proposed import run_proposed_stitch
14+
from .tables import restore_test_table, setup_env, PREFIX, EDGES_SRC, _get_instance
15+
from .utils import _compare_components, _compare_cross_edges, _convert_for_json
16+
17+
LOGS_ROOT = Path("/home/akhilesh/opt/zetta_utils/.env/pcg/.env/stitching/runs")
18+
19+
20+
def generate_run_id() -> str:
21+
return secrets.token_hex(4)
22+
23+
24+
# ─────────────────────────────────────────────────────────────────────
25+
# Top-level API
26+
# ─────────────────────────────────────────────────────────────────────
27+
28+
29+
def run_current_baseline(experiment: str = "single", edge_file: str = None):
30+
"""
31+
Run the current stitch path once for an experiment type.
32+
If the table + saved results already exist, skips and prints "reusing".
33+
"""
34+
setup_env()
35+
if edge_file is None:
36+
edge_file = f"{EDGES_SRC}/task_0_0.edges"
37+
38+
table_name = f"{PREFIX}hsmith_mec_current_{experiment}"
39+
log_dir = LOGS_ROOT / experiment / "current"
40+
log_dir.mkdir(parents=True, exist_ok=True)
41+
structure_path = log_dir / "current_structure.json"
42+
43+
instance = _get_instance()
44+
if instance.table(table_name).exists() and structure_path.exists():
45+
print(f"reusing {table_name}")
46+
return
47+
48+
print(f"restoring and running current path for '{experiment}'")
49+
restore_test_table(table_name)
50+
edges = pickle.loads(CloudFile(edge_file).get())
51+
edges = np.asarray(edges, dtype=basetypes.NODE_ID)
52+
result = run_current_stitch(table_name, edges, do_sanity_check=False)
53+
_save_run_result(log_dir, "current", result)
54+
print(f"current {experiment} done: {result['elapsed']:.1f}s")
55+
56+
57+
def run_proposed_and_compare(experiment: str = "single", edge_file: str = None):
58+
"""
59+
Run the proposed stitch path and compare against the current baseline.
60+
Returns (match, result_current, result_proposed).
61+
"""
62+
setup_env()
63+
if edge_file is None:
64+
edge_file = f"{EDGES_SRC}/task_0_0.edges"
65+
66+
run_id = generate_run_id()
67+
log_dir = LOGS_ROOT / experiment / run_id
68+
log_dir.mkdir(parents=True, exist_ok=True)
69+
table_proposed = f"{PREFIX}hsmith_mec_{run_id}_proposed"
70+
71+
print(f"run_id: {run_id}")
72+
print(f"logs: {log_dir}")
73+
74+
current_log_dir = LOGS_ROOT / experiment / "current"
75+
result_current = _load_result(current_log_dir, "current")
76+
77+
restore_test_table(table_proposed)
78+
edges = pickle.loads(CloudFile(edge_file).get())
79+
edges = np.asarray(edges, dtype=basetypes.NODE_ID)
80+
result_proposed = run_proposed_stitch(table_proposed, edges)
81+
_save_run_result(log_dir, "proposed", result_proposed)
82+
83+
print(f"\ncurrent: {result_current['elapsed']:.1f}s, proposed: {result_proposed['elapsed']:.1f}s")
84+
match = compare_stitch_results(result_current, result_proposed)
85+
86+
summary = {
87+
"run_id": run_id,
88+
"experiment": experiment,
89+
"timestamp": datetime.now(timezone.utc).isoformat(),
90+
"edge_file": edge_file,
91+
"match": match,
92+
"time_current": result_current["elapsed"],
93+
"time_proposed": result_proposed["elapsed"],
94+
"proposed_perf": result_proposed.get("perf", {}),
95+
}
96+
with open(log_dir / "summary.json", "w") as f:
97+
json.dump(_convert_for_json(summary), f, indent=2)
98+
99+
print(f"\n{'MATCH' if match else 'MISMATCH'}")
100+
return match, result_current, result_proposed
101+
102+
103+
# ─────────────────────────────────────────────────────────────────────
104+
# Comparison
105+
# ─────────────────────────────────────────────────────────────────────
106+
107+
108+
def compare_stitch_results(result_a: dict, result_b: dict) -> bool:
109+
ids_match = _compare_new_ids_per_layer(result_a, result_b)
110+
comp_match = _compare_components(result_a["structure"], result_b["structure"])
111+
cx_match = _compare_cross_edges(result_a["structure"], result_b["structure"])
112+
return ids_match and comp_match and cx_match
113+
114+
115+
def _compare_new_ids_per_layer(result_a, result_b):
116+
lc_a = {int(k): v for k, v in result_a.get("layer_counts", {}).items()}
117+
lc_b = {int(k): v for k, v in result_b.get("layer_counts", {}).items()}
118+
all_layers = sorted(set(lc_a.keys()) | set(lc_b.keys()))
119+
match = True
120+
for layer in all_layers:
121+
if lc_a.get(layer, 0) != lc_b.get(layer, 0):
122+
print(f" NEW IDS MISMATCH layer {layer}: {lc_a.get(layer,0)} vs {lc_b.get(layer,0)}")
123+
match = False
124+
if match:
125+
print(f" NEW IDS MATCH: {sum(lc_a.values())} across {len(all_layers)} layers")
126+
return match
127+
128+
129+
# ─────────────────────────────────────────────────────────────────────
130+
# Persistence helpers
131+
# ─────────────────────────────────────────────────────────────────────
132+
133+
134+
def _save_structure(log_dir, name, structure):
135+
serializable = {}
136+
comps = structure.get("components", {})
137+
serializable["components"] = {
138+
str(layer): [sorted(c) for c in ccs] for layer, ccs in comps.items()
139+
}
140+
cx = structure.get("cross_edges", {})
141+
serializable["cross_edges"] = {
142+
str(layer): [[sorted(src), sorted(dst)] for src, dst in pairs]
143+
for layer, pairs in cx.items()
144+
}
145+
with open(log_dir / f"{name}_structure.json", "w") as f:
146+
json.dump(_convert_for_json(serializable), f, indent=2)
147+
148+
149+
def _save_run_result(log_dir, name, result):
150+
_save_structure(log_dir, name, result["structure"])
151+
meta = {k: v for k, v in result.items() if k != "structure"}
152+
with open(log_dir / f"{name}_meta.json", "w") as f:
153+
json.dump(_convert_for_json(meta), f, indent=2)
154+
155+
156+
def _load_structure(path):
157+
with open(path) as f:
158+
data = json.load(f)
159+
return {
160+
"components": {
161+
int(layer): [frozenset(c) for c in ccs]
162+
for layer, ccs in data.get("components", {}).items()
163+
},
164+
"cross_edges": {
165+
int(layer): [(frozenset(src), frozenset(dst)) for src, dst in pairs]
166+
for layer, pairs in data.get("cross_edges", {}).items()
167+
},
168+
}
169+
170+
171+
def _load_result(log_dir, name):
172+
with open(log_dir / f"{name}_meta.json") as f:
173+
result = json.load(f)
174+
result["structure"] = _load_structure(log_dir / f"{name}_structure.json")
175+
return result
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import os
2+
import time
3+
4+
os.environ["PCG_PROFILER_ENABLED"] = "1"
5+
6+
import numpy as np
7+
8+
import pychunkedgraph.debug.profiler as profiler_mod
9+
from pychunkedgraph.debug.profiler import HierarchicalProfiler
10+
from pychunkedgraph.graph import ChunkedGraph, basetypes
11+
from .utils import extract_structure
12+
13+
14+
def run_current_stitch(graph_id: str, atomic_edges: np.ndarray, do_sanity_check: bool = True) -> dict:
15+
"""
16+
Run the existing add_edges stitch path on a graph copy.
17+
Same calling convention as dist/internal/chunkedgraph/operations.py.
18+
Returns dict with structural result and metadata.
19+
"""
20+
21+
class SilentProfiler(HierarchicalProfiler):
22+
def print_report(self, *a, **kw):
23+
pass
24+
25+
profiler_mod._profiler = SilentProfiler(enabled=True)
26+
27+
atomic_edges = np.asarray(atomic_edges, dtype=basetypes.NODE_ID)
28+
cg = ChunkedGraph(graph_id=graph_id)
29+
30+
print(f" [current] stitch ({len(atomic_edges)} edges)...")
31+
t0 = time.time()
32+
result = cg.add_edges(
33+
user_id="test",
34+
atomic_edges=atomic_edges,
35+
stitch_mode=True,
36+
allow_same_segment_merge=True,
37+
do_sanity_check=do_sanity_check,
38+
)
39+
elapsed = time.time() - t0
40+
new_roots = result.new_root_ids
41+
new_l2_ids = result.new_lvl2_ids
42+
print(f" [current] stitch: {elapsed:.1f}s, {len(new_roots)} roots")
43+
44+
profiler = profiler_mod._profiler
45+
perf = {}
46+
for path, times in profiler.timings.items():
47+
perf[path] = {
48+
"total_ms": sum(times) * 1000,
49+
"calls": profiler.call_counts[path],
50+
"avg_ms": (sum(times) / profiler.call_counts[path]) * 1000,
51+
}
52+
profiler_mod._profiler = HierarchicalProfiler(enabled=False)
53+
54+
t0 = time.time()
55+
structure = extract_structure(cg, new_roots)
56+
print(f" [current] structure: {time.time() - t0:.1f}s")
57+
58+
return {
59+
"structure": structure,
60+
"new_roots": new_roots.tolist(),
61+
"new_l2_ids": [int(x) for x in new_l2_ids],
62+
"operation_id": int(result.operation_id) if result.operation_id else None,
63+
"elapsed": elapsed,
64+
"graph_id": graph_id,
65+
"n_edges": len(atomic_edges),
66+
"layer_counts": {layer: len(ccs) for layer, ccs in structure["components"].items()},
67+
"perf": perf,
68+
}

0 commit comments

Comments
 (0)