From e3e1651dfa36faf1d5969a43c48b579e29698b0c Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Tue, 10 Feb 2026 10:10:30 +0530 Subject: [PATCH 01/50] Sync changes from CDB_DiskANN repo - Refactored recall utilities in diskann-benchmark - Updated tokio utilities - Added attribute and format parser improvements in label-filter - Updated ground_truth utilities in diskann-tools --- diskann-benchmark/src/utils/recall.rs | 703 +----------------- diskann-benchmark/src/utils/tokio.rs | 20 +- diskann-label-filter/src/attribute.rs | 1 + diskann-label-filter/src/parser/format.rs | 2 + .../src/utils/flatten_utils.rs | 2 +- diskann-tools/Cargo.toml | 18 +- diskann-tools/src/utils/ground_truth.rs | 161 +++- 7 files changed, 196 insertions(+), 711 deletions(-) diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index 5b7fd1594..bfaf46772 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -3,15 +3,13 @@ * Licensed under the MIT license. */ -use std::{collections::HashSet, hash::Hash}; - -use diskann_utils::strided::StridedView; -use diskann_utils::views::MatrixView; +use diskann_benchmark_core as benchmark_core; +pub(crate) use benchmark_core::recall::knn; use serde::Serialize; -use thiserror::Error; -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] +#[non_exhaustive] pub(crate) struct RecallMetrics { /// The `k` value for `k-recall-at-n`. pub(crate) recall_k: usize, @@ -25,278 +23,19 @@ pub(crate) struct RecallMetrics { pub(crate) minimum: usize, /// The maximum observed recall (max possible value: `recall_k`). pub(crate) maximum: usize, - /// Recall results by query - pub(crate) by_query: Option>, -} - -// impl RecallMetrics { -// pub(crate) fn num_queries(&self) -> usize { -// self.num_queries -// } - -// pub(crate) fn average(&self) -> f64 { -// self.average -// } -// } - -#[derive(Debug, Error)] -pub(crate) enum ComputeRecallError { - #[error("results matrix has {0} rows but ground truth has {1}")] - RowsMismatch(usize, usize), - #[error("distances matrix has {0} rows but ground truth has {1}")] - DistanceRowsMismatch(usize, usize), - #[error("recall k value {0} must be less than or equal to recall n {1}")] - RecallKAndNError(usize, usize), - #[error("number of results per query {0} must be at least the specified recall k {1}")] - NotEnoughResults(usize, usize), - #[error( - "number of groundtruth values per query {0} must be at least the specified recall n {1}" - )] - NotEnoughGroundTruth(usize, usize), - #[error("number of groundtruth distances {0} does not match groundtruth entries {1}")] - NotEnoughGroundTruthDistances(usize, usize), -} - -pub(crate) trait ComputeKnnRecall { - fn compute_knn_recall( - &self, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result; -} - -impl ComputeKnnRecall for MatrixView<'_, T> -where - T: Eq + Hash + Copy + std::fmt::Debug, -{ - fn compute_knn_recall( - &self, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result { - compute_knn_recall( - self, - groundtruth_distances, - results, - recall_k, - recall_n, - allow_insufficient_results, - enhanced_metrics, - ) - } -} - -impl ComputeKnnRecall for Vec> -where - T: Eq + Hash + Copy + std::fmt::Debug, -{ - fn compute_knn_recall( - &self, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result { - compute_knn_recall( - self, - groundtruth_distances, - results, - recall_k, - recall_n, - allow_insufficient_results, - enhanced_metrics, - ) - } -} - -pub(crate) trait KnnRecall { - type Item; - - fn nrows(&self) -> usize; - fn ncols(&self) -> Option; - fn row(&self, i: usize) -> &[Self::Item]; -} - -impl KnnRecall for MatrixView<'_, T> { - type Item = T; - - fn nrows(&self) -> usize { - MatrixView::<'_, T>::nrows(self) - } - fn ncols(&self) -> Option { - Some(MatrixView::<'_, T>::ncols(self)) - } - fn row(&self, i: usize) -> &[Self::Item] { - MatrixView::<'_, T>::row(self, i) - } -} - -impl KnnRecall for Vec> { - type Item = T; - - fn nrows(&self) -> usize { - self.len() - } - fn ncols(&self) -> Option { - None - } - fn row(&self, i: usize) -> &[Self::Item] { - &self[i] - } } -impl KnnRecall for StridedView<'_, T> { - type Item = T; - - fn nrows(&self) -> usize { - StridedView::<'_, T>::nrows(self) - } - fn ncols(&self) -> Option { - Some(StridedView::<'_, T>::ncols(self)) - } - fn row(&self, i: usize) -> &[Self::Item] { - StridedView::<'_, T>::row(self, i) - } -} - -fn compute_knn_recall( - groundtruth: &K, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, -) -> Result -where - T: Eq + Hash + Copy + std::fmt::Debug, - K: KnnRecall, -{ - if recall_k > recall_n { - return Err(ComputeRecallError::RecallKAndNError(recall_k, recall_n)); - } - - let nrows = results.nrows(); - if nrows != groundtruth.nrows() { - return Err(ComputeRecallError::RowsMismatch(nrows, groundtruth.nrows())); - } - - if results.ncols() < recall_n && !allow_insufficient_results { - return Err(ComputeRecallError::NotEnoughResults( - results.ncols(), - recall_n, - )); - } - - // Validate groundtruth size for fixed-size sources - match groundtruth.ncols() { - Some(ncols) if ncols < recall_k => { - return Err(ComputeRecallError::NotEnoughGroundTruth(ncols, recall_k)); - } - _ => {} - } - - if let Some(distances) = groundtruth_distances { - if nrows != distances.nrows() { - return Err(ComputeRecallError::DistanceRowsMismatch( - distances.nrows(), - nrows, - )); - } - - match groundtruth.ncols() { - Some(ncols) if distances.ncols() != ncols => { - return Err(ComputeRecallError::NotEnoughGroundTruthDistances( - distances.ncols(), - ncols, - )); - } - _ => {} +impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { + fn from(m: &benchmark_core::recall::RecallMetrics) -> Self { + Self { + recall_k: m.recall_k, + recall_n: m.recall_n, + num_queries: m.num_queries, + average: m.average, + minimum: m.minimum, + maximum: m.maximum, } } - - // The actual recall computation for fixed-size groundtruth - let mut recall_values: Vec = Vec::new(); - let mut this_groundtruth = HashSet::new(); - let mut this_results = HashSet::new(); - - for (i, result) in results.row_iter().enumerate() { - let gt_row = groundtruth.row(i); - - // Populate the groundtruth using the top-k - this_groundtruth.clear(); - this_groundtruth.extend(gt_row.iter().copied().take(recall_k)); - - // If we have distances, then continue to append distances as long as the distance - // value is constant - if let Some(distances) = groundtruth_distances { - if recall_k > 0 { - let distances_row = distances.row(i); - if distances_row.len() > recall_k - 1 && gt_row.len() > recall_k - 1 { - let last_distance = distances_row[recall_k - 1]; - for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(recall_k) { - if *d == last_distance { - this_groundtruth.insert(*g); - } else { - break; - } - } - } - } - } - - this_results.clear(); - this_results.extend(result.iter().copied().take(recall_n)); - - // Count the overlap - let r = this_groundtruth - .iter() - .filter(|i| this_results.contains(i)) - .count() - .min(recall_k); - - recall_values.push(r); - } - - // Perform post-processing - let total: usize = recall_values.iter().sum(); - let minimum = recall_values.iter().min().unwrap_or(&0); - let maximum = recall_values.iter().max().unwrap_or(&0); - - let div = if groundtruth.ncols().is_some() { - recall_k * nrows - } else { - (0..groundtruth.nrows()) - .map(|i| groundtruth.row(i).len()) - .sum::() - .max(1) - }; - - let average = (total as f64) / (div as f64); - - Ok(RecallMetrics { - recall_k, - recall_n, - num_queries: nrows, - average, - minimum: *minimum, - maximum: *maximum, - by_query: if enhanced_metrics { - Some(recall_values) - } else { - None - }, - }) } /// Compute `k-recall-at-n` for all valid combinations of values in `recall_k` and @@ -309,14 +48,13 @@ where feature = "product-quantization" ))] pub(crate) fn compute_multiple_recalls( - results: StridedView<'_, T>, - groundtruth: StridedView<'_, T>, + results: &dyn benchmark_core::recall::Rows, + groundtruth: &dyn benchmark_core::recall::Rows, recall_k: &[usize], recall_n: &[usize], - enhanced_metrics: bool, -) -> Result, ComputeRecallError> +) -> Result, benchmark_core::recall::ComputeRecallError> where - T: Eq + Hash + Copy + std::fmt::Debug, + T: benchmark_core::recall::RecallCompatible, { let mut result = Vec::new(); for k in recall_k { @@ -325,414 +63,27 @@ where continue; } - result.push(compute_knn_recall( - &groundtruth, - None, - results, - *k, - *n, - false, - enhanced_metrics, - )?); + let recall = benchmark_core::recall::knn(groundtruth, None, results, *k, *n, false)?; + result.push((&recall).into()); } } Ok(result) } -#[derive(Debug, Serialize)] -pub(crate) struct APMetrics { +#[derive(Debug, Clone, Serialize)] +#[non_exhaustive] +pub(crate) struct AveragePrecisionMetrics { /// The number of queries. pub(crate) num_queries: usize, /// The average precision pub(crate) average_precision: f64, } -#[derive(Debug, Error)] -pub(crate) enum ComputeAPError { - #[error("results has {0} elements but ground truth has {1}")] - EntriesMismatch(usize, usize), -} - -/// Compute average precision of a range search result -pub(crate) fn compute_average_precision( - results: Vec>, - groundtruth: &[Vec], -) -> Result -where - T: Eq + Hash + Copy + std::fmt::Debug, -{ - if results.len() != groundtruth.len() { - return Err(ComputeAPError::EntriesMismatch( - results.len(), - groundtruth.len(), - )); - } - - // The actual recall computation. - let mut num_gt_results = 0; - let mut num_reported_results = 0; - - let mut scratch = HashSet::new(); - - std::iter::zip(results.iter(), groundtruth.iter()).for_each(|(result, gt)| { - scratch.clear(); - scratch.extend(result.iter().copied()); - num_reported_results += gt.iter().filter(|i| scratch.contains(i)).count(); - num_gt_results += gt.len(); - }); - - // Perform post-processing. - let average_precision = (num_reported_results as f64) / (num_gt_results as f64); - - Ok(APMetrics { - average_precision, - num_queries: results.len(), - }) -} - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use diskann_utils::views::Matrix; - - use super::*; - - pub(crate) fn compute_knn_recall( - results: StridedView<'_, u32>, - groundtruth: G, // StridedView - groundtruth_distances: Option>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result - where - G: ComputeKnnRecall + KnnRecall + Clone, - { - groundtruth.compute_knn_recall( - groundtruth_distances, - results, - recall_k, - recall_n, - allow_insufficient_results, - enhanced_metrics, - ) - } - - struct ExpectedRecall { - recall_k: usize, - recall_n: usize, - // Recall for each component. - components: Vec, - } - - impl ExpectedRecall { - fn new(recall_k: usize, recall_n: usize, components: Vec) -> Self { - assert!(recall_k <= recall_n); - components.iter().for_each(|x| { - assert!(*x <= recall_k); - }); - Self { - recall_k, - recall_n, - components, - } - } - - fn compute_recall(&self) -> f64 { - (self.components.iter().sum::() as f64) - / ((self.components.len() * self.recall_k) as f64) - } - } - - #[test] - fn test_happy_path() { - let groundtruth = Matrix::try_from( - vec![ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 0 - 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, // row 1 - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 2 - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 3 - ] - .into(), - 4, - 10, - ) - .unwrap(); - - let distances = Matrix::try_from( - vec![ - 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 0 - 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // row 1 - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, // row 2 - 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 3 - ] - .into(), - 4, - 10, - ) - .unwrap(); - - // Shift row 0 by one and row 1 by two. - let our_results = Matrix::try_from( - vec![ - 100, 0, 1, 2, 5, 6, // row 0 - 100, 101, 7, 8, 9, 10, // row 1 - 0, 1, 2, 3, 4, 5, // row 2 - 0, 1, 2, 3, 4, 5, // row 3 - ] - .into(), - 4, - 6, - ) - .unwrap(); - - //---------// - // No Ties // - //---------// - let expected_no_ties = vec![ - // Equal `k` and `n` - ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]), - ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]), - ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]), - ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]), - ExpectedRecall::new(5, 5, vec![3, 3, 5, 5]), - ExpectedRecall::new(6, 6, vec![4, 4, 6, 6]), - // Unequal `k` and `n`. - ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]), - ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]), - ExpectedRecall::new(2, 3, vec![2, 0, 2, 2]), - ExpectedRecall::new(3, 5, vec![3, 1, 3, 3]), - ]; - let epsilon = 1e-6; // Define a small tolerance - - for (i, expected) in expected_no_ties.iter().enumerate() { - assert_eq!(expected.components.len(), our_results.nrows()); - let recall = compute_knn_recall( - our_results.as_view().into(), - groundtruth.as_view(), - None, - expected.recall_k, - expected.recall_n, - false, - true, - ) - .unwrap(); - - let left = recall.average; - let right = expected.compute_recall(); - assert!( - (left - right).abs() < epsilon, - "left = {}, right = {} on input {}", - left, - right, - i - ); - - assert_eq!(recall.num_queries, our_results.nrows()); - assert_eq!(recall.recall_k, expected.recall_k); - assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); - } - - //-----------// - // With Ties // - //-----------// - let expected_with_ties = vec![ - // Equal `k` and `n` - ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]), - ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]), - ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]), - ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]), - ExpectedRecall::new(5, 5, vec![4, 3, 5, 5]), // tie-breaker kicks in - ExpectedRecall::new(6, 6, vec![5, 4, 6, 6]), // tie-breaker kicks in - // Unequal `k` and `n`. - ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]), - ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]), - ExpectedRecall::new(2, 3, vec![2, 1, 2, 2]), - ExpectedRecall::new(4, 5, vec![4, 3, 4, 4]), - ]; - - for (i, expected) in expected_with_ties.iter().enumerate() { - assert_eq!(expected.components.len(), our_results.nrows()); - let recall = compute_knn_recall( - our_results.as_view().into(), - groundtruth.as_view(), - Some(distances.as_view().into()), - expected.recall_k, - expected.recall_n, - false, - true, - ) - .unwrap(); - - let left = recall.average; - let right = expected.compute_recall(); - assert!( - (left - right).abs() < epsilon, - "left = {}, right = {} on input {}", - left, - right, - i - ); - - assert_eq!(recall.num_queries, our_results.nrows()); - assert_eq!(recall.recall_k, expected.recall_k); - assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); - assert_eq!(recall.by_query, Some(expected.components.clone())); - } - } - - #[test] - fn test_errors() { - // k greater than n - { - let groundtruth = Matrix::::new(0, 10, 10); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 11, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::RecallKAndNError(..))); - } - - // Unequal rows - { - let groundtruth = Matrix::::new(0, 11, 10); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::RowsMismatch(..))); - let err_allow_insufficient_results = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - true, - false, - ) - .unwrap_err(); - assert!(matches!( - err_allow_insufficient_results, - ComputeRecallError::RowsMismatch(..) - )); - } - - // Not enough results - { - let groundtruth = Matrix::::new(0, 10, 10); - let results = Matrix::::new(0, 10, 5); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 5, - 10, - false, - false, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughResults(..))); - let _ = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 5, - 10, - true, - false, - ); - } - - // Not enough groundtruth - { - let groundtruth = Matrix::::new(0, 10, 5); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..))); - let err_allow_insufficient_results = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - true, - false, - ) - .unwrap_err(); - assert!(matches!( - err_allow_insufficient_results, - ComputeRecallError::NotEnoughGroundTruth(..) - )); - } - - // Distance Row Mismatch - { - let groundtruth = Matrix::::new(0, 10, 10); - let distances = Matrix::::new(0.0, 9, 10); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - Some(distances.as_view().into()), - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::DistanceRowsMismatch(..))); - } - - // Distance Cols Mismatch - { - let groundtruth = Matrix::::new(0, 10, 10); - let distances = Matrix::::new(0.0, 10, 9); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - Some(distances.as_view().into()), - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!( - err, - ComputeRecallError::NotEnoughGroundTruthDistances(..) - )); +impl From<&benchmark_core::recall::AveragePrecisionMetrics> for AveragePrecisionMetrics { + fn from(m: &benchmark_core::recall::AveragePrecisionMetrics) -> Self { + Self { + num_queries: m.num_queries, + average_precision: m.average_precision, } } } diff --git a/diskann-benchmark/src/utils/tokio.rs b/diskann-benchmark/src/utils/tokio.rs index a21d3f520..21c78abb2 100644 --- a/diskann-benchmark/src/utils/tokio.rs +++ b/diskann-benchmark/src/utils/tokio.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -/// Create a multi-threaded runtime with `num_threads`. +/// Create a generic multi-threaded runtime with `num_threads`. pub(crate) fn runtime(num_threads: usize) -> anyhow::Result { Ok(tokio::runtime::Builder::new_multi_thread() .worker_threads(num_threads) @@ -18,21 +18,3 @@ pub(crate) fn block_on(future: F) -> F::Output { .expect("current thread runtime initialization failed") .block_on(future) } - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_runtimes() { - for num_threads in [1, 2, 4, 8] { - let rt = runtime(num_threads).unwrap(); - let metrics = rt.metrics(); - assert_eq!(metrics.num_workers(), num_threads); - } - } -} diff --git a/diskann-label-filter/src/attribute.rs b/diskann-label-filter/src/attribute.rs index 9eb7ff500..f0d99bfd9 100644 --- a/diskann-label-filter/src/attribute.rs +++ b/diskann-label-filter/src/attribute.rs @@ -5,6 +5,7 @@ use std::fmt::Display; use std::hash::{Hash, Hasher}; +use std::io::Write; use serde_json::Value; use thiserror::Error; diff --git a/diskann-label-filter/src/parser/format.rs b/diskann-label-filter/src/parser/format.rs index c042d8338..5e9e3a9c1 100644 --- a/diskann-label-filter/src/parser/format.rs +++ b/diskann-label-filter/src/parser/format.rs @@ -15,8 +15,10 @@ pub struct Document { /// label in raw json format #[serde(flatten)] pub label: serde_json::Value, + } + /// Represents a query expression as defined in the RFC. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueryExpression { diff --git a/diskann-label-filter/src/utils/flatten_utils.rs b/diskann-label-filter/src/utils/flatten_utils.rs index 16404af4b..83c9f80f9 100644 --- a/diskann-label-filter/src/utils/flatten_utils.rs +++ b/diskann-label-filter/src/utils/flatten_utils.rs @@ -154,7 +154,7 @@ fn flatten_json_pointer_inner( } Value::Array(arr) => { for (i, item) in arr.iter().enumerate() { - flatten_recursive(item, stack.push(&i, separator), out, separator); + flatten_recursive(item, stack.push(&String::from(""), separator), out, separator); } } _ => { diff --git a/diskann-tools/Cargo.toml b/diskann-tools/Cargo.toml index 7f0cb203a..1b4b3408e 100644 --- a/diskann-tools/Cargo.toml +++ b/diskann-tools/Cargo.toml @@ -5,14 +5,13 @@ version.workspace = true authors.workspace = true description.workspace = true documentation.workspace = true -license.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] byteorder.workspace = true clap = { workspace = true, features = ["derive"] } -diskann-providers = { workspace = true, default-features = false } # see `linalg/Cargo.toml` +diskann-providers = { workspace = true, default-features = false } # see `linalg/Cargo.toml` diskann-vector = { workspace = true } diskann-disk = { workspace = true } diskann-utils = { workspace = true } @@ -24,31 +23,24 @@ ordered-float = "4.2.0" rand_distr.workspace = true rand.workspace = true serde = { workspace = true, features = ["derive"] } -toml = "0.8.13" +serde_json.workspace = true bincode.workspace = true opentelemetry.workspace = true -opentelemetry_sdk.workspace = true -csv.workspace = true -tokio = { workspace = true, features = ["full"] } -arc-swap.workspace = true diskann-quantization = { workspace = true } diskann = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } tracing.workspace = true bit-set.workspace = true anyhow.workspace = true -serde_json.workspace = true itertools.workspace = true diskann-label-filter.workspace = true [dev-dependencies] rstest.workspace = true -assert_ok = "1.0.2" -# Use virtual-storage for integration tests -diskann-disk = { path = "../diskann-disk", features = ["virtual_storage"] } vfs = { workspace = true } -ureq = { version = "3.0.11", default-features = false, features = ["native-tls", "gzip"] } -diskann-providers = { path = "../diskann-providers", default-features = false, features = ["testing", "virtual_storage"] } +diskann-providers = { workspace = true, default-features = false, features = [ + "virtual_storage", +] } diskann-utils = { workspace = true, features = ["testing"] } [features] diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index e96f7ae8f..31e69b2b2 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -4,7 +4,7 @@ */ use bit_set::BitSet; -use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels}; +use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels, ASTExpr}; use std::{io::Write, mem::size_of, str::FromStr}; @@ -25,18 +25,97 @@ use diskann_utils::views::Matrix; use diskann_vector::{distance::Metric, DistanceFunction}; use itertools::Itertools; use rayon::prelude::*; +use serde_json::{Map, Value}; use crate::utils::{search_index_utils, CMDResult, CMDToolError}; +/// Expands a JSON object with array-valued fields into multiple objects with scalar values. +/// For example: {"country": ["AU", "NZ"], "year": 2007} +/// becomes: [{"country": "AU", "year": 2007}, {"country": "NZ", "year": 2007}] +/// +/// If multiple fields have arrays, all combinations are generated. +fn expand_array_fields(value: &Value) -> Vec { + match value { + Value::Object(map) => { + // Start with a single empty object + let mut results: Vec> = vec![Map::new()]; + + for (key, val) in map.iter() { + if let Value::Array(arr) = val { + // Expand: for each existing result, create copies for each array element + let mut new_results: Vec> = Vec::new(); + for existing in results.iter() { + for item in arr.iter() { + let mut new_map: Map = existing.clone(); + new_map.insert(key.clone(), item.clone()); + new_results.push(new_map); + } + } + // If array is empty, keep existing results without this key + if !arr.is_empty() { + results = new_results; + } + } else { + // Non-array field: add to all existing results + for existing in results.iter_mut() { + existing.insert(key.clone(), val.clone()); + } + } + } + + results.into_iter().map(Value::Object).collect() + } + // If not an object, return as-is + _ => vec![value.clone()], + } +} + +/// Evaluates a query expression against a label, expanding array fields first. +/// Returns true if any expanded variant matches the query. +fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { + let expanded = expand_array_fields(label); + expanded.iter().any(|item| eval_query_expr(query_expr, item)) +} + pub fn read_labels_and_compute_bitmap( base_label_filename: &str, query_label_filename: &str, ) -> CMDResult> { // Read base labels let base_labels = read_baselabels(base_label_filename)?; + tracing::info!( + "Loaded {} base labels from {}", + base_labels.len(), + base_label_filename + ); + + // Print first few base labels for debugging + for (i, label) in base_labels.iter().take(3).enumerate() { + tracing::debug!( + "Base label sample [{}]: doc_id={}, label={}", + i, + label.doc_id, + label.label + ); + } // Parse queries and evaluate against labels let parsed_queries = read_and_parse_queries(query_label_filename)?; + tracing::info!( + "Loaded {} queries from {}", + parsed_queries.len(), + query_label_filename + ); + + // Print first few queries for debugging + for (i, (query_id, query_expr)) in parsed_queries.iter().take(3).enumerate() { + tracing::debug!( + "Query sample [{}]: query_id={}, expr={:?}", + i, + query_id, + query_expr + ); + } // using the global threadpool is fine here #[allow(clippy::disallowed_methods)] @@ -45,7 +124,15 @@ pub fn read_labels_and_compute_bitmap( .map(|(_query_id, query_expr)| { let mut bitmap = BitSet::new(); for base_label in base_labels.iter() { - if eval_query_expr(query_expr, &base_label.label) { + // Handle case where base_label.label is an array - check if any element matches + // Also expand array-valued fields within objects (e.g., {"country": ["AU", "NZ"]}) + let matches = if let Some(array) = base_label.label.as_array() { + array.iter().any(|item| eval_query_with_array_expansion(query_expr, item)) + } else { + eval_query_with_array_expansion(query_expr, &base_label.label) + }; + + if matches { bitmap.insert(base_label.doc_id); } } @@ -53,6 +140,38 @@ pub fn read_labels_and_compute_bitmap( }) .collect(); + // Debug: Print match statistics for each query + let total_matches: usize = query_bitmaps.iter().map(|b| b.len()).sum(); + let queries_with_matches = query_bitmaps.iter().filter(|b| !b.is_empty()).count(); + tracing::info!( + "Filter matching summary: {} total matches across {} queries ({} queries have matches)", + total_matches, + query_bitmaps.len(), + queries_with_matches + ); + + // Print per-query match counts + for (i, bitmap) in query_bitmaps.iter().enumerate() { + if i < 10 || bitmap.is_empty() { + tracing::debug!( + "Query {}: {} base vectors matched the filter", + i, + bitmap.len() + ); + } + } + + // If no matches, print more diagnostic info + if total_matches == 0 { + tracing::warn!("WARNING: No base vectors matched any query filters!"); + tracing::warn!("This could indicate a format mismatch between base labels and query filters."); + + // Try to identify what keys exist in base labels vs queries + if let Some(first_label) = base_labels.first() { + tracing::warn!("First base label (full): doc_id={}, label={}", first_label.doc_id, first_label.label); + } + } + Ok(query_bitmaps) } @@ -195,6 +314,44 @@ pub fn compute_ground_truth_from_datafiles< assert_ne!(ground_truth.len(), 0, "No ground-truth results computed"); + // Debug: Print top K matches for each query + tracing::info!( + "Ground truth computed for {} queries with recall_at={}", + ground_truth.len(), + recall_at + ); + for (query_idx, npq) in ground_truth.iter().enumerate() { + let neighbors: Vec<_> = npq.iter().collect(); + let neighbor_count = neighbors.len(); + + if query_idx < 10 { + // Print top K IDs and distances for first 10 queries + let top_ids: Vec = neighbors.iter().take(10).map(|n| n.id).collect(); + let top_dists: Vec = neighbors.iter().take(10).map(|n| n.distance).collect(); + tracing::debug!( + "Query {}: {} neighbors found. Top IDs: {:?}, Top distances: {:?}", + query_idx, + neighbor_count, + top_ids, + top_dists + ); + } + + if neighbor_count == 0 { + tracing::warn!("Query {} has 0 neighbors in ground truth!", query_idx); + } + } + + // Summary stats + let total_neighbors: usize = ground_truth.iter().map(|npq| npq.iter().count()).sum(); + let queries_with_neighbors = ground_truth.iter().filter(|npq| npq.iter().count() > 0).count(); + tracing::info!( + "Ground truth summary: {} total neighbors, {} queries have neighbors, {} queries have 0 neighbors", + total_neighbors, + queries_with_neighbors, + ground_truth.len() - queries_with_neighbors + ); + if has_vector_filters || has_query_bitmaps { let ground_truth_collection = ground_truth .into_iter() From ec2091ffb510a245970e0fdec83bb46955cbefe7 Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Tue, 10 Feb 2026 11:08:49 +0530 Subject: [PATCH 02/50] Before merging with main --- Cargo.lock | 340 ------------------ .../src/utils/flatten_utils.rs | 2 +- 2 files changed, 1 insertion(+), 341 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e80330d7d..665b6c6df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,12 +2,6 @@ # It is not intended for manual editing. version = 4 -[[package]] -name = "adler2" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" - [[package]] name = "aho-corasick" version = "1.1.4" @@ -103,30 +97,12 @@ dependencies = [ "rustversion", ] -[[package]] -name = "assert_ok" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c770ef7624541db11cce57929f00e737fef89157d7c1cd1977b20ee74fefd84" - [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" -[[package]] -name = "base64" -version = "0.22.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" - -[[package]] -name = "base64ct" -version = "1.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" - [[package]] name = "bf-tree" version = "0.4.5" @@ -225,16 +201,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" -[[package]] -name = "cc" -version = "1.2.52" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd4932aefd12402b36c60956a4fe0035421f544799057659ff86f923657aada3" -dependencies = [ - "find-msvc-tools", - "shlex", -] - [[package]] name = "cfg-if" version = "1.0.0" @@ -327,31 +293,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "core-foundation" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" -dependencies = [ - "core-foundation-sys", - "libc", -] - -[[package]] -name = "core-foundation-sys" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" - -[[package]] -name = "crc32fast" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" -dependencies = [ - "cfg-if", -] - [[package]] name = "criterion" version = "0.5.1" @@ -419,43 +360,12 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" -[[package]] -name = "csv" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" -dependencies = [ - "csv-core", - "itoa", - "ryu", - "serde_core", -] - -[[package]] -name = "csv-core" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" -dependencies = [ - "memchr", -] - [[package]] name = "defer" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "930c7171c8df9fb1782bdf9b918ed9ed2d33d1d22300abb754f9085bc48bf8e8" -[[package]] -name = "der" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" -dependencies = [ - "pem-rfc7468", - "zeroize", -] - [[package]] name = "derive_more" version = "2.1.1" @@ -718,14 +628,11 @@ name = "diskann-tools" version = "0.41.0" dependencies = [ "anyhow", - "arc-swap", - "assert_ok", "bincode", "bit-set", "bytemuck", "byteorder", "clap", - "csv", "diskann", "diskann-disk", "diskann-label-filter", @@ -737,7 +644,6 @@ dependencies = [ "itertools 0.13.0", "num_cpus", "opentelemetry", - "opentelemetry_sdk", "ordered-float", "rand 0.9.2", "rand_distr", @@ -745,11 +651,8 @@ dependencies = [ "rstest", "serde", "serde_json", - "tokio", - "toml 0.8.23", "tracing", "tracing-subscriber", - "ureq", "vfs", ] @@ -956,12 +859,6 @@ dependencies = [ "windows-sys 0.60.2", ] -[[package]] -name = "find-msvc-tools" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f449e6c6c08c865631d4890cfacf252b3d396c9bcc83adb6623cdb02a8336c41" - [[package]] name = "flatbuffers" version = "25.12.19" @@ -972,16 +869,6 @@ dependencies = [ "rustc_version", ] -[[package]] -name = "flate2" -version = "1.1.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b375d6465b98090a5f25b1c7703f3859783755aa9a80433b36e0379a3ec2f369" -dependencies = [ - "crc32fast", - "miniz_oxide", -] - [[package]] name = "fnv" version = "1.0.7" @@ -1000,21 +887,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - [[package]] name = "futures" version = "0.3.31" @@ -1322,22 +1194,6 @@ dependencies = [ "paste", ] -[[package]] -name = "http" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" -dependencies = [ - "bytes", - "itoa", -] - -[[package]] -name = "httparse" -version = "1.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" - [[package]] name = "iai-callgrind" version = "0.14.2" @@ -1559,16 +1415,6 @@ version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" -[[package]] -name = "miniz_oxide" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" -dependencies = [ - "adler2", - "simd-adler32", -] - [[package]] name = "mio" version = "1.1.1" @@ -1650,23 +1496,6 @@ dependencies = [ "nano-gemm-core", ] -[[package]] -name = "native-tls" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" -dependencies = [ - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - [[package]] name = "never-say-never" version = "6.6.666" @@ -1730,50 +1559,6 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" -[[package]] -name = "openssl" -version = "0.10.75" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" -dependencies = [ - "bitflags 2.10.0", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-macros", - "openssl-sys", -] - -[[package]] -name = "openssl-macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.113", -] - -[[package]] -name = "openssl-probe" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" - -[[package]] -name = "openssl-sys" -version = "0.9.111" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" -dependencies = [ - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "opentelemetry" version = "0.30.0" @@ -1842,15 +1627,6 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" -[[package]] -name = "pem-rfc7468" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" -dependencies = [ - "base64ct", -] - [[package]] name = "percent-encoding" version = "2.3.2" @@ -1889,12 +1665,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" -[[package]] -name = "pkg-config" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" - [[package]] name = "plotters" version = "0.3.7" @@ -2326,15 +2096,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "rustls-pki-types" -version = "1.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282" -dependencies = [ - "zeroize", -] - [[package]] name = "rustversion" version = "1.0.22" @@ -2384,15 +2145,6 @@ dependencies = [ "sdd", ] -[[package]] -name = "schannel" -version = "0.1.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" -dependencies = [ - "windows-sys 0.61.2", -] - [[package]] name = "scopeguard" version = "1.2.0" @@ -2405,29 +2157,6 @@ version = "4.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63d45f3526312c9c90d717aac28d37010e623fbd7ca6f21503e69784e86f40" -[[package]] -name = "security-framework" -version = "2.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" -dependencies = [ - "bitflags 2.10.0", - "core-foundation", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "semver" version = "1.0.27" @@ -2523,12 +2252,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "shlex" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" - [[package]] name = "signal-hook-registry" version = "1.4.8" @@ -2539,12 +2262,6 @@ dependencies = [ "libc", ] -[[package]] -name = "simd-adler32" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" - [[package]] name = "slab" version = "0.4.11" @@ -2922,42 +2639,6 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" -[[package]] -name = "ureq" -version = "3.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a" -dependencies = [ - "base64", - "der", - "flate2", - "log", - "native-tls", - "percent-encoding", - "rustls-pki-types", - "ureq-proto", - "utf-8", - "webpki-root-certs", -] - -[[package]] -name = "ureq-proto" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" -dependencies = [ - "base64", - "http", - "httparse", - "log", -] - -[[package]] -name = "utf-8" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" - [[package]] name = "utf8parse" version = "0.2.2" @@ -2970,12 +2651,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - [[package]] name = "version_check" version = "0.9.5" @@ -3090,15 +2765,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "webpki-root-certs" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36a29fc0408b113f68cf32637857ab740edfafdf460c326cd2afaa2d84cc05dc" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "winapi-util" version = "0.1.11" @@ -3305,12 +2971,6 @@ dependencies = [ "syn 2.0.113", ] -[[package]] -name = "zeroize" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" - [[package]] name = "zmij" version = "1.0.11" diff --git a/diskann-label-filter/src/utils/flatten_utils.rs b/diskann-label-filter/src/utils/flatten_utils.rs index 83c9f80f9..16404af4b 100644 --- a/diskann-label-filter/src/utils/flatten_utils.rs +++ b/diskann-label-filter/src/utils/flatten_utils.rs @@ -154,7 +154,7 @@ fn flatten_json_pointer_inner( } Value::Array(arr) => { for (i, item) in arr.iter().enumerate() { - flatten_recursive(item, stack.push(&String::from(""), separator), out, separator); + flatten_recursive(item, stack.push(&i, separator), out, separator); } } _ => { From a949024b8283390d49b43061973abbb74653d17d Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Mon, 16 Feb 2026 14:40:55 +0530 Subject: [PATCH 03/50] Working version of inline beta search --- .../example/document-filter.json | 34 + .../src/backend/document_index/benchmark.rs | 1038 ++++++++++++++ .../src/backend/document_index/mod.rs | 13 + diskann-benchmark/src/backend/index/result.rs | 13 + diskann-benchmark/src/backend/mod.rs | 2 + .../src/inputs/document_index.rs | 177 +++ diskann-benchmark/src/inputs/mod.rs | 2 + diskann-benchmark/src/utils/recall.rs | 1 + diskann-benchmark/src/utils/tokio.rs | 7 + diskann-label-filter/src/attribute.rs | 1 - diskann-label-filter/src/document.rs | 4 +- .../ast_label_id_mapper.rs | 15 +- .../document_insert_strategy.rs | 274 ++++ .../document_provider.rs | 2 +- .../encoded_filter_expr.rs | 19 +- .../roaring_attribute_store.rs | 2 +- .../encoded_document_accessor.rs | 14 +- .../inline_beta_search/inline_beta_filter.rs | 67 +- diskann-label-filter/src/lib.rs | 1 + diskann-label-filter/src/parser/format.rs | 2 - .../provider/async_/inmem/full_precision.rs | 1218 +++++++++-------- diskann-tools/src/utils/ground_truth.rs | 37 +- .../disk_index_search/data.256.label.jsonl | 4 +- 23 files changed, 2307 insertions(+), 640 deletions(-) create mode 100644 diskann-benchmark/example/document-filter.json create mode 100644 diskann-benchmark/src/backend/document_index/benchmark.rs create mode 100644 diskann-benchmark/src/backend/document_index/mod.rs create mode 100644 diskann-benchmark/src/inputs/document_index.rs create mode 100644 diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs diff --git a/diskann-benchmark/example/document-filter.json b/diskann-benchmark/example/document-filter.json new file mode 100644 index 000000000..d6e9e13b2 --- /dev/null +++ b/diskann-benchmark/example/document-filter.json @@ -0,0 +1,34 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "document-index-build", + "content": { + "build": { + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data_labels": "data.256.label.jsonl", + "distance": "squared_l2", + "max_degree": 32, + "l_build": 50, + "alpha": 1.2 + }, + "search": { + "queries": "disk_index_sample_query_10pts.fbin", + "query_predicates": "query.10.label.jsonl", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_filter_res.bin", + "beta": 0.5, + "runs": [ + { + "search_n": 20, + "search_l": [20, 30, 40], + "recall_k": 10 + } + ] + } + } + } + ] +} \ No newline at end of file diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs new file mode 100644 index 000000000..dffe669ff --- /dev/null +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -0,0 +1,1038 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Benchmark for DocumentInsertStrategy which allows inserting Documents +//! (vector + attributes) into a DiskANN index built with DocumentProvider. +//! Also benchmarks filtered search using InlineBetaStrategy. + +use std::io::Write; +use std::num::NonZeroUsize; +use std::path::Path; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; + +use anyhow::Result; +use diskann::{ + graph::{ + config::Builder as ConfigBuilder, config::MaxDegree, config::PruneKind, + search_output_buffer, DiskANNIndex, SearchParams, StartPointStrategy, + }, + provider::DefaultContext, + utils::{async_tools, IntoUsize}, +}; +use diskann_benchmark_runner::{ + dispatcher::{DispatchRule, FailureScore, MatchScore}, + output::Output, + registry::Benchmarks, + utils::{datatype::DataType, percentiles, MicroSeconds}, + Any, Checkpoint, +}; +use diskann_label_filter::{ + attribute::{Attribute, AttributeValue}, + document::Document, + encoded_attribute_provider::{ + document_insert_strategy::DocumentInsertStrategy, document_provider::DocumentProvider, + roaring_attribute_store::RoaringAttributeStore, + }, + inline_beta_search::inline_beta_filter::InlineBetaStrategy, + query::FilteredQuery, + read_and_parse_queries, read_baselabels, ASTExpr, +}; +use diskann_providers::model::graph::provider::async_::{ + common::{self, NoStore, TableBasedDeletes}, + inmem::{CreateFullPrecision, DefaultProvider, DefaultProviderParameters, SetStartPoints}, +}; +use diskann_utils::views::Matrix; +use indicatif::{ProgressBar, ProgressStyle}; +use serde::Serialize; + +use crate::{ + inputs::document_index::DocumentIndexBuild, + utils::{ + self, + datafiles::{self, BinFile}, + recall, + }, +}; + +/// Register the document index benchmarks. +pub(crate) fn register_benchmarks(benchmarks: &mut Benchmarks) { + benchmarks.register::>( + "document-index-build", + |job, checkpoint, out| { + let stats = job.run(checkpoint, out)?; + Ok(serde_json::to_value(stats)?) + }, + ); +} + +/// Document index benchmark job. +pub(super) struct DocumentIndexJob<'a> { + input: &'a DocumentIndexBuild, +} + +impl<'a> DocumentIndexJob<'a> { + fn new(input: &'a DocumentIndexBuild) -> Self { + Self { input } + } +} + +impl diskann_benchmark_runner::dispatcher::Map for DocumentIndexJob<'static> { + type Type<'a> = DocumentIndexJob<'a>; +} + +// Dispatch from the concrete input type +impl<'a> DispatchRule<&'a DocumentIndexBuild> for DocumentIndexJob<'a> { + type Error = std::convert::Infallible; + + fn try_match(_from: &&'a DocumentIndexBuild) -> Result { + Ok(MatchScore(1)) + } + + fn convert(from: &'a DocumentIndexBuild) -> Result { + Ok(DocumentIndexJob::new(from)) + } + + fn description( + f: &mut std::fmt::Formatter<'_>, + _from: Option<&&'a DocumentIndexBuild>, + ) -> std::fmt::Result { + writeln!(f, "tag: \"{}\"", DocumentIndexBuild::tag()) + } +} + +// Central dispatch mapping from Any +impl<'a> DispatchRule<&'a Any> for DocumentIndexJob<'a> { + type Error = anyhow::Error; + + fn try_match(from: &&'a Any) -> Result { + from.try_match::() + } + + fn convert(from: &'a Any) -> Result { + from.convert::() + } + + fn description(f: &mut std::fmt::Formatter, from: Option<&&'a Any>) -> std::fmt::Result { + Any::description::(f, from, DocumentIndexBuild::tag()) + } +} +/// Convert a HashMap to Vec +fn hashmap_to_attributes(map: std::collections::HashMap) -> Vec { + map.into_iter() + .map(|(k, v)| Attribute::from_value(k, v)) + .collect() +} + +/// Compute the index of the row closest to the medoid (centroid) of the data. +fn compute_medoid_index(data: &Matrix) -> usize +where + T: bytemuck::Pod + Copy + 'static, +{ + use diskann_vector::{distance::SquaredL2, PureDistanceFunction}; + + let dim = data.ncols(); + if dim == 0 || data.nrows() == 0 { + return 0; + } + + // Compute the centroid (mean of all rows) as f64 for precision + let mut sum = vec![0.0f64; dim]; + for i in 0..data.nrows() { + let row = data.row(i); + for (j, &v) in row.iter().enumerate() { + // Convert T to f64 for summation using bytemuck + let f64_val: f64 = if std::any::TypeId::of::() == std::any::TypeId::of::() { + let f32_val: f32 = bytemuck::cast(v); + f32_val as f64 + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let u8_val: u8 = bytemuck::cast(v); + u8_val as f64 + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let i8_val: i8 = bytemuck::cast(v); + i8_val as f64 + } else { + 0.0 + }; + sum[j] += f64_val; + } + } + + // Convert centroid to f32 and compute distances + let centroid_f32: Vec = sum + .iter() + .map(|s| (s / data.nrows() as f64) as f32) + .collect(); + + // Find the row closest to the centroid + let mut min_dist = f32::MAX; + let mut medoid_idx = 0; + for i in 0..data.nrows() { + let row = data.row(i); + let row_f32: Vec = row + .iter() + .map(|&v| { + if std::any::TypeId::of::() == std::any::TypeId::of::() { + bytemuck::cast(v) + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let u8_val: u8 = bytemuck::cast(v); + u8_val as f32 + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let i8_val: i8 = bytemuck::cast(v); + i8_val as f32 + } else { + 0.0 + } + }) + .collect(); + let d = SquaredL2::evaluate(centroid_f32.as_slice(), row_f32.as_slice()); + if d < min_dist { + min_dist = d; + medoid_idx = i; + } + } + + medoid_idx +} + +impl<'a> DocumentIndexJob<'a> { + fn run( + self, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> Result { + // Print the input description + writeln!(output, "{}", self.input)?; + + let build = &self.input.build; + + // Dispatch based on data type - retain original type without conversion + match build.data_type { + DataType::Float32 => self.run_typed::(output), + DataType::UInt8 => self.run_typed::(output), + DataType::Int8 => self.run_typed::(output), + _ => Err(anyhow::anyhow!( + "Unsupported data type: {:?}. Supported types: float32, uint8, int8.", + build.data_type + )), + } + } + + fn run_typed(self, mut output: &mut dyn Output) -> Result + where + T: bytemuck::Pod + Copy + Send + Sync + 'static + std::fmt::Debug, + T: diskann::graph::SampleableForStart + diskann_utils::future::AsyncFriendly, + T: diskann::utils::VectorRepr + diskann_utils::sampling::WithApproximateNorm, + { + let build = &self.input.build; + + // 1. Load vectors from data file in the original data type + writeln!(output, "Loading vectors ({})...", build.data_type)?; + let timer = std::time::Instant::now(); + let data_path: &Path = build.data.as_ref(); + writeln!(output, "Data path is: {}", data_path.to_string_lossy())?; + let data: Matrix = datafiles::load_dataset(BinFile(data_path))?; + let data_load_time: MicroSeconds = timer.elapsed().into(); + let num_vectors = data.nrows(); + let dim = data.ncols(); + writeln!( + output, + " Loaded {} vectors of dimension {}", + num_vectors, dim + )?; + + // 2. Load and parse labels from the data_labels file + writeln!(output, "Loading labels...")?; + let timer = std::time::Instant::now(); + let label_path: &Path = build.data_labels.as_ref(); + let labels = read_baselabels(label_path)?; + let label_load_time: MicroSeconds = timer.elapsed().into(); + let label_count = labels.len(); + writeln!(output, " Loaded {} label documents", label_count)?; + + if num_vectors != label_count { + return Err(anyhow::anyhow!( + "Mismatch: {} vectors but {} label documents", + num_vectors, + label_count + )); + } + + // Convert labels to attribute vectors + let attributes: Vec> = labels + .into_iter() + .map(|doc| hashmap_to_attributes(doc.flatten_metadata_with_separator(""))) + .collect(); + + // 3. Create the index configuration + let metric = build.distance.into(); + let prune_kind = PruneKind::from_metric(metric); + let mut config_builder = ConfigBuilder::new( + build.max_degree, // pruned_degree + MaxDegree::Same, // max_degree + build.l_build, // l_build + prune_kind, // prune_kind + ); + config_builder.alpha(build.alpha); + let config = config_builder.build()?; + + // 4. Create the data provider directly + writeln!(output, "Creating index...")?; + let params = DefaultProviderParameters { + max_points: num_vectors, + frozen_points: diskann::utils::ONE, + metric, + dim, + prefetch_lookahead: None, + prefetch_cache_line_level: None, + max_degree: build.max_degree as u32, + }; + + // Create the underlying provider + let fp_precursor = CreateFullPrecision::::new(dim, None); + let inner_provider = + DefaultProvider::new_empty(params, fp_precursor, NoStore, TableBasedDeletes)?; + + // Set start points using medoid strategy + let start_points = StartPointStrategy::Medoid + .compute(data.as_view()) + .map_err(|e| anyhow::anyhow!("Failed to compute start points: {}", e))?; + inner_provider.set_start_points(start_points.row_iter())?; + + // 5. Create DocumentProvider wrapping the inner provider + let attribute_store = RoaringAttributeStore::::new(); + + // Store attributes for the start point (medoid) + // Start points are stored at indices num_vectors..num_vectors+frozen_points + let medoid_idx = compute_medoid_index(&data); + let start_point_id = num_vectors as u32; // Start points begin at max_points + let medoid_attrs = attributes.get(medoid_idx).cloned().unwrap_or_default(); + use diskann_label_filter::traits::attribute_store::AttributeStore; + attribute_store.set_element(&start_point_id, &medoid_attrs)?; + + let doc_provider = DocumentProvider::new(inner_provider, attribute_store); + + // Create a new DiskANNIndex with DocumentProvider + let doc_index = Arc::new(DiskANNIndex::new(config, doc_provider, None)); + + // 6. Build index by inserting vectors and attributes (parallel) + writeln!( + output, + "Building index with {} vectors using {} threads...", + num_vectors, build.num_threads + )?; + let timer = std::time::Instant::now(); + + let insert_strategy: DocumentInsertStrategy<_, [T]> = + DocumentInsertStrategy::new(common::FullPrecision); + let rt = utils::tokio::runtime(build.num_threads)?; + + // Create control block for parallel work distribution + let data_arc = Arc::new(data); + let attributes_arc = Arc::new(attributes); + let control_block = DocumentControlBlock::new( + data_arc.clone(), + attributes_arc.clone(), + output.draw_target(), + )?; + + let num_tasks = build.num_threads; + let insert_latencies = rt.block_on(async { + let tasks: Vec<_> = (0..num_tasks) + .map(|_| { + let block = control_block.clone(); + let index = doc_index.clone(); + let strategy = insert_strategy; + tokio::spawn(async move { + let mut latencies = Vec::::new(); + let ctx = DefaultContext; + loop { + match block.next() { + Some((id, vector, attrs)) => { + let doc = Document::new(vector, attrs); + let start = std::time::Instant::now(); + let result = + index.insert(strategy, &ctx, &(id as u32), &doc).await; + latencies.push(MicroSeconds::from(start.elapsed())); + + if let Err(e) = result { + block.cancel(); + return Err(e); + } + } + None => return Ok(latencies), + } + } + }) + }) + .collect(); + + // Collect results from all tasks + let mut all_latencies = Vec::with_capacity(num_vectors); + for task in tasks { + let task_latencies = task.await??; + all_latencies.extend(task_latencies); + } + Ok::<_, anyhow::Error>(all_latencies) + })?; + + let build_time: MicroSeconds = timer.elapsed().into(); + writeln!(output, " Index built in {} s", build_time.as_seconds())?; + + let insert_percentiles = percentiles::compute_percentiles(&mut insert_latencies.clone())?; + // ===================== + // Search Phase + // ===================== + let search_input = &self.input.search; + + // Load query vectors (same type as data for compatible distance computation) + writeln!(output, "\nLoading query vectors...")?; + let query_path: &Path = search_input.queries.as_ref(); + let queries: Matrix = datafiles::load_dataset(BinFile(query_path))?; + let num_queries = queries.nrows(); + writeln!(output, " Loaded {} queries", num_queries)?; + + // Load and parse query predicates + writeln!(output, "Loading query predicates...")?; + let predicate_path: &Path = search_input.query_predicates.as_ref(); + let parsed_predicates = read_and_parse_queries(predicate_path)?; + writeln!(output, " Loaded {} predicates", parsed_predicates.len())?; + + if num_queries != parsed_predicates.len() { + return Err(anyhow::anyhow!( + "Mismatch: {} queries but {} predicates", + num_queries, + parsed_predicates.len() + )); + } + + // Load groundtruth + writeln!(output, "Loading groundtruth...")?; + let gt_path: &Path = search_input.groundtruth.as_ref(); + let groundtruth: Vec> = datafiles::load_range_groundtruth(BinFile(gt_path))?; + writeln!( + output, + " Loaded groundtruth with {} rows", + groundtruth.len() + )?; + + // Run filtered searches + writeln!( + output, + "\nRunning filtered searches (beta={})...", + search_input.beta + )?; + let mut search_results = Vec::new(); + + for num_threads in &search_input.num_threads { + for run in &search_input.runs { + for &search_l in &run.search_l { + writeln!( + output, + " threads={}, search_n={}, search_l={}...", + num_threads, run.search_n, search_l + )?; + + let search_run_result = run_filtered_search( + &doc_index, + &queries, + &parsed_predicates, + &groundtruth, + search_input.beta, + *num_threads, + run.search_n, + search_l, + run.recall_k, + search_input.reps, + )?; + + writeln!( + output, + " recall={:.4}, mean_qps={:.1}", + search_run_result.recall.average, + if search_run_result.qps.is_empty() { + 0.0 + } else { + search_run_result.qps.iter().sum::() + / search_run_result.qps.len() as f64 + } + )?; + + search_results.push(search_run_result); + } + } + } + + let stats = DocumentIndexStats { + num_vectors, + dim, + label_count, + data_load_time, + label_load_time, + build_time, + insert_latencies: insert_percentiles, + build_params: BuildParamsStats { + max_degree: build.max_degree, + l_build: build.l_build, + alpha: build.alpha, + }, + search: search_results, + }; + + writeln!(output, "\n{}", stats)?; + Ok(stats) + } +} +/// Local results from a partition of queries. +struct SearchLocalResults { + ids: Matrix, + distances: Vec>, + latencies: Vec, + comparisons: Vec, + hops: Vec, +} + +impl SearchLocalResults { + fn merge(all: &[SearchLocalResults]) -> anyhow::Result { + let first = all + .first() + .ok_or_else(|| anyhow::anyhow!("empty results"))?; + let num_ids = first.ids.ncols(); + let total_rows: usize = all.iter().map(|r| r.ids.nrows()).sum(); + + let mut ids = Matrix::new(0, total_rows, num_ids); + let mut output_row = 0; + for r in all { + for input_row in r.ids.row_iter() { + ids.row_mut(output_row).copy_from_slice(input_row); + output_row += 1; + } + } + + let mut distances = Vec::new(); + let mut latencies = Vec::new(); + let mut comparisons = Vec::new(); + let mut hops = Vec::new(); + for r in all { + distances.extend_from_slice(&r.distances); + latencies.extend_from_slice(&r.latencies); + comparisons.extend_from_slice(&r.comparisons); + hops.extend_from_slice(&r.hops); + } + + Ok(Self { + ids, + distances, + latencies, + comparisons, + hops, + }) + } +} + +/// Run filtered search with the given parameters. +#[allow(clippy::too_many_arguments)] +fn run_filtered_search( + index: &Arc>, + queries: &Matrix, + predicates: &[(usize, ASTExpr)], + groundtruth: &Vec>, + beta: f32, + num_threads: NonZeroUsize, + search_n: usize, + search_l: usize, + recall_k: usize, + reps: NonZeroUsize, +) -> anyhow::Result +where + T: bytemuck::Pod + Copy + Send + Sync + 'static, + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync + + 'static, + InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>>, +{ + let rt = utils::tokio::runtime(num_threads.get())?; + let num_queries = queries.nrows(); + + let mut all_rep_results = Vec::with_capacity(reps.get()); + let mut rep_latencies = Vec::with_capacity(reps.get()); + + for _ in 0..reps.get() { + let start = std::time::Instant::now(); + let results = rt.block_on(run_search_parallel( + index.clone(), + queries, + predicates, + beta, + num_threads, + search_n, + search_l, + ))?; + rep_latencies.push(MicroSeconds::from(start.elapsed())); + all_rep_results.push(results); + } + + // Merge results from first rep for recall calculation + let merged = SearchLocalResults::merge(&all_rep_results[0])?; + + // Compute recall + let recall_metrics: recall::RecallMetrics = + (&recall::knn(groundtruth, None, &merged.ids, recall_k, search_n, false)?).into(); + + // Compute per-query details (only for queries with recall < 1) + let per_query_details: Vec = (0..num_queries) + .filter_map(|query_idx| { + let result_ids: Vec = merged + .ids + .row(query_idx) + .iter() + .copied() + .filter(|&id| id != u32::MAX) + .collect(); + let result_distances: Vec = merged + .distances + .get(query_idx) + .map(|d| d.iter().copied().filter(|&dist| dist != f32::MAX).collect()) + .unwrap_or_default(); + // Only keep top 20 from ground truth + let gt_ids: Vec = groundtruth + .get(query_idx) + .map(|gt| gt.iter().take(20).copied().collect()) + .unwrap_or_default(); + + // Compute per-query recall: intersection of result_ids with gt_ids / recall_k + let result_set: std::collections::HashSet = result_ids.iter().copied().collect(); + let gt_set: std::collections::HashSet = + gt_ids.iter().take(recall_k).copied().collect(); + let intersection = result_set.intersection(>_set).count(); + let per_query_recall = if gt_set.is_empty() { + 1.0 + } else { + intersection as f64 / gt_set.len() as f64 + }; + + // Only include queries with imperfect recall + if per_query_recall >= 1.0 { + return None; + } + + let (_, ref ast_expr) = predicates[query_idx]; + let filter_str = format!("{:?}", ast_expr); + + Some(PerQueryDetails { + query_id: query_idx, + filter: filter_str, + recall: per_query_recall, + result_ids, + result_distances, + groundtruth_ids: gt_ids, + }) + }) + .collect(); + + // Compute QPS from rep latencies + let qps: Vec = rep_latencies + .iter() + .map(|l| num_queries as f64 / l.as_seconds()) + .collect(); + + // Aggregate per-query latencies across all reps + let (all_latencies, all_cmps, all_hops): (Vec<_>, Vec<_>, Vec<_>) = all_rep_results + .iter() + .map(|results| { + let mut lat = Vec::new(); + let mut cmp = Vec::new(); + let mut hop = Vec::new(); + for r in results { + lat.extend_from_slice(&r.latencies); + cmp.extend_from_slice(&r.comparisons); + hop.extend_from_slice(&r.hops); + } + (lat, cmp, hop) + }) + .fold( + (Vec::new(), Vec::new(), Vec::new()), + |(mut a, mut b, mut c): (Vec, Vec, Vec), (x, y, z)| { + a.extend(x); + b.extend(y); + c.extend(z); + (a, b, c) + }, + ); + + let mut query_latencies = all_latencies; + let percentiles::Percentiles { mean, p90, p99, .. } = + percentiles::compute_percentiles(&mut query_latencies)?; + + let mean_cmps = if all_cmps.is_empty() { + 0.0 + } else { + all_cmps.iter().map(|&x| x as f32).sum::() / all_cmps.len() as f32 + }; + let mean_hops = if all_hops.is_empty() { + 0.0 + } else { + all_hops.iter().map(|&x| x as f32).sum::() / all_hops.len() as f32 + }; + + Ok(SearchRunStats { + num_threads: num_threads.get(), + num_queries, + search_n, + search_l, + recall: recall_metrics, + qps, + wall_clock_time: rep_latencies, + mean_latency: mean, + p90_latency: p90, + p99_latency: p99, + mean_cmps, + mean_hops, + per_query_details: Some(per_query_details), + }) +} +async fn run_search_parallel( + index: Arc>, + queries: &Matrix, + predicates: &[(usize, ASTExpr)], + beta: f32, + num_tasks: NonZeroUsize, + search_n: usize, + search_l: usize, +) -> anyhow::Result> +where + T: bytemuck::Pod + Copy + Send + Sync + 'static, + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync + + 'static, + InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>>, +{ + let num_queries = queries.nrows(); + + // Plan query partitions + let partitions: Result, _> = (0..num_tasks.get()) + .map(|task_id| async_tools::partition(num_queries, num_tasks, task_id)) + .collect(); + let partitions = partitions?; + + // We need to clone data for each task + let queries_arc = Arc::new(queries.clone()); + let predicates_arc = Arc::new(predicates.to_vec()); + + let handles: Vec<_> = partitions + .into_iter() + .map(|range| { + let index = index.clone(); + let queries = queries_arc.clone(); + let predicates = predicates_arc.clone(); + tokio::spawn(async move { + run_search_local(index, queries, predicates, beta, range, search_n, search_l).await + }) + }) + .collect(); + + let mut results = Vec::new(); + for h in handles { + results.push(h.await??); + } + + Ok(results) +} + +async fn run_search_local( + index: Arc>, + queries: Arc>, + predicates: Arc>, + beta: f32, + range: std::ops::Range, + search_n: usize, + search_l: usize, +) -> anyhow::Result +where + T: bytemuck::Pod + Copy + Send + Sync + 'static, + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync, + InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>>, +{ + let mut ids = Matrix::new(0, range.len(), search_n); + let mut all_distances: Vec> = Vec::with_capacity(range.len()); + let mut latencies = Vec::with_capacity(range.len()); + let mut comparisons = Vec::with_capacity(range.len()); + let mut hops = Vec::with_capacity(range.len()); + + let ctx = DefaultContext; + let search_params = SearchParams::new_default(search_n, search_l)?; + + for (output_idx, query_idx) in range.enumerate() { + let query_vec = queries.row(query_idx); + let (_, ref ast_expr) = predicates[query_idx]; + + let strategy = InlineBetaStrategy::new(beta, common::FullPrecision); + let query_vec_owned = query_vec.to_vec(); + let filtered_query: FilteredQuery> = + FilteredQuery::new(query_vec_owned, ast_expr.clone()); + + let start = std::time::Instant::now(); + + let mut distances = vec![0.0f32; search_n]; + let result_ids = ids.row_mut(output_idx); + let mut result_buffer = search_output_buffer::IdDistance::new(result_ids, &mut distances); + + let stats = index + .search( + &strategy, + &ctx, + &filtered_query, + &search_params, + &mut result_buffer, + ) + .await?; + + let result_count = stats.result_count.into_usize(); + result_ids[result_count..].fill(u32::MAX); + distances[result_count..].fill(f32::MAX); + + latencies.push(MicroSeconds::from(start.elapsed())); + comparisons.push(stats.cmps); + hops.push(stats.hops); + all_distances.push(distances); + } + + Ok(SearchLocalResults { + ids, + distances: all_distances, + latencies, + comparisons, + hops, + }) +} +#[derive(Debug, Serialize)] +pub struct BuildParamsStats { + pub max_degree: usize, + pub l_build: usize, + pub alpha: f32, +} + +/// Helper module for serializing arrays as compact single-line JSON strings +mod compact_array { + use serde::Serializer; + + pub fn serialize_u32_vec(vec: &Vec, serializer: S) -> Result + where + S: Serializer, + { + // Serialize as a string containing the compact JSON array + let compact = serde_json::to_string(vec).unwrap_or_default(); + serializer.serialize_str(&compact) + } + + pub fn serialize_f32_vec(vec: &Vec, serializer: S) -> Result + where + S: Serializer, + { + // Serialize as a string containing the compact JSON array + let compact = serde_json::to_string(vec).unwrap_or_default(); + serializer.serialize_str(&compact) + } +} + +/// Per-query detailed results for debugging/analysis +#[derive(Debug, Serialize)] +pub struct PerQueryDetails { + pub query_id: usize, + pub filter: String, + pub recall: f64, + #[serde(serialize_with = "compact_array::serialize_u32_vec")] + pub result_ids: Vec, + #[serde(serialize_with = "compact_array::serialize_f32_vec")] + pub result_distances: Vec, + #[serde(serialize_with = "compact_array::serialize_u32_vec")] + pub groundtruth_ids: Vec, +} + +/// Results from a single search configuration (one search_l value). +#[derive(Debug, Serialize)] +pub struct SearchRunStats { + pub num_threads: usize, + pub num_queries: usize, + pub search_n: usize, + pub search_l: usize, + pub recall: recall::RecallMetrics, + pub qps: Vec, + pub wall_clock_time: Vec, + pub mean_latency: f64, + pub p90_latency: MicroSeconds, + pub p99_latency: MicroSeconds, + pub mean_cmps: f32, + pub mean_hops: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub per_query_details: Option>, +} + +#[derive(Debug, Serialize)] +pub struct DocumentIndexStats { + pub num_vectors: usize, + pub dim: usize, + pub label_count: usize, + pub data_load_time: MicroSeconds, + pub label_load_time: MicroSeconds, + pub build_time: MicroSeconds, + pub insert_latencies: percentiles::Percentiles, + pub build_params: BuildParamsStats, + pub search: Vec, +} + +impl std::fmt::Display for DocumentIndexStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Document Index Build Stats:")?; + writeln!(f, " Vectors: {} x {}", self.num_vectors, self.dim)?; + writeln!(f, " Label Count: {}", self.label_count)?; + writeln!( + f, + " Data Load Time: {} s", + self.data_load_time.as_seconds() + )?; + writeln!( + f, + " Label Load Time: {} s", + self.label_load_time.as_seconds() + )?; + writeln!(f, " Total Build Time: {} s", self.build_time.as_seconds())?; + writeln!(f, " Insert Latencies:")?; + writeln!(f, " Mean: {} us", self.insert_latencies.mean)?; + writeln!(f, " P50: {} us", self.insert_latencies.median)?; + writeln!(f, " P90: {} us", self.insert_latencies.p90)?; + writeln!(f, " P99: {} us", self.insert_latencies.p99)?; + writeln!(f, " Build Parameters:")?; + writeln!(f, " max_degree (R): {}", self.build_params.max_degree)?; + writeln!(f, " l_build (L): {}", self.build_params.l_build)?; + writeln!(f, " alpha: {}", self.build_params.alpha)?; + + if !self.search.is_empty() { + writeln!(f, "\nFiltered Search Results:")?; + writeln!( + f, + " {:>8} {:>8} {:>10} {:>10} {:>15} {:>12} {:>12} {:>10} {:>8} {:>10} {:>12}", + "L", "KNN", "Avg Cmps", "Avg Hops", "QPS -mean(max)", "Avg Latency", "p99 Latency", "Recall", "Threads", "Queries", "WallClock(s)" + )?; + for s in &self.search { + let mean_qps = if s.qps.is_empty() { + 0.0 + } else { + s.qps.iter().sum::() / s.qps.len() as f64 + }; + let max_qps = s.qps.iter().cloned().fold(0.0_f64, f64::max); + let mean_wall_clock = if s.wall_clock_time.is_empty() { + 0.0 + } else { + s.wall_clock_time.iter().map(|t| t.as_seconds()).sum::() / s.wall_clock_time.len() as f64 + }; + writeln!( + f, + " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8} {:>10} {:>12.3}", + s.search_l, + s.search_n, + s.mean_cmps, + s.mean_hops, + mean_qps, + max_qps, + s.mean_latency, + s.p99_latency, + s.recall.average, + s.num_threads, + s.num_queries, + mean_wall_clock + )?; + } + } + Ok(()) + } +} + +// ================================ +// Parallel Build Support +// ================================ + +fn make_progress_bar( + nrows: usize, + draw_target: indicatif::ProgressDrawTarget, +) -> anyhow::Result { + let progress = ProgressBar::with_draw_target(Some(nrows as u64), draw_target); + progress.set_style(ProgressStyle::with_template( + "Building [{elapsed_precise}] {wide_bar} {percent}", + )?); + Ok(progress) +} + +/// Control block for parallel document insertion. +/// Manages work distribution and progress tracking across multiple tasks. +struct DocumentControlBlock { + data: Arc>, + attributes: Arc>>, + position: AtomicUsize, + cancel: AtomicBool, + progress: ProgressBar, +} + +impl DocumentControlBlock { + fn new( + data: Arc>, + attributes: Arc>>, + draw_target: indicatif::ProgressDrawTarget, + ) -> anyhow::Result> { + let nrows = data.nrows(); + Ok(Arc::new(Self { + data, + attributes, + position: AtomicUsize::new(0), + cancel: AtomicBool::new(false), + progress: make_progress_bar(nrows, draw_target)?, + })) + } + + /// Return the next document data to insert: (id, vector_slice, attributes). + fn next(&self) -> Option<(usize, &[T], Vec)> { + let cancel = self.cancel.load(Ordering::Relaxed); + if cancel { + None + } else { + let i = self.position.fetch_add(1, Ordering::Relaxed); + match self.data.get_row(i) { + Some(row) => { + let attrs = self.attributes.get(i).cloned().unwrap_or_default(); + self.progress.inc(1); + Some((i, row, attrs)) + } + None => None, + } + } + } + + /// Tell all users of the control block to cancel and return early. + fn cancel(&self) { + self.cancel.store(true, Ordering::Relaxed); + } +} + +impl Drop for DocumentControlBlock { + fn drop(&mut self) { + self.progress.finish(); + } +} diff --git a/diskann-benchmark/src/backend/document_index/mod.rs b/diskann-benchmark/src/backend/document_index/mod.rs new file mode 100644 index 000000000..9937590cc --- /dev/null +++ b/diskann-benchmark/src/backend/document_index/mod.rs @@ -0,0 +1,13 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Backend benchmark implementation for document index with label filters. +//! +//! This benchmark tests the DocumentInsertStrategy which enables inserting +//! Document objects (vector + attributes) into a DiskANN index. + +mod benchmark; + +pub(crate) use benchmark::register_benchmarks; diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index c7e2ab75c..21d74f915 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -109,6 +109,7 @@ impl std::fmt::Display for AggregatedSearchResults { #[derive(Debug, Serialize)] pub(super) struct SearchResults { pub(super) num_tasks: usize, + pub(super) num_queries: usize, pub(super) search_n: usize, pub(super) search_l: usize, pub(super) qps: Vec, @@ -143,6 +144,7 @@ impl SearchResults { Self { num_tasks: setup.tasks.into(), + num_queries: recall.num_queries, search_n: parameters.k_value, search_l: parameters.l_value, qps, @@ -182,6 +184,8 @@ where "p99 Latency", "Recall", "Threads", + "Queries", + "WallClock(s)", ] } else { &[ @@ -194,6 +198,8 @@ where "p99 Latency", "Recall", "Threads", + "Queries", + "WallClock(s)", ] }; @@ -237,6 +243,13 @@ where ); row.insert(format!("{:3}", r.recall.average), col_idx + 7); row.insert(r.num_tasks, col_idx + 8); + row.insert(r.num_queries, col_idx + 9); + let mean_wall_clock = if r.search_latencies.is_empty() { + 0.0 + } else { + r.search_latencies.iter().map(|t| t.as_seconds()).sum::() / r.search_latencies.len() as f64 + }; + row.insert(format!("{:.3}", mean_wall_clock), col_idx + 10); }); write!(f, "{}", table) diff --git a/diskann-benchmark/src/backend/mod.rs b/diskann-benchmark/src/backend/mod.rs index 24fe91d7e..5dc1967de 100644 --- a/diskann-benchmark/src/backend/mod.rs +++ b/diskann-benchmark/src/backend/mod.rs @@ -4,6 +4,7 @@ */ mod disk_index; +mod document_index; mod exhaustive; mod filters; mod index; @@ -13,4 +14,5 @@ pub(crate) fn register_benchmarks(registry: &mut diskann_benchmark_runner::regis disk_index::register_benchmarks(registry); index::register_benchmarks(registry); filters::register_benchmarks(registry); + document_index::register_benchmarks(registry); } diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs new file mode 100644 index 000000000..b1a36e48a --- /dev/null +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -0,0 +1,177 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Input types for document index benchmarks using DocumentInsertStrategy. + +use std::num::NonZeroUsize; + +use anyhow::Context; +use diskann_benchmark_runner::{ + files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, +}; +use serde::{Deserialize, Serialize}; + +use super::async_::GraphSearch; +use crate::inputs::{as_input, Example, Input}; + +////////////// +// Registry // +////////////// + +as_input!(DocumentIndexBuild); + +pub(super) fn register_inputs( + registry: &mut diskann_benchmark_runner::registry::Inputs, +) -> anyhow::Result<()> { + registry.register(Input::::new())?; + Ok(()) +} + +/// Build parameters for document index construction. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentBuildParams { + pub(crate) data_type: DataType, + pub(crate) data: InputFile, + pub(crate) data_labels: InputFile, + pub(crate) distance: crate::utils::SimilarityMeasure, + pub(crate) max_degree: usize, + pub(crate) l_build: usize, + pub(crate) alpha: f32, + #[serde(default = "default_num_threads")] + pub(crate) num_threads: usize, +} + +fn default_num_threads() -> usize { + 1 +} + +impl CheckDeserialization for DocumentBuildParams { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.data.check_deserialization(checker)?; + self.data_labels.check_deserialization(checker)?; + if self.max_degree == 0 { + return Err(anyhow::anyhow!("max_degree must be > 0")); + } + if self.l_build == 0 { + return Err(anyhow::anyhow!("l_build must be > 0")); + } + if self.alpha <= 0.0 { + return Err(anyhow::anyhow!("alpha must be > 0")); + } + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentSearchParams { + pub(crate) queries: InputFile, + pub(crate) query_predicates: InputFile, + pub(crate) groundtruth: InputFile, + pub(crate) beta: f32, + #[serde(default = "default_reps")] + pub(crate) reps: NonZeroUsize, + #[serde(default = "default_thread_counts")] + pub(crate) num_threads: Vec, + pub(crate) runs: Vec, +} + +fn default_reps() -> NonZeroUsize { + NonZeroUsize::new(5).unwrap() +} +fn default_thread_counts() -> Vec { + vec![NonZeroUsize::new(1).unwrap()] +} + +impl CheckDeserialization for DocumentSearchParams { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.check_deserialization(checker)?; + self.query_predicates.check_deserialization(checker)?; + self.groundtruth.check_deserialization(checker)?; + if self.beta <= 0.0 || self.beta > 1.0 { + return Err(anyhow::anyhow!( + "beta must be in range (0, 1], got: {}", + self.beta + )); + } + for (i, run) in self.runs.iter_mut().enumerate() { + run.check_deserialization(checker) + .with_context(|| format!("search run {}", i))?; + } + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentIndexBuild { + pub(crate) build: DocumentBuildParams, + pub(crate) search: DocumentSearchParams, +} + +impl DocumentIndexBuild { + pub(crate) const fn tag() -> &'static str { + "document-index-build" + } +} + +impl CheckDeserialization for DocumentIndexBuild { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.build.check_deserialization(checker)?; + self.search.check_deserialization(checker)?; + Ok(()) + } +} + +impl Example for DocumentIndexBuild { + fn example() -> Self { + Self { + build: DocumentBuildParams { + data_type: DataType::Float32, + data: InputFile::new("data.fbin"), + data_labels: InputFile::new("data.label.jsonl"), + distance: crate::utils::SimilarityMeasure::SquaredL2, + max_degree: 32, + l_build: 50, + alpha: 1.2, + num_threads: 1, + }, + search: DocumentSearchParams { + queries: InputFile::new("queries.fbin"), + query_predicates: InputFile::new("query.label.jsonl"), + groundtruth: InputFile::new("groundtruth.bin"), + beta: 0.5, + reps: default_reps(), + num_threads: default_thread_counts(), + runs: vec![GraphSearch { + search_n: 10, + search_l: vec![20, 30, 40, 50], + recall_k: 10, + }], + }, + } + } +} + +impl std::fmt::Display for DocumentIndexBuild { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Document Index Build with Label Filters\n")?; + writeln!(f, "tag: \"{}\"", Self::tag())?; + writeln!( + f, + "\nBuild: data={}, labels={}, R={}, L={}, alpha={}", + self.build.data.display(), + self.build.data_labels.display(), + self.build.max_degree, + self.build.l_build, + self.build.alpha + )?; + writeln!( + f, + "Search: queries={}, beta={}", + self.search.queries.display(), + self.search.beta + )?; + Ok(()) + } +} diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index a0ae1a982..65de65a41 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -5,6 +5,7 @@ pub(crate) mod async_; pub(crate) mod disk; +pub(crate) mod document_index; pub(crate) mod exhaustive; pub(crate) mod filters; pub(crate) mod save_and_load; @@ -16,6 +17,7 @@ pub(crate) fn register_inputs( exhaustive::register_inputs(registry)?; disk::register_inputs(registry)?; filters::register_inputs(registry)?; + document_index::register_inputs(registry)?; Ok(()) } diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index dcbe86d94..50ef7e430 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -3,6 +3,7 @@ * Licensed under the MIT license. */ +pub(crate) use benchmark_core::recall::knn; use diskann_benchmark_core as benchmark_core; use serde::Serialize; diff --git a/diskann-benchmark/src/utils/tokio.rs b/diskann-benchmark/src/utils/tokio.rs index 72dbeb918..21c78abb2 100644 --- a/diskann-benchmark/src/utils/tokio.rs +++ b/diskann-benchmark/src/utils/tokio.rs @@ -3,6 +3,13 @@ * Licensed under the MIT license. */ +/// Create a generic multi-threaded runtime with `num_threads`. +pub(crate) fn runtime(num_threads: usize) -> anyhow::Result { + Ok(tokio::runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .build()?) +} + /// Create a current-thread runtime and block on the given future. /// Only for functions that don't need multi-threading pub(crate) fn block_on(future: F) -> F::Output { diff --git a/diskann-label-filter/src/attribute.rs b/diskann-label-filter/src/attribute.rs index f0d99bfd9..9eb7ff500 100644 --- a/diskann-label-filter/src/attribute.rs +++ b/diskann-label-filter/src/attribute.rs @@ -5,7 +5,6 @@ use std::fmt::Display; use std::hash::{Hash, Hasher}; -use std::io::Write; use serde_json::Value; use thiserror::Error; diff --git a/diskann-label-filter/src/document.rs b/diskann-label-filter/src/document.rs index 31cad4772..5c817525c 100644 --- a/diskann-label-filter/src/document.rs +++ b/diskann-label-filter/src/document.rs @@ -8,12 +8,12 @@ use diskann_utils::reborrow::Reborrow; ///Simple container class that clients can use to /// supply diskann with a vector and its attributes -pub struct Document<'a, V> { +pub struct Document<'a, V: ?Sized> { vector: &'a V, attributes: Vec, } -impl<'a, V> Document<'a, V> { +impl<'a, V: ?Sized> Document<'a, V> { pub fn new(vector: &'a V, attributes: Vec) -> Self { Self { vector, attributes } } diff --git a/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs b/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs index 0fa21cc02..8b39d8731 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs @@ -31,19 +31,14 @@ impl ASTLabelIdMapper { Self { attribute_map } } - fn _lookup( - encoder: &AttributeEncoder, - attribute: &Attribute, - field: &str, - op: &CompareOp, - ) -> ANNResult> { + fn _lookup(encoder: &AttributeEncoder, attribute: &Attribute) -> ANNResult> { match encoder.get(attribute) { Some(attribute_id) => Ok(ASTIdExpr::Terminal(attribute_id)), None => Err(ANNError::message( ANNErrorKind::Opaque, format!( - "{}+{} present in the query does not exist in the dataset.", - field, op + "{} present in the query does not exist in the dataset.", + attribute ), )), } @@ -120,10 +115,10 @@ impl ASTVisitor for ASTLabelIdMapper { if let Some(attribute) = label_or_none { match self.attribute_map.read() { - Ok(guard) => Self::_lookup(&guard, &attribute, field, op), + Ok(guard) => Self::_lookup(&guard, &attribute), Err(poison_error) => { let attr_map = poison_error.into_inner(); - Self::_lookup(&attr_map, &attribute, field, op) + Self::_lookup(&attr_map, &attribute) } } } else { diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs new file mode 100644 index 000000000..850976a32 --- /dev/null +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -0,0 +1,274 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +//! A strategy wrapper that enables insertion of [Document] objects into a +//! [DiskANNIndex] using a [DocumentProvider]. + +use std::marker::PhantomData; + +use diskann::{ + graph::{ + glue::{ + ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + }, + SearchOutputBuffer, + }, + neighbor::Neighbor, + provider::{Accessor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, + ANNResult, +}; + +use super::document_provider::DocumentProvider; +use crate::document::Document; +use crate::encoded_attribute_provider::roaring_attribute_store::RoaringAttributeStore; + +/// A strategy wrapper that enables insertion of [Document] objects. +pub struct DocumentInsertStrategy { + inner: Inner, + _phantom: PhantomData VT>, +} + +impl Clone for DocumentInsertStrategy { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + _phantom: PhantomData, + } + } +} + +impl Copy for DocumentInsertStrategy {} + +impl DocumentInsertStrategy { + pub fn new(inner: Inner) -> Self { + Self { + inner, + _phantom: PhantomData, + } + } + + pub fn inner(&self) -> &Inner { + &self.inner + } +} + +/// Wrapper accessor for Document queries +pub struct DocumentSearchAccessor { + inner: Inner, + _phantom: PhantomData VT>, +} + +impl DocumentSearchAccessor { + pub fn new(inner: Inner) -> Self { + Self { + inner, + _phantom: PhantomData, + } + } +} + +impl HasId for DocumentSearchAccessor +where + Inner: HasId, + VT: ?Sized, +{ + type Id = Inner::Id; +} + +impl Accessor for DocumentSearchAccessor +where + Inner: Accessor, + VT: ?Sized, +{ + type ElementRef<'a> = Inner::ElementRef<'a>; + type Element<'a> + = Inner::Element<'a> + where + Self: 'a; + type Extended = Inner::Extended; + type GetError = Inner::GetError; + + fn get_element( + &mut self, + id: Self::Id, + ) -> impl std::future::Future, Self::GetError>> + Send { + self.inner.get_element(id) + } + + fn on_elements_unordered( + &mut self, + itr: Itr, + f: F, + ) -> impl std::future::Future> + Send + where + Self: Sync, + Itr: Iterator + Send, + F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + { + self.inner.on_elements_unordered(itr, f) + } +} + +impl<'doc, Inner, VT> BuildQueryComputer> for DocumentSearchAccessor +where + Inner: BuildQueryComputer, + VT: ?Sized, +{ + type QueryComputerError = Inner::QueryComputerError; + type QueryComputer = Inner::QueryComputer; + + fn build_query_computer( + &self, + from: &Document<'doc, VT>, + ) -> Result { + self.inner.build_query_computer(from.vector()) + } +} + +impl<'this, Inner, VT> DelegateNeighbor<'this> for DocumentSearchAccessor +where + Inner: DelegateNeighbor<'this>, + VT: ?Sized, +{ + type Delegate = Inner::Delegate; + fn delegate_neighbor(&'this mut self) -> Self::Delegate { + self.inner.delegate_neighbor() + } +} + +impl<'doc, Inner, VT> ExpandBeam> for DocumentSearchAccessor +where + Inner: ExpandBeam, + VT: ?Sized, +{ +} + +impl SearchExt for DocumentSearchAccessor +where + Inner: SearchExt, + VT: ?Sized, +{ + fn starting_points( + &self, + ) -> impl std::future::Future>> + Send { + self.inner.starting_points() + } + fn terminate_early(&mut self) -> bool { + self.inner.terminate_early() + } +} + +#[derive(Debug, Default, Clone, Copy)] +pub struct CopyIdsForDocument; + +impl<'doc, A, VT> SearchPostProcess> for CopyIdsForDocument +where + A: BuildQueryComputer>, + VT: ?Sized, +{ + type Error = std::convert::Infallible; + + fn post_process( + &self, + _accessor: &mut A, + _query: &Document<'doc, VT>, + _computer: &>>::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl std::future::Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + let count = output.extend(candidates.map(|n| (n.id, n.distance))); + std::future::ready(Ok(count)) + } +} + +impl<'doc, Inner, DP, VT> + SearchStrategy>, Document<'doc, VT>> + for DocumentInsertStrategy +where + Inner: InsertStrategy, + DP: DataProvider, + VT: Sync + Send + ?Sized + 'static, +{ + type QueryComputer = Inner::QueryComputer; + type PostProcessor = CopyIdsForDocument; + type SearchAccessorError = Inner::SearchAccessorError; + type SearchAccessor<'a> = DocumentSearchAccessor, VT>; + + fn search_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::SearchAccessorError> { + let inner_accessor = self + .inner + .search_accessor(provider.inner_provider(), context)?; + Ok(DocumentSearchAccessor::new(inner_accessor)) + } + + fn post_processor(&self) -> Self::PostProcessor { + CopyIdsForDocument + } +} + +impl<'doc, Inner, DP, VT> + InsertStrategy>, Document<'doc, VT>> + for DocumentInsertStrategy +where + Inner: InsertStrategy, + DP: DataProvider, + VT: Sync + Send + ?Sized + 'static, +{ + type PruneStrategy = DocumentPruneStrategy; + + fn prune_strategy(&self) -> Self::PruneStrategy { + DocumentPruneStrategy::new(self.inner.prune_strategy()) + } + + fn insert_search_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::SearchAccessorError> { + let inner_accessor = self + .inner + .insert_search_accessor(provider.inner_provider(), context)?; + Ok(DocumentSearchAccessor::new(inner_accessor)) + } +} + +#[derive(Clone, Copy)] +pub struct DocumentPruneStrategy { + inner: Inner, +} + +impl DocumentPruneStrategy { + pub fn new(inner: Inner) -> Self { + Self { inner } + } +} + +impl PruneStrategy>> + for DocumentPruneStrategy +where + DP: DataProvider, + Inner: PruneStrategy, +{ + type DistanceComputer = Inner::DistanceComputer; + type PruneAccessor<'a> = Inner::PruneAccessor<'a>; + type PruneAccessorError = Inner::PruneAccessorError; + + fn prune_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::PruneAccessorError> { + self.inner + .prune_accessor(provider.inner_provider(), context) + } +} diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs index 6b496271b..1fabf5f54 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs @@ -77,7 +77,7 @@ impl<'a, VT, DP, AS> SetElement> for DocumentProvider where DP: DataProvider + Delete + SetElement, AS: AttributeStore + AsyncFriendly, - VT: Sync + Send, + VT: Sync + Send + ?Sized, { type SetError = ANNError; type Guard = >::Guard; diff --git a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs index d56cb13c1..370ef25ae 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs @@ -5,8 +5,6 @@ use std::sync::{Arc, RwLock}; -use diskann::ANNResult; - use crate::{ encoded_attribute_provider::{ ast_id_expr::ASTIdExpr, ast_label_id_mapper::ASTLabelIdMapper, @@ -16,20 +14,21 @@ use crate::{ }; pub(crate) struct EncodedFilterExpr { - ast_id_expr: ASTIdExpr, + ast_id_expr: Option>, } impl EncodedFilterExpr { - pub fn new( - ast_expr: &ASTExpr, - attribute_map: Arc>, - ) -> ANNResult { + pub fn new(ast_expr: &ASTExpr, attribute_map: Arc>) -> Self { let mut mapper = ASTLabelIdMapper::new(attribute_map); - let ast_id_expr = ast_expr.accept(&mut mapper)?; - Ok(Self { ast_id_expr }) + match ast_expr.accept(&mut mapper) { + Ok(ast_id_expr) => Self { + ast_id_expr: Some(ast_id_expr), + }, + Err(_e) => Self { ast_id_expr: None }, + } } - pub(crate) fn encoded_filter_expr(&self) -> &ASTIdExpr { + pub(crate) fn encoded_filter_expr(&self) -> &Option> { &self.ast_id_expr } } diff --git a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs index 6b82a68b1..c69589ba0 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs @@ -15,7 +15,7 @@ use diskann::{utils::VectorId, ANNError, ANNErrorKind, ANNResult}; use diskann_utils::future::AsyncFriendly; use std::sync::{Arc, RwLock}; -pub(crate) struct RoaringAttributeStore +pub struct RoaringAttributeStore where IT: VectorId + AsyncFriendly, { diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 962d361d7..1def9a406 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -28,7 +28,7 @@ use crate::{ type AttrAccessor = EncodedAttributeAccessor::Id>>; -pub(crate) struct EncodedDocumentAccessor +pub struct EncodedDocumentAccessor where IA: HasId, { @@ -136,7 +136,7 @@ where Some(set) => Ok(set.into_owned()), None => Err(ANNError::message( ANNErrorKind::IndexError, - "No labels were found for vector", + format!("No labels were found for vector:{:?}", id), )), } })?; @@ -220,12 +220,20 @@ where .inner_accessor .build_query_computer(from.query()) .into_ann_result()?; - let id_query = EncodedFilterExpr::new(from.filter_expr(), self.attribute_map.clone())?; + let id_query = EncodedFilterExpr::new(from.filter_expr(), self.attribute_map.clone()); + let is_valid_filter = id_query.encoded_filter_expr().is_some(); + if !is_valid_filter { + tracing::warn!( + "Failed to convert {} into an id expr. This will now be an unfiltered search.", + from.filter_expr() + ); + } Ok(InlineBetaComputer::new( inner_computer, self.beta_value, id_query, + is_valid_filter, )) } } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index b25b1746f..f03f36c12 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -28,6 +28,13 @@ pub struct InlineBetaStrategy { inner: Strategy, } +impl InlineBetaStrategy { + /// Create a new InlineBetaStrategy with the given beta value and inner strategy. + pub fn new(beta: f32, inner: Strategy) -> Self { + Self { beta, inner } + } +} + impl SearchStrategy>, FilteredQuery> for InlineBetaStrategy @@ -72,6 +79,7 @@ pub struct InlineBetaComputer { inner_computer: Inner, beta_value: f32, filter_expr: EncodedFilterExpr, + is_valid_filter: bool, //optimization to avoid evaluating empty predicates. } impl InlineBetaComputer { @@ -79,17 +87,23 @@ impl InlineBetaComputer { inner_computer: Inner, beta_value: f32, filter_expr: EncodedFilterExpr, + is_valid_filter: bool, ) -> Self { Self { inner_computer, beta_value, filter_expr, + is_valid_filter, } } pub(crate) fn filter_expr(&self) -> &EncodedFilterExpr { &self.filter_expr } + + pub(crate) fn is_valid_filter(&self) -> bool { + self.is_valid_filter + } } impl PreprocessedDistanceFunction, f32> @@ -101,22 +115,35 @@ where let (vec, attrs) = changing.destructure(); let sim = self.inner_computer.evaluate_similarity(vec); let pred_eval = PredicateEvaluator::new(attrs); - match self.filter_expr.encoded_filter_expr().accept(&pred_eval) { - Ok(matched) => { - if matched { - sim * self.beta_value - } else { - sim + if self.is_valid_filter { + match self + .filter_expr + .encoded_filter_expr() + .as_ref() + .unwrap() + .accept(&pred_eval) + { + Ok(matched) => { + if matched { + return sim * self.beta_value; + } else { + return sim; + } + } + Err(_) => { + //If predicate evaluation fails for any reason, we simply revert + //to unfiltered search. + tracing::warn!("Predicate evaluation failed"); + return sim; } } - Err(_) => { - //TODO: If predicate evaluation fails, we are taking the approach that we will simply - //return the score returned by the inner computer, as though no predicate was specified. - tracing::warn!( - "Predicate evaluation failed in OnlineBetaComputer::evaluate_similarity()" - ); - sim - } + } else { + //If predicate evaluation fails, we will return the score returned by the + //inner computer, as though no predicate was specified. + tracing::warn!( + "Predicate evaluation failed in OnlineBetaComputer::evaluate_similarity()" + ); + sim } } } @@ -155,8 +182,16 @@ where let doc = accessor.get_element(candidate.id).await?; let pe = PredicateEvaluator::new(doc.attributes()); - if computer.filter_expr().encoded_filter_expr().accept(&pe)? { - filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); + if computer.is_valid_filter() { + if computer + .filter_expr() + .encoded_filter_expr() + .as_ref() + .unwrap() + .accept(&pe)? + { + filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); + } } } diff --git a/diskann-label-filter/src/lib.rs b/diskann-label-filter/src/lib.rs index 106845f98..273475b15 100644 --- a/diskann-label-filter/src/lib.rs +++ b/diskann-label-filter/src/lib.rs @@ -40,6 +40,7 @@ pub mod encoded_attribute_provider { pub(crate) mod ast_id_expr; pub(crate) mod ast_label_id_mapper; pub(crate) mod attribute_encoder; + pub mod document_insert_strategy; pub mod document_provider; pub mod encoded_attribute_accessor; pub(crate) mod encoded_filter_expr; diff --git a/diskann-label-filter/src/parser/format.rs b/diskann-label-filter/src/parser/format.rs index 5e9e3a9c1..c042d8338 100644 --- a/diskann-label-filter/src/parser/format.rs +++ b/diskann-label-filter/src/parser/format.rs @@ -15,10 +15,8 @@ pub struct Document { /// label in raw json format #[serde(flatten)] pub label: serde_json::Value, - } - /// Represents a query expression as defined in the RFC. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueryExpression { diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index e74419a46..9a48488fe 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -1,580 +1,638 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::{collections::HashMap, fmt::Debug, future::Future}; - -use diskann::{ - ANNError, ANNResult, - graph::{ - SearchOutputBuffer, - glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, - }, - }, - neighbor::Neighbor, - provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - ExecutionContext, HasId, - }, - utils::{IntoUsize, VectorRepr}, -}; -use diskann_utils::future::AsyncFriendly; -use diskann_vector::{DistanceFunction, distance::Metric}; - -use crate::model::graph::{ - provider::async_::{ - FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, - common::{ - CreateVectorStore, FullPrecision, Internal, NoDeletes, NoStore, Panics, - PrefetchCacheLineLevel, SetElementHelper, - }, - inmem::DefaultProvider, - postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, - }, - traits::AdHoc, -}; - -/// A type alias for the DefaultProvider with full-precision as the primary vector store. -pub type FullPrecisionProvider = - DefaultProvider, Q, D, Ctx>; - -/// The default full-precision vector store. -pub type FullPrecisionStore = FastMemoryVectorProviderAsync>; - -/// A default full-precision vector store provider. -#[derive(Clone)] -pub struct CreateFullPrecision { - dim: usize, - prefetch_cache_line_level: Option, - _phantom: std::marker::PhantomData, -} - -impl CreateFullPrecision -where - T: VectorRepr, -{ - /// Create a new full-precision vector store provider. - pub fn new(dim: usize, prefetch_cache_line_level: Option) -> Self { - Self { - dim, - prefetch_cache_line_level, - _phantom: std::marker::PhantomData, - } - } -} - -impl CreateVectorStore for CreateFullPrecision -where - T: VectorRepr, -{ - type Target = FullPrecisionStore; - fn create( - self, - max_points: usize, - metric: Metric, - prefetch_lookahead: Option, - ) -> Self::Target { - FullPrecisionStore::new( - max_points, - self.dim, - metric, - self.prefetch_cache_line_level, - prefetch_lookahead, - ) - } -} - -//////////////// -// SetElement // -//////////////// - -impl SetElementHelper for FullPrecisionStore -where - T: VectorRepr, -{ - /// Set the element at the given index. - fn set_element(&self, id: &u32, element: &[T]) -> Result<(), ANNError> { - unsafe { self.set_vector_sync(id.into_usize(), element) } - } -} - -////////////////// -// FullAccessor // -////////////////// - -/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. -/// -/// This type implements the following traits: -/// -/// * [`Accessor`] for the [`DefaultProvider`]. -/// * [`ComputerAccessor`] for comparing full-precision distances. -/// * [`BuildQueryComputer`]. -pub struct FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, -{ - /// The host provider. - provider: &'a FullPrecisionProvider, - - /// A buffer for resolving iterators given during bulk operations. - /// - /// The accessor reuses this allocation to amortize allocation cost over multiple bulk - /// operations. - id_buffer: Vec, -} - -impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, -{ - type Repr = T; - fn as_full_precision(&self) -> &FullPrecisionStore { - &self.provider.base_vectors - } -} - -impl HasId for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, -{ - type Id = u32; -} - -impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - fn starting_points(&self) -> impl Future>> { - std::future::ready(self.provider.starting_points()) - } -} - -impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - pub fn new(provider: &'a FullPrecisionProvider) -> Self { - Self { - provider, - id_buffer: Vec::new(), - } - } -} - -impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Delegate = &'a SimpleNeighborProviderAsync; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() - } -} - -impl<'a, T, Q, D, Ctx> Accessor for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - /// The extended element inherets the lifetime of the Accessor. - type Extended = &'a [T]; - - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - /// - /// NOTE: We intentionally don't use `'b` here since our implementation borrows - /// the inner `Opaque` from the underlying provider. - type Element<'b> - = &'a [T] - where - Self: 'b; - - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'b> = &'b [T]; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = Panics; - - /// Return the full-precision vector stored at index `i`. - /// - /// This function always completes synchronously. - #[inline(always)] - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB (undefined behavior) that can result from - // potentially mixing unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - self.provider.base_vectors.get_vector_sync(id.into_usize()) - })) - } - - /// Perform a bulk operation. - /// - /// This implementation uses prefetching. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - { - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(itr); - - let len = id_buffer.len(); - let lookahead = self.provider.base_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.base_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .base_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } - - // Invoke the passed closure on the full-precision vector. - // - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - f( - unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, - *id, - ) - } - - std::future::ready(Ok(())) - } -} - -impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputerError = Panics; - type DistanceComputer = T::Distance; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(T::distance( - self.provider.metric, - Some(self.provider.base_vectors.dim()), - )) - } -} - -impl BuildQueryComputer<[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; - - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - Ok(T::query_distance(from, self.provider.metric)) - } -} - -impl ExpandBeam<[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ -} - -impl FillSet for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - async fn fill_set( - &mut self, - set: &mut HashMap, - itr: Itr, - ) -> Result<(), Self::GetError> - where - Itr: Iterator + Send + Sync, - { - for i in itr { - set.entry(i).or_insert_with(|| unsafe { - self.provider.base_vectors.get_vector_sync(i.into_usize()) - }); - } - Ok(()) - } -} - -//-------------------// -// In-mem Extensions // -//-------------------// - -impl<'a, T, Q, D, Ctx> AsDeletionCheck for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type Checker = D; - fn as_deletion_check(&self) -> &D { - &self.provider.deleted - } -} - -////////////////// -// Post Process // -////////////////// - -pub trait GetFullPrecision { - type Repr: VectorRepr; - fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync>; -} - -/// A [`SearchPostProcess`]or that: -/// -/// 1. Filters out deleted ids from being returned. -/// 2. Reranks a candidate stream using full-precision distances. -/// 3. Copies back the results to the output buffer. -#[derive(Debug, Default, Clone, Copy)] -pub struct Rerank; - -impl glue::SearchPostProcess for Rerank -where - T: VectorRepr, - A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, -{ - type Error = Panics; - - fn post_process( - &self, - accessor: &mut A, - query: &[T], - _computer: &A::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator>, - B: SearchOutputBuffer + ?Sized, - { - let full = accessor.as_full_precision(); - let checker = accessor.as_deletion_check(); - let f = full.distance(); - - // Filter before computing the full precision distances. - let mut reranked: Vec<(u32, f32)> = candidates - .filter_map(|n| { - if checker.deletion_check(n.id) { - None - } else { - Some(( - n.id, - f.evaluate_similarity(query, unsafe { - full.get_vector_sync(n.id.into_usize()) - }), - )) - } - }) - .collect(); - - // Sort the full precision distances. - reranked - .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); - // Store the reranked results. - std::future::ready(Ok(output.extend(reranked))) - } -} - -//////////////// -// Strategies // -//////////////// - -// A layered approach is used for search strategies. The `Internal` version does the heavy -// lifting in terms of establishing accessors and post processing. -// -// However, during post-processing, the `Internal` versions of strategies will not filter -// out the start points. The publicly exposed types *will* filter out the start points. -// -// This layered approach allows algorithms like `InplaceDeleteStrategy` that need to adjust -// the adjacency list for the start point to reuse the `Internal` strategies. - -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> - for Internal -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = RemoveDeletedIdsAndCopy; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = glue::Pipeline; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -// Pruning -impl PruneStrategy> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputer = T::Distance; - type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type PruneAccessorError = diskann::error::Infallible; - - fn prune_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider)) - } -} - -/// Implementing this trait allows `FullPrecision` to be used for multi-insert. -impl<'a, T, Q, D, Ctx> glue::AsElement<&'a [T]> for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Error = diskann::error::Infallible; - fn as_element( - &mut self, - vector: &'a [T], - _id: Self::Id, - ) -> impl Future, Self::Error>> + Send { - std::future::ready(Ok(vector)) - } -} - -impl InsertStrategy, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type PruneStrategy = Self; - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } -} - -// Inplace Delete // -impl InplaceDeleteStrategy> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type DeleteElementError = Panics; - type DeleteElement<'a> = [T]; - type DeleteElementGuard = Box<[T]>; - type PruneStrategy = Self; - type SearchStrategy = Internal; - fn search_strategy(&self) -> Self::SearchStrategy { - Internal(Self) - } - - fn prune_strategy(&self) -> Self::PruneStrategy { - Self - } - - async fn get_delete_element<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - id: u32, - ) -> Result { - Ok(unsafe { provider.base_vectors.get_vector_sync(id.into_usize()) }.into()) - } -} +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{collections::HashMap, fmt::Debug, future::Future}; + +use diskann::{ + ANNError, ANNResult, + graph::{ + SearchOutputBuffer, + glue::{ + self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchExt, SearchStrategy, + }, + }, + neighbor::Neighbor, + provider::{ + Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, + ExecutionContext, HasId, + }, + utils::{IntoUsize, VectorRepr}, +}; +use diskann_utils::future::AsyncFriendly; +use diskann_vector::{DistanceFunction, distance::Metric}; + +use crate::model::graph::{ + provider::async_::{ + FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, + common::{ + CreateVectorStore, FullPrecision, Internal, NoDeletes, NoStore, Panics, + PrefetchCacheLineLevel, SetElementHelper, + }, + inmem::DefaultProvider, + postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, + }, + traits::AdHoc, +}; + +/// A type alias for the DefaultProvider with full-precision as the primary vector store. +pub type FullPrecisionProvider = + DefaultProvider, Q, D, Ctx>; + +/// The default full-precision vector store. +pub type FullPrecisionStore = FastMemoryVectorProviderAsync>; + +/// A default full-precision vector store provider. +#[derive(Clone)] +pub struct CreateFullPrecision { + dim: usize, + prefetch_cache_line_level: Option, + _phantom: std::marker::PhantomData, +} + +impl CreateFullPrecision +where + T: VectorRepr, +{ + /// Create a new full-precision vector store provider. + pub fn new(dim: usize, prefetch_cache_line_level: Option) -> Self { + Self { + dim, + prefetch_cache_line_level, + _phantom: std::marker::PhantomData, + } + } +} + +impl CreateVectorStore for CreateFullPrecision +where + T: VectorRepr, +{ + type Target = FullPrecisionStore; + fn create( + self, + max_points: usize, + metric: Metric, + prefetch_lookahead: Option, + ) -> Self::Target { + FullPrecisionStore::new( + max_points, + self.dim, + metric, + self.prefetch_cache_line_level, + prefetch_lookahead, + ) + } +} + +//////////////// +// SetElement // +//////////////// + +impl SetElementHelper for FullPrecisionStore +where + T: VectorRepr, +{ + /// Set the element at the given index. + fn set_element(&self, id: &u32, element: &[T]) -> Result<(), ANNError> { + unsafe { self.set_vector_sync(id.into_usize(), element) } + } +} + +////////////////// +// FullAccessor // +////////////////// + +/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. +/// +/// This type implements the following traits: +/// +/// * [`Accessor`] for the [`DefaultProvider`]. +/// * [`ComputerAccessor`] for comparing full-precision distances. +/// * [`BuildQueryComputer`]. +pub struct FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, +{ + /// The host provider. + provider: &'a FullPrecisionProvider, + + /// A buffer for resolving iterators given during bulk operations. + /// + /// The accessor reuses this allocation to amortize allocation cost over multiple bulk + /// operations. + id_buffer: Vec, +} + +impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, +{ + type Repr = T; + fn as_full_precision(&self) -> &FullPrecisionStore { + &self.provider.base_vectors + } +} + +impl HasId for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, +{ + type Id = u32; +} + +impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + fn starting_points(&self) -> impl Future>> { + std::future::ready(self.provider.starting_points()) + } +} + +impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + pub fn new(provider: &'a FullPrecisionProvider) -> Self { + Self { + provider, + id_buffer: Vec::new(), + } + } +} + +impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type Delegate = &'a SimpleNeighborProviderAsync; + + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + self.provider.neighbors() + } +} + +impl<'a, T, Q, D, Ctx> Accessor for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + /// The extended element inherets the lifetime of the Accessor. + type Extended = &'a [T]; + + /// This accessor returns raw slices. There *is* a chance of racing when the fast + /// providers are used. We just have to live with it. + /// + /// NOTE: We intentionally don't use `'b` here since our implementation borrows + /// the inner `Opaque` from the underlying provider. + type Element<'b> + = &'a [T] + where + Self: 'b; + + /// `ElementRef` has an arbitrarily short lifetime. + type ElementRef<'b> = &'b [T]; + + /// Choose to panic on an out-of-bounds access rather than propagate an error. + type GetError = Panics; + + /// Return the full-precision vector stored at index `i`. + /// + /// This function always completes synchronously. + #[inline(always)] + fn get_element( + &mut self, + id: Self::Id, + ) -> impl Future, Self::GetError>> + Send { + // SAFETY: We've decided to live with UB (undefined behavior) that can result from + // potentially mixing unsynchronized reads and writes on the underlying memory. + std::future::ready(Ok(unsafe { + self.provider.base_vectors.get_vector_sync(id.into_usize()) + })) + } + + /// Perform a bulk operation. + /// + /// This implementation uses prefetching. + fn on_elements_unordered( + &mut self, + itr: Itr, + mut f: F, + ) -> impl Future> + Send + where + Self: Sync, + Itr: Iterator + Send, + F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + { + // Reuse the internal buffer to collect the results and give us random access + // capabilities. + let id_buffer = &mut self.id_buffer; + id_buffer.clear(); + id_buffer.extend(itr); + + let len = id_buffer.len(); + let lookahead = self.provider.base_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.base_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .base_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + // Invoke the passed closure on the full-precision vector. + // + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + f( + unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, + *id, + ) + } + + std::future::ready(Ok(())) + } +} + +impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type DistanceComputerError = Panics; + type DistanceComputer = T::Distance; + + fn build_distance_computer( + &self, + ) -> Result { + Ok(T::distance( + self.provider.metric, + Some(self.provider.base_vectors.dim()), + )) + } +} + +impl BuildQueryComputer<[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type QueryComputerError = Panics; + type QueryComputer = T::QueryDistance; + + fn build_query_computer( + &self, + from: &[T], + ) -> Result { + Ok(T::query_distance(from, self.provider.metric)) + } +} + +impl ExpandBeam<[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + +/// Support for Vec queries that delegates to the [T] impl via deref. +/// This allows InlineBetaStrategy to use Vec queries with FullAccessor. +impl BuildQueryComputer> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type QueryComputerError = Panics; + type QueryComputer = T::QueryDistance; + + fn build_query_computer( + &self, + from: &Vec, + ) -> Result { + // Delegate to [T] impl via deref + Ok(T::query_distance(from.as_slice(), self.provider.metric)) + } +} + +/// Support for Vec queries that delegates to the [T] impl. +impl ExpandBeam> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr + Clone, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + +impl FillSet for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + async fn fill_set( + &mut self, + set: &mut HashMap, + itr: Itr, + ) -> Result<(), Self::GetError> + where + Itr: Iterator + Send + Sync, + { + for i in itr { + set.entry(i).or_insert_with(|| unsafe { + self.provider.base_vectors.get_vector_sync(i.into_usize()) + }); + } + Ok(()) + } +} + +//-------------------// +// In-mem Extensions // +//-------------------// + +impl<'a, T, Q, D, Ctx> AsDeletionCheck for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type Checker = D; + fn as_deletion_check(&self) -> &D { + &self.provider.deleted + } +} + +////////////////// +// Post Process // +////////////////// + +pub trait GetFullPrecision { + type Repr: VectorRepr; + fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync>; +} + +/// A [`SearchPostProcess`]or that: +/// +/// 1. Filters out deleted ids from being returned. +/// 2. Reranks a candidate stream using full-precision distances. +/// 3. Copies back the results to the output buffer. +#[derive(Debug, Default, Clone, Copy)] +pub struct Rerank; + +impl glue::SearchPostProcess for Rerank +where + T: VectorRepr, + A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, +{ + type Error = Panics; + + fn post_process( + &self, + accessor: &mut A, + query: &[T], + _computer: &A::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator>, + B: SearchOutputBuffer + ?Sized, + { + let full = accessor.as_full_precision(); + let checker = accessor.as_deletion_check(); + let f = full.distance(); + + // Filter before computing the full precision distances. + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if checker.deletion_check(n.id) { + None + } else { + Some(( + n.id, + f.evaluate_similarity(query, unsafe { + full.get_vector_sync(n.id.into_usize()) + }), + )) + } + }) + .collect(); + + // Sort the full precision distances. + reranked + .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + // Store the reranked results. + std::future::ready(Ok(output.extend(reranked))) + } +} + +//////////////// +// Strategies // +//////////////// + +// A layered approach is used for search strategies. The `Internal` version does the heavy +// lifting in terms of establishing accessors and post processing. +// +// However, during post-processing, the `Internal` versions of strategies will not filter +// out the start points. The publicly exposed types *will* filter out the start points. +// +// This layered approach allows algorithms like `InplaceDeleteStrategy` that need to adjust +// the adjacency list for the start point to reuse the `Internal` strategies. + +/// Perform a search entirely in the full-precision space. +/// +/// Starting points are not filtered out of the final results. +impl SearchStrategy, [T]> + for Internal +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = RemoveDeletedIdsAndCopy; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +/// Perform a search entirely in the full-precision space. +/// +/// Starting points are not filtered out of the final results. +impl SearchStrategy, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = glue::Pipeline; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +/// Support for Vec queries that delegates to the [T] impl. +/// This allows InlineBetaStrategy to use Vec queries with FullPrecision. +impl SearchStrategy, Vec> for FullPrecision +where + T: VectorRepr + Clone, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = glue::Pipeline; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +// Pruning +impl PruneStrategy> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type DistanceComputer = T::Distance; + type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type PruneAccessorError = diskann::error::Infallible; + + fn prune_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::PruneAccessorError> { + Ok(FullAccessor::new(provider)) + } +} + +/// Implementing this trait allows `FullPrecision` to be used for multi-insert. +impl<'a, T, Q, D, Ctx> glue::AsElement<&'a [T]> for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type Error = diskann::error::Infallible; + fn as_element( + &mut self, + vector: &'a [T], + _id: Self::Id, + ) -> impl Future, Self::Error>> + Send { + std::future::ready(Ok(vector)) + } +} + +impl InsertStrategy, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type PruneStrategy = Self; + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } +} + +// Inplace Delete // +impl InplaceDeleteStrategy> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type DeleteElementError = Panics; + type DeleteElement<'a> = [T]; + type DeleteElementGuard = Box<[T]>; + type PruneStrategy = Self; + type SearchStrategy = Internal; + fn search_strategy(&self) -> Self::SearchStrategy { + Internal(Self) + } + + fn prune_strategy(&self) -> Self::PruneStrategy { + Self + } + + async fn get_delete_element<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + id: u32, + ) -> Result { + Ok(unsafe { provider.base_vectors.get_vector_sync(id.into_usize()) }.into()) + } +} diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index 31e69b2b2..8c2fa29f6 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -32,14 +32,14 @@ use crate::utils::{search_index_utils, CMDResult, CMDToolError}; /// Expands a JSON object with array-valued fields into multiple objects with scalar values. /// For example: {"country": ["AU", "NZ"], "year": 2007} /// becomes: [{"country": "AU", "year": 2007}, {"country": "NZ", "year": 2007}] -/// +/// /// If multiple fields have arrays, all combinations are generated. fn expand_array_fields(value: &Value) -> Vec { match value { Value::Object(map) => { // Start with a single empty object let mut results: Vec> = vec![Map::new()]; - + for (key, val) in map.iter() { if let Value::Array(arr) = val { // Expand: for each existing result, create copies for each array element @@ -62,7 +62,7 @@ fn expand_array_fields(value: &Value) -> Vec { } } } - + results.into_iter().map(Value::Object).collect() } // If not an object, return as-is @@ -74,7 +74,9 @@ fn expand_array_fields(value: &Value) -> Vec { /// Returns true if any expanded variant matches the query. fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { let expanded = expand_array_fields(label); - expanded.iter().any(|item| eval_query_expr(query_expr, item)) + expanded + .iter() + .any(|item| eval_query_expr(query_expr, item)) } pub fn read_labels_and_compute_bitmap( @@ -127,11 +129,13 @@ pub fn read_labels_and_compute_bitmap( // Handle case where base_label.label is an array - check if any element matches // Also expand array-valued fields within objects (e.g., {"country": ["AU", "NZ"]}) let matches = if let Some(array) = base_label.label.as_array() { - array.iter().any(|item| eval_query_with_array_expansion(query_expr, item)) + array + .iter() + .any(|item| eval_query_with_array_expansion(query_expr, item)) } else { eval_query_with_array_expansion(query_expr, &base_label.label) }; - + if matches { bitmap.insert(base_label.doc_id); } @@ -164,11 +168,17 @@ pub fn read_labels_and_compute_bitmap( // If no matches, print more diagnostic info if total_matches == 0 { tracing::warn!("WARNING: No base vectors matched any query filters!"); - tracing::warn!("This could indicate a format mismatch between base labels and query filters."); - + tracing::warn!( + "This could indicate a format mismatch between base labels and query filters." + ); + // Try to identify what keys exist in base labels vs queries if let Some(first_label) = base_labels.first() { - tracing::warn!("First base label (full): doc_id={}, label={}", first_label.doc_id, first_label.label); + tracing::warn!( + "First base label (full): doc_id={}, label={}", + first_label.doc_id, + first_label.label + ); } } @@ -323,7 +333,7 @@ pub fn compute_ground_truth_from_datafiles< for (query_idx, npq) in ground_truth.iter().enumerate() { let neighbors: Vec<_> = npq.iter().collect(); let neighbor_count = neighbors.len(); - + if query_idx < 10 { // Print top K IDs and distances for first 10 queries let top_ids: Vec = neighbors.iter().take(10).map(|n| n.id).collect(); @@ -336,7 +346,7 @@ pub fn compute_ground_truth_from_datafiles< top_dists ); } - + if neighbor_count == 0 { tracing::warn!("Query {} has 0 neighbors in ground truth!", query_idx); } @@ -344,7 +354,10 @@ pub fn compute_ground_truth_from_datafiles< // Summary stats let total_neighbors: usize = ground_truth.iter().map(|npq| npq.iter().count()).sum(); - let queries_with_neighbors = ground_truth.iter().filter(|npq| npq.iter().count() > 0).count(); + let queries_with_neighbors = ground_truth + .iter() + .filter(|npq| npq.iter().count() > 0) + .count(); tracing::info!( "Ground truth summary: {} total neighbors, {} queries have neighbors, {} queries have 0 neighbors", total_neighbors, diff --git a/test_data/disk_index_search/data.256.label.jsonl b/test_data/disk_index_search/data.256.label.jsonl index 83254af7b..a99cde8e2 100644 --- a/test_data/disk_index_search/data.256.label.jsonl +++ b/test_data/disk_index_search/data.256.label.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7f8b6b99ca32173557689712d3fb5da30c5e4111130fd2accbccf32f5ce3e47e -size 17702 +oid sha256:92576896b10780a2cd80a16030f8384610498b76453f57fadeacb854379e0acf +size 17701 From 98ad4f7d28b3a77e613105ecab3fc3b6d162c872 Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Tue, 17 Feb 2026 11:57:11 +0530 Subject: [PATCH 04/50] Removing unnecessary stats --- .../src/backend/document_index/benchmark.rs | 118 ++---------------- diskann-benchmark/src/backend/index/result.rs | 13 -- 2 files changed, 12 insertions(+), 119 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index dffe669ff..ed2915974 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -586,57 +586,6 @@ where let recall_metrics: recall::RecallMetrics = (&recall::knn(groundtruth, None, &merged.ids, recall_k, search_n, false)?).into(); - // Compute per-query details (only for queries with recall < 1) - let per_query_details: Vec = (0..num_queries) - .filter_map(|query_idx| { - let result_ids: Vec = merged - .ids - .row(query_idx) - .iter() - .copied() - .filter(|&id| id != u32::MAX) - .collect(); - let result_distances: Vec = merged - .distances - .get(query_idx) - .map(|d| d.iter().copied().filter(|&dist| dist != f32::MAX).collect()) - .unwrap_or_default(); - // Only keep top 20 from ground truth - let gt_ids: Vec = groundtruth - .get(query_idx) - .map(|gt| gt.iter().take(20).copied().collect()) - .unwrap_or_default(); - - // Compute per-query recall: intersection of result_ids with gt_ids / recall_k - let result_set: std::collections::HashSet = result_ids.iter().copied().collect(); - let gt_set: std::collections::HashSet = - gt_ids.iter().take(recall_k).copied().collect(); - let intersection = result_set.intersection(>_set).count(); - let per_query_recall = if gt_set.is_empty() { - 1.0 - } else { - intersection as f64 / gt_set.len() as f64 - }; - - // Only include queries with imperfect recall - if per_query_recall >= 1.0 { - return None; - } - - let (_, ref ast_expr) = predicates[query_idx]; - let filter_str = format!("{:?}", ast_expr); - - Some(PerQueryDetails { - query_id: query_idx, - filter: filter_str, - recall: per_query_recall, - result_ids, - result_distances, - groundtruth_ids: gt_ids, - }) - }) - .collect(); - // Compute QPS from rep latencies let qps: Vec = rep_latencies .iter() @@ -684,18 +633,15 @@ where Ok(SearchRunStats { num_threads: num_threads.get(), - num_queries, search_n, search_l, recall: recall_metrics, qps, - wall_clock_time: rep_latencies, mean_latency: mean, p90_latency: p90, p99_latency: p99, mean_cmps, mean_hops, - per_query_details: Some(per_query_details), }) } async fn run_search_parallel( @@ -830,60 +776,19 @@ pub struct BuildParamsStats { pub alpha: f32, } -/// Helper module for serializing arrays as compact single-line JSON strings -mod compact_array { - use serde::Serializer; - - pub fn serialize_u32_vec(vec: &Vec, serializer: S) -> Result - where - S: Serializer, - { - // Serialize as a string containing the compact JSON array - let compact = serde_json::to_string(vec).unwrap_or_default(); - serializer.serialize_str(&compact) - } - - pub fn serialize_f32_vec(vec: &Vec, serializer: S) -> Result - where - S: Serializer, - { - // Serialize as a string containing the compact JSON array - let compact = serde_json::to_string(vec).unwrap_or_default(); - serializer.serialize_str(&compact) - } -} - -/// Per-query detailed results for debugging/analysis -#[derive(Debug, Serialize)] -pub struct PerQueryDetails { - pub query_id: usize, - pub filter: String, - pub recall: f64, - #[serde(serialize_with = "compact_array::serialize_u32_vec")] - pub result_ids: Vec, - #[serde(serialize_with = "compact_array::serialize_f32_vec")] - pub result_distances: Vec, - #[serde(serialize_with = "compact_array::serialize_u32_vec")] - pub groundtruth_ids: Vec, -} - /// Results from a single search configuration (one search_l value). #[derive(Debug, Serialize)] pub struct SearchRunStats { pub num_threads: usize, - pub num_queries: usize, pub search_n: usize, pub search_l: usize, pub recall: recall::RecallMetrics, pub qps: Vec, - pub wall_clock_time: Vec, pub mean_latency: f64, pub p90_latency: MicroSeconds, pub p99_latency: MicroSeconds, pub mean_cmps: f32, pub mean_hops: f32, - #[serde(skip_serializing_if = "Option::is_none")] - pub per_query_details: Option>, } #[derive(Debug, Serialize)] @@ -929,8 +834,16 @@ impl std::fmt::Display for DocumentIndexStats { writeln!(f, "\nFiltered Search Results:")?; writeln!( f, - " {:>8} {:>8} {:>10} {:>10} {:>15} {:>12} {:>12} {:>10} {:>8} {:>10} {:>12}", - "L", "KNN", "Avg Cmps", "Avg Hops", "QPS -mean(max)", "Avg Latency", "p99 Latency", "Recall", "Threads", "Queries", "WallClock(s)" + " {:>8} {:>8} {:>10} {:>10} {:>15} {:>12} {:>12} {:>10} {:>8}", + "L", + "KNN", + "Avg Cmps", + "Avg Hops", + "QPS -mean(max)", + "Avg Latency", + "p99 Latency", + "Recall", + "Threads" )?; for s in &self.search { let mean_qps = if s.qps.is_empty() { @@ -939,14 +852,9 @@ impl std::fmt::Display for DocumentIndexStats { s.qps.iter().sum::() / s.qps.len() as f64 }; let max_qps = s.qps.iter().cloned().fold(0.0_f64, f64::max); - let mean_wall_clock = if s.wall_clock_time.is_empty() { - 0.0 - } else { - s.wall_clock_time.iter().map(|t| t.as_seconds()).sum::() / s.wall_clock_time.len() as f64 - }; writeln!( f, - " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8} {:>10} {:>12.3}", + " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8}", s.search_l, s.search_n, s.mean_cmps, @@ -956,9 +864,7 @@ impl std::fmt::Display for DocumentIndexStats { s.mean_latency, s.p99_latency, s.recall.average, - s.num_threads, - s.num_queries, - mean_wall_clock + s.num_threads )?; } } diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index 21d74f915..c7e2ab75c 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -109,7 +109,6 @@ impl std::fmt::Display for AggregatedSearchResults { #[derive(Debug, Serialize)] pub(super) struct SearchResults { pub(super) num_tasks: usize, - pub(super) num_queries: usize, pub(super) search_n: usize, pub(super) search_l: usize, pub(super) qps: Vec, @@ -144,7 +143,6 @@ impl SearchResults { Self { num_tasks: setup.tasks.into(), - num_queries: recall.num_queries, search_n: parameters.k_value, search_l: parameters.l_value, qps, @@ -184,8 +182,6 @@ where "p99 Latency", "Recall", "Threads", - "Queries", - "WallClock(s)", ] } else { &[ @@ -198,8 +194,6 @@ where "p99 Latency", "Recall", "Threads", - "Queries", - "WallClock(s)", ] }; @@ -243,13 +237,6 @@ where ); row.insert(format!("{:3}", r.recall.average), col_idx + 7); row.insert(r.num_tasks, col_idx + 8); - row.insert(r.num_queries, col_idx + 9); - let mean_wall_clock = if r.search_latencies.is_empty() { - 0.0 - } else { - r.search_latencies.iter().map(|t| t.as_seconds()).sum::() / r.search_latencies.len() as f64 - }; - row.insert(format!("{:.3}", mean_wall_clock), col_idx + 10); }); write!(f, "{}", table) From edfdee6e7dd3d63f03dc8484390d1f5ebd604c4c Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Mon, 2 Mar 2026 15:54:44 +0530 Subject: [PATCH 05/50] Fix clippy warnings --- .../roaring_attribute_store.rs | 9 +++++++++ .../src/inline_beta_search/inline_beta_filter.rs | 15 +++++++-------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs index c69589ba0..cab250c76 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs @@ -24,6 +24,15 @@ where inv_index: Arc>>, } +impl Default for RoaringAttributeStore +where + IT: VectorId, +{ + fn default() -> Self { + Self::new() + } +} + impl RoaringAttributeStore where IT: VectorId, diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index f03f36c12..76a78de67 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -125,16 +125,16 @@ where { Ok(matched) => { if matched { - return sim * self.beta_value; + sim * self.beta_value } else { - return sim; + sim } } Err(_) => { //If predicate evaluation fails for any reason, we simply revert //to unfiltered search. tracing::warn!("Predicate evaluation failed"); - return sim; + sim } } } else { @@ -182,16 +182,15 @@ where let doc = accessor.get_element(candidate.id).await?; let pe = PredicateEvaluator::new(doc.attributes()); - if computer.is_valid_filter() { - if computer + if computer.is_valid_filter() + && computer .filter_expr() .encoded_filter_expr() .as_ref() .unwrap() .accept(&pe)? - { - filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); - } + { + filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); } } From cbb77f64e54a38038f64fd6fba7b9643e626b9e1 Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:44:47 +0530 Subject: [PATCH 06/50] use search and build benchmark apis --- .../src/backend/document_index/benchmark.rs | 630 ++++++++---------- 1 file changed, 288 insertions(+), 342 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index dffe669ff..51b3467bf 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -20,7 +20,12 @@ use diskann::{ search_output_buffer, DiskANNIndex, SearchParams, StartPointStrategy, }, provider::DefaultContext, - utils::{async_tools, IntoUsize}, +}; +use diskann_benchmark_core::{ + build::{self, AsProgress, Build, Parallelism, Progress}, + recall, + search as search_api, + tokio, }; use diskann_benchmark_runner::{ dispatcher::{DispatchRule, FailureScore, MatchScore}, @@ -485,228 +490,237 @@ impl<'a> DocumentIndexJob<'a> { Ok(stats) } } -/// Local results from a partition of queries. -struct SearchLocalResults { - ids: Matrix, - distances: Vec>, - latencies: Vec, - comparisons: Vec, - hops: Vec, +/// Per-query output from [`FilteredSearcher::search`]. +struct FilteredSearchOutput { + distances: Vec, + comparisons: u32, + hops: u32, } -impl SearchLocalResults { - fn merge(all: &[SearchLocalResults]) -> anyhow::Result { - let first = all - .first() - .ok_or_else(|| anyhow::anyhow!("empty results"))?; - let num_ids = first.ids.ncols(); - let total_rows: usize = all.iter().map(|r| r.ids.nrows()).sum(); - - let mut ids = Matrix::new(0, total_rows, num_ids); - let mut output_row = 0; - for r in all { - for input_row in r.ids.row_iter() { - ids.row_mut(output_row).copy_from_slice(input_row); - output_row += 1; - } - } - - let mut distances = Vec::new(); - let mut latencies = Vec::new(); - let mut comparisons = Vec::new(); - let mut hops = Vec::new(); - for r in all { - distances.extend_from_slice(&r.distances); - latencies.extend_from_slice(&r.latencies); - comparisons.extend_from_slice(&r.comparisons); - hops.extend_from_slice(&r.hops); - } - - Ok(Self { - ids, - distances, - latencies, - comparisons, - hops, - }) - } +/// Implements [`search_api::Search`] for parallelized inline-beta filtered search. +/// +/// Each query is paired with a predicate at the same index in `predicates`. The +/// [`InlineBetaStrategy`] is used with a [`FilteredQuery`] containing the raw vector +/// and the predicate's [`ASTExpr`]. +struct FilteredSearcher +where + DP: diskann::provider::DataProvider, +{ + index: Arc>, + queries: Arc>, + predicates: Arc>, + beta: f32, } -/// Run filtered search with the given parameters. -#[allow(clippy::too_many_arguments)] -fn run_filtered_search( - index: &Arc>, - queries: &Matrix, - predicates: &[(usize, ASTExpr)], - groundtruth: &Vec>, - beta: f32, - num_threads: NonZeroUsize, - search_n: usize, - search_l: usize, - recall_k: usize, - reps: NonZeroUsize, -) -> anyhow::Result +impl search_api::Search for FilteredSearcher where - T: bytemuck::Pod + Copy + Send + Sync + 'static, - DP: diskann::provider::DataProvider< - Context = DefaultContext, - ExternalId = u32, - InternalId = u32, - > + Send + DP: diskann::provider::DataProvider + + Send + Sync + 'static, - InlineBetaStrategy: - diskann::graph::glue::SearchStrategy>>, + InlineBetaStrategy: diskann::graph::glue::SearchStrategy>, u32>, + T: bytemuck::Pod + Copy + Send + Sync + 'static, { - let rt = utils::tokio::runtime(num_threads.get())?; - let num_queries = queries.nrows(); + type Id = DP::ExternalId; + type Parameters = SearchParams; + type Output = FilteredSearchOutput; - let mut all_rep_results = Vec::with_capacity(reps.get()); - let mut rep_latencies = Vec::with_capacity(reps.get()); + fn num_queries(&self) -> usize { + self.queries.nrows() + } - for _ in 0..reps.get() { - let start = std::time::Instant::now(); - let results = rt.block_on(run_search_parallel( - index.clone(), - queries, - predicates, - beta, - num_threads, - search_n, - search_l, - ))?; - rep_latencies.push(MicroSeconds::from(start.elapsed())); - all_rep_results.push(results); + fn id_count(&self, parameters: &SearchParams) -> search_api::IdCount { + search_api::IdCount::Fixed( + NonZeroUsize::new(parameters.k_value).unwrap_or(diskann::utils::ONE), + ) } - // Merge results from first rep for recall calculation - let merged = SearchLocalResults::merge(&all_rep_results[0])?; - - // Compute recall - let recall_metrics: recall::RecallMetrics = - (&recall::knn(groundtruth, None, &merged.ids, recall_k, search_n, false)?).into(); - - // Compute per-query details (only for queries with recall < 1) - let per_query_details: Vec = (0..num_queries) - .filter_map(|query_idx| { - let result_ids: Vec = merged - .ids - .row(query_idx) - .iter() - .copied() - .filter(|&id| id != u32::MAX) - .collect(); - let result_distances: Vec = merged - .distances - .get(query_idx) - .map(|d| d.iter().copied().filter(|&dist| dist != f32::MAX).collect()) - .unwrap_or_default(); - // Only keep top 20 from ground truth - let gt_ids: Vec = groundtruth - .get(query_idx) - .map(|gt| gt.iter().take(20).copied().collect()) - .unwrap_or_default(); - - // Compute per-query recall: intersection of result_ids with gt_ids / recall_k - let result_set: std::collections::HashSet = result_ids.iter().copied().collect(); - let gt_set: std::collections::HashSet = - gt_ids.iter().take(recall_k).copied().collect(); - let intersection = result_set.intersection(>_set).count(); - let per_query_recall = if gt_set.is_empty() { - 1.0 - } else { - intersection as f64 / gt_set.len() as f64 - }; + async fn search( + &self, + parameters: &SearchParams, + buffer: &mut O, + index: usize, + ) -> diskann::ANNResult + where + O: diskann::graph::SearchOutputBuffer + Send, + { + let ctx = DefaultContext; + let query_vec = self.queries.row(index); + let (_, ref ast_expr) = self.predicates[index]; + let strategy = InlineBetaStrategy::new(self.beta, common::FullPrecision); + let filtered_query = FilteredQuery::new(query_vec, ast_expr.clone()); + + // Use a concrete IdDistance scratch buffer so that both the IDs and distances + // are captured. Afterwards, the valid IDs are forwarded into the framework buffer. + let k = parameters.k_value; + let mut ids = vec![0u32; k]; + let mut distances = vec![0.0f32; k]; + let mut scratch = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + + let stats = self + .index + .search(&strategy, &ctx, &filtered_query, parameters, &mut scratch) + .await?; - // Only include queries with imperfect recall - if per_query_recall >= 1.0 { - return None; + let count = scratch.current_len(); + for (&id, &dist) in std::iter::zip(&ids[..count], &distances[..count]) { + if buffer.push(id, dist).is_full() { + break; } + } + + Ok(FilteredSearchOutput { + distances: distances[..count].to_vec(), + comparisons: stats.cmps, + hops: stats.hops, + }) + } +} + +/// Aggregates per-rep [`search_api::SearchResults`] into a [`SearchRunStats`]. +struct FilteredSearchAggregator<'a> { + groundtruth: &'a Vec>, + predicates: &'a [(usize, ASTExpr)], + recall_k: usize, +} + +impl search_api::Aggregate + for FilteredSearchAggregator<'_> +{ + type Output = SearchRunStats; + + fn aggregate( + &mut self, + run: search_api::Run, + results: Vec>, + ) -> anyhow::Result { + let parameters = run.parameters(); + let search_n = parameters.k_value; + let num_queries = results.first().map(|r| r.len()).unwrap_or(0); + + // Recall from first rep only. + let recall_metrics: SerializableRecallMetrics = match results.first() { + Some(first) => (&recall::knn( + self.groundtruth, + None, + first.ids().as_rows(), + self.recall_k, + search_n, + true, + )?) + .into(), + None => anyhow::bail!("no search results"), + }; - let (_, ref ast_expr) = predicates[query_idx]; - let filter_str = format!("{:?}", ast_expr); + // Per-query details from first rep (only queries with recall < 1). + let first = results.first().unwrap(); + let per_query_details: Vec = (0..num_queries) + .filter_map(|query_idx| { + let result_ids: Vec = first.ids().as_rows().row(query_idx).to_vec(); + let result_distances: Vec = first + .output() + .get(query_idx) + .map(|o| o.distances.clone()) + .unwrap_or_default(); + let gt_ids: Vec = self + .groundtruth + .get(query_idx) + .map(|gt| gt.iter().take(20).copied().collect()) + .unwrap_or_default(); + + let result_set: std::collections::HashSet = + result_ids.iter().copied().collect(); + let gt_set: std::collections::HashSet = + gt_ids.iter().take(self.recall_k).copied().collect(); + let intersection = result_set.intersection(>_set).count(); + let per_query_recall = if gt_set.is_empty() { + 1.0 + } else { + intersection as f64 / gt_set.len() as f64 + }; - Some(PerQueryDetails { - query_id: query_idx, - filter: filter_str, - recall: per_query_recall, - result_ids, - result_distances, - groundtruth_ids: gt_ids, + if per_query_recall >= 1.0 { + return None; + } + + let (_, ref ast_expr) = self.predicates[query_idx]; + Some(PerQueryDetails { + query_id: query_idx, + filter: format!("{:?}", ast_expr), + recall: per_query_recall, + result_ids, + result_distances, + groundtruth_ids: gt_ids, + }) }) - }) - .collect(); + .collect(); - // Compute QPS from rep latencies - let qps: Vec = rep_latencies - .iter() - .map(|l| num_queries as f64 / l.as_seconds()) - .collect(); + // Wall-clock latency and QPS per rep. + let rep_latencies: Vec = + results.iter().map(|r| r.end_to_end_latency()).collect(); + let qps: Vec = rep_latencies + .iter() + .map(|l| num_queries as f64 / l.as_seconds()) + .collect(); - // Aggregate per-query latencies across all reps - let (all_latencies, all_cmps, all_hops): (Vec<_>, Vec<_>, Vec<_>) = all_rep_results - .iter() - .map(|results| { - let mut lat = Vec::new(); - let mut cmp = Vec::new(); - let mut hop = Vec::new(); - for r in results { - lat.extend_from_slice(&r.latencies); - cmp.extend_from_slice(&r.comparisons); - hop.extend_from_slice(&r.hops); + // Per-query latencies, comparisons, and hops aggregated across all reps. + let mut all_latencies: Vec = Vec::new(); + let mut all_cmps: Vec = Vec::new(); + let mut all_hops: Vec = Vec::new(); + for r in &results { + all_latencies.extend_from_slice(r.latencies()); + for o in r.output() { + all_cmps.push(o.comparisons); + all_hops.push(o.hops); } - (lat, cmp, hop) - }) - .fold( - (Vec::new(), Vec::new(), Vec::new()), - |(mut a, mut b, mut c): (Vec, Vec, Vec), (x, y, z)| { - a.extend(x); - b.extend(y); - c.extend(z); - (a, b, c) - }, - ); + } - let mut query_latencies = all_latencies; - let percentiles::Percentiles { mean, p90, p99, .. } = - percentiles::compute_percentiles(&mut query_latencies)?; + let percentiles::Percentiles { mean, p90, p99, .. } = + percentiles::compute_percentiles(&mut all_latencies)?; - let mean_cmps = if all_cmps.is_empty() { - 0.0 - } else { - all_cmps.iter().map(|&x| x as f32).sum::() / all_cmps.len() as f32 - }; - let mean_hops = if all_hops.is_empty() { - 0.0 - } else { - all_hops.iter().map(|&x| x as f32).sum::() / all_hops.len() as f32 - }; + let mean_cmps = if all_cmps.is_empty() { + 0.0 + } else { + all_cmps.iter().map(|&x| x as f32).sum::() / all_cmps.len() as f32 + }; + let mean_hops = if all_hops.is_empty() { + 0.0 + } else { + all_hops.iter().map(|&x| x as f32).sum::() / all_hops.len() as f32 + }; - Ok(SearchRunStats { - num_threads: num_threads.get(), - num_queries, - search_n, - search_l, - recall: recall_metrics, - qps, - wall_clock_time: rep_latencies, - mean_latency: mean, - p90_latency: p90, - p99_latency: p99, - mean_cmps, - mean_hops, - per_query_details: Some(per_query_details), - }) + Ok(SearchRunStats { + num_threads: run.setup().threads.get(), + num_queries, + search_n, + search_l: parameters.l_value, + recall: recall_metrics, + qps, + wall_clock_time: rep_latencies, + mean_latency: mean, + p90_latency: p90, + p99_latency: p99, + mean_cmps, + mean_hops, + per_query_details: Some(per_query_details), + }) + } } -async fn run_search_parallel( - index: Arc>, + +/// Run filtered search with the given parameters. +#[allow(clippy::too_many_arguments)] +fn run_filtered_search( + index: &Arc>, queries: &Matrix, predicates: &[(usize, ASTExpr)], + groundtruth: &Vec>, beta: f32, - num_tasks: NonZeroUsize, + num_threads: NonZeroUsize, search_n: usize, search_l: usize, -) -> anyhow::Result> + recall_k: usize, + reps: NonZeroUsize, +) -> anyhow::Result where T: bytemuck::Pod + Copy + Send + Sync + 'static, DP: diskann::provider::DataProvider< @@ -719,109 +733,31 @@ where InlineBetaStrategy: diskann::graph::glue::SearchStrategy>>, { - let num_queries = queries.nrows(); - - // Plan query partitions - let partitions: Result, _> = (0..num_tasks.get()) - .map(|task_id| async_tools::partition(num_queries, num_tasks, task_id)) - .collect(); - let partitions = partitions?; - - // We need to clone data for each task - let queries_arc = Arc::new(queries.clone()); - let predicates_arc = Arc::new(predicates.to_vec()); - - let handles: Vec<_> = partitions - .into_iter() - .map(|range| { - let index = index.clone(); - let queries = queries_arc.clone(); - let predicates = predicates_arc.clone(); - tokio::spawn(async move { - run_search_local(index, queries, predicates, beta, range, search_n, search_l).await - }) - }) - .collect(); - - let mut results = Vec::new(); - for h in handles { - results.push(h.await??); - } - - Ok(results) -} - -async fn run_search_local( - index: Arc>, - queries: Arc>, - predicates: Arc>, - beta: f32, - range: std::ops::Range, - search_n: usize, - search_l: usize, -) -> anyhow::Result -where - T: bytemuck::Pod + Copy + Send + Sync + 'static, - DP: diskann::provider::DataProvider< - Context = DefaultContext, - ExternalId = u32, - InternalId = u32, - > + Send - + Sync, - InlineBetaStrategy: - diskann::graph::glue::SearchStrategy>>, -{ - let mut ids = Matrix::new(0, range.len(), search_n); - let mut all_distances: Vec> = Vec::with_capacity(range.len()); - let mut latencies = Vec::with_capacity(range.len()); - let mut comparisons = Vec::with_capacity(range.len()); - let mut hops = Vec::with_capacity(range.len()); - - let ctx = DefaultContext; - let search_params = SearchParams::new_default(search_n, search_l)?; - - for (output_idx, query_idx) in range.enumerate() { - let query_vec = queries.row(query_idx); - let (_, ref ast_expr) = predicates[query_idx]; - - let strategy = InlineBetaStrategy::new(beta, common::FullPrecision); - let query_vec_owned = query_vec.to_vec(); - let filtered_query: FilteredQuery> = - FilteredQuery::new(query_vec_owned, ast_expr.clone()); - - let start = std::time::Instant::now(); - - let mut distances = vec![0.0f32; search_n]; - let result_ids = ids.row_mut(output_idx); - let mut result_buffer = search_output_buffer::IdDistance::new(result_ids, &mut distances); - - let stats = index - .search( - &strategy, - &ctx, - &filtered_query, - &search_params, - &mut result_buffer, - ) - .await?; - - let result_count = stats.result_count.into_usize(); - result_ids[result_count..].fill(u32::MAX); - distances[result_count..].fill(f32::MAX); + let searcher = Arc::new(FilteredSearcher { + index: index.clone(), + queries: Arc::new(queries.clone()), + predicates: Arc::new(predicates.to_vec()), + beta, + }); + + let parameters = SearchParams::new_default(search_n, search_l)?; + let setup = search_api::Setup { + threads: num_threads, + tasks: num_threads, + reps, + }; - latencies.push(MicroSeconds::from(start.elapsed())); - comparisons.push(stats.cmps); - hops.push(stats.hops); - all_distances.push(distances); - } + let mut results = search_api::search_all( + searcher, + [search_api::Run::new(parameters, setup)], + FilteredSearchAggregator { + groundtruth, + predicates, + recall_k, + }, + )?; - Ok(SearchLocalResults { - ids, - distances: all_distances, - latencies, - comparisons, - hops, - }) + results.pop().ok_or_else(|| anyhow::anyhow!("no search results")) } #[derive(Debug, Serialize)] pub struct BuildParamsStats { @@ -970,69 +906,79 @@ impl std::fmt::Display for DocumentIndexStats { // Parallel Build Support // ================================ -fn make_progress_bar( - nrows: usize, - draw_target: indicatif::ProgressDrawTarget, -) -> anyhow::Result { - let progress = ProgressBar::with_draw_target(Some(nrows as u64), draw_target); - progress.set_style(ProgressStyle::with_template( - "Building [{elapsed_precise}] {wide_bar} {percent}", - )?); - Ok(progress) -} - -/// Control block for parallel document insertion. -/// Manages work distribution and progress tracking across multiple tasks. -struct DocumentControlBlock { +/// Implements [`Build`] for parallel document insertion into a [`DiskANNIndex`] +/// backed by a [`DocumentProvider`]. Each call to [`Build::build`] inserts a +/// contiguous range of vectors and their associated attributes. +struct DocumentIndexBuilder { + index: Arc>, data: Arc>, attributes: Arc>>, - position: AtomicUsize, - cancel: AtomicBool, - progress: ProgressBar, + strategy: DocumentInsertStrategy, } -impl DocumentControlBlock { +impl DocumentIndexBuilder { fn new( + index: Arc>, data: Arc>, attributes: Arc>>, - draw_target: indicatif::ProgressDrawTarget, - ) -> anyhow::Result> { - let nrows = data.nrows(); - Ok(Arc::new(Self { + strategy: DocumentInsertStrategy, + ) -> Arc { + Arc::new(Self { + index, data, attributes, - position: AtomicUsize::new(0), - cancel: AtomicBool::new(false), - progress: make_progress_bar(nrows, draw_target)?, - })) + strategy, + }) } +} - /// Return the next document data to insert: (id, vector_slice, attributes). - fn next(&self) -> Option<(usize, &[T], Vec)> { - let cancel = self.cancel.load(Ordering::Relaxed); - if cancel { - None - } else { - let i = self.position.fetch_add(1, Ordering::Relaxed); - match self.data.get_row(i) { - Some(row) => { - let attrs = self.attributes.get(i).cloned().unwrap_or_default(); - self.progress.inc(1); - Some((i, row, attrs)) - } - None => None, - } +impl Build for DocumentIndexBuilder +where + DP: diskann::provider::DataProvider + + for<'doc> diskann::provider::SetElement> + + AsyncFriendly, + for<'doc> DocumentInsertStrategy: + diskann::graph::glue::InsertStrategy>, + DocumentInsertStrategy: AsyncFriendly, + T: AsyncFriendly, +{ + type Output = (); + + fn num_data(&self) -> usize { + self.data.nrows() + } + + async fn build(&self, range: std::ops::Range) -> diskann::ANNResult { + let ctx = DefaultContext; + for i in range { + let attrs = self.attributes.get(i).cloned().unwrap_or_default(); + let doc = Document::new(self.data.row(i), attrs); + self.index + .insert(self.strategy, &ctx, &(i as u32), &doc) + .await?; } + Ok(()) + } +} + +/// Adapts an already-constructed [`ProgressBar`] into the [`AsProgress`] / [`Progress`] +/// traits expected by [`build_tracked`]. +struct IndicatifAsProgress(ProgressBar); + +struct IndicatifProgress(ProgressBar); + +impl Progress for IndicatifProgress { + fn progress(&self, handled: usize) { + self.0.inc(handled as u64); } - /// Tell all users of the control block to cancel and return early. - fn cancel(&self) { - self.cancel.store(true, Ordering::Relaxed); + fn finish(&self) { + self.0.finish(); } } -impl Drop for DocumentControlBlock { - fn drop(&mut self) { - self.progress.finish(); +impl AsProgress for IndicatifAsProgress { + fn as_progress(&self, _max: usize) -> Arc { + Arc::new(IndicatifProgress(self.0.clone())) } } From 670782f4bec3079a5ee144fca16a160b56dea011 Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:47:17 +0530 Subject: [PATCH 07/50] Rename struct for recall metrics --- diskann-benchmark/src/backend/document_index/benchmark.rs | 4 ++-- diskann-benchmark/src/utils/recall.rs | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 51b3467bf..3a88ba9b7 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -58,7 +58,7 @@ use crate::{ utils::{ self, datafiles::{self, BinFile}, - recall, + recall::SerializableRecallMetrics, }, }; @@ -810,7 +810,7 @@ pub struct SearchRunStats { pub num_queries: usize, pub search_n: usize, pub search_l: usize, - pub recall: recall::RecallMetrics, + pub recall: SerializableRecallMetrics, pub qps: Vec, pub wall_clock_time: Vec, pub mean_latency: f64, diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index 50ef7e430..a7e0e39ab 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -2,15 +2,13 @@ * Copyright (c) Microsoft Corporation. * Licensed under the MIT license. */ - -pub(crate) use benchmark_core::recall::knn; use diskann_benchmark_core as benchmark_core; use serde::Serialize; #[derive(Debug, Clone, Serialize)] #[non_exhaustive] -pub(crate) struct RecallMetrics { +pub(crate) struct SerializableRecallMetrics(benchmark_core::recall::RecallMetrics) { /// The `k` value for `k-recall-at-n`. pub(crate) recall_k: usize, /// The `n` value for `k-recall-at-n`. @@ -25,7 +23,7 @@ pub(crate) struct RecallMetrics { pub(crate) maximum: usize, } -impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { +impl From<&benchmark_core::recall::RecallMetrics> for SerializableRecallMetrics { fn from(m: &benchmark_core::recall::RecallMetrics) -> Self { Self { recall_k: m.recall_k, From 6c2c967d6cbd1d37d353054864ccfc2d8133d63f Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:47:51 +0530 Subject: [PATCH 08/50] Use copyIds --- .../document_insert_strategy.rs | 39 ++----------------- 1 file changed, 3 insertions(+), 36 deletions(-) diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index 850976a32..9a3bad9a0 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -9,13 +9,7 @@ use std::marker::PhantomData; use diskann::{ - graph::{ - glue::{ - ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, - }, - SearchOutputBuffer, - }, - neighbor::Neighbor, + graph::glue::{self, ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, provider::{Accessor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, ANNResult, }; @@ -160,33 +154,6 @@ where } } -#[derive(Debug, Default, Clone, Copy)] -pub struct CopyIdsForDocument; - -impl<'doc, A, VT> SearchPostProcess> for CopyIdsForDocument -where - A: BuildQueryComputer>, - VT: ?Sized, -{ - type Error = std::convert::Infallible; - - fn post_process( - &self, - _accessor: &mut A, - _query: &Document<'doc, VT>, - _computer: &>>::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl std::future::Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - let count = output.extend(candidates.map(|n| (n.id, n.distance))); - std::future::ready(Ok(count)) - } -} - impl<'doc, Inner, DP, VT> SearchStrategy>, Document<'doc, VT>> for DocumentInsertStrategy @@ -196,7 +163,7 @@ where VT: Sync + Send + ?Sized + 'static, { type QueryComputer = Inner::QueryComputer; - type PostProcessor = CopyIdsForDocument; + type PostProcessor = glue::CopyIds; type SearchAccessorError = Inner::SearchAccessorError; type SearchAccessor<'a> = DocumentSearchAccessor, VT>; @@ -212,7 +179,7 @@ where } fn post_processor(&self) -> Self::PostProcessor { - CopyIdsForDocument + glue::CopyIds } } From d13dc7f303c08535adfc67a18bf6417a1051d17c Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:48:10 +0530 Subject: [PATCH 09/50] Use renamed struct in SearchResults --- diskann-benchmark/src/backend/index/result.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index 21d74f915..7f9514613 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -117,7 +117,7 @@ pub(super) struct SearchResults { pub(super) mean_latencies: Vec, pub(super) p90_latencies: Vec, pub(super) p99_latencies: Vec, - pub(super) recall: utils::recall::RecallMetrics, + pub(super) recall: utils::recall::SerializableRecallMetrics, pub(super) mean_cmps: f32, pub(super) mean_hops: f32, } From bd19bdebece118c70165f53d4fd32af80b76b580 Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:49:04 +0530 Subject: [PATCH 10/50] Evaluate query progressively without flattening and cloning the json kv map --- diskann-tools/src/utils/ground_truth.rs | 83 +++++++++++++------------ 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index 8c2fa29f6..678e83dfd 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -29,56 +29,57 @@ use serde_json::{Map, Value}; use crate::utils::{search_index_utils, CMDResult, CMDToolError}; -/// Expands a JSON object with array-valued fields into multiple objects with scalar values. -/// For example: {"country": ["AU", "NZ"], "year": 2007} -/// becomes: [{"country": "AU", "year": 2007}, {"country": "NZ", "year": 2007}] +/// Evaluates a query expression against a label, expanding array-valued fields by recursion. /// -/// If multiple fields have arrays, all combinations are generated. -fn expand_array_fields(value: &Value) -> Vec { - match value { +/// For each key in the JSON object, if the value is an array the expression is evaluated +/// against one element at a time (any-match semantics) without materialising the full +/// Cartesian product. Non-object labels are evaluated directly. +fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { + match label { Value::Object(map) => { - // Start with a single empty object - let mut results: Vec> = vec![Map::new()]; - - for (key, val) in map.iter() { - if let Value::Array(arr) = val { - // Expand: for each existing result, create copies for each array element - let mut new_results: Vec> = Vec::new(); - for existing in results.iter() { - for item in arr.iter() { - let mut new_map: Map = existing.clone(); - new_map.insert(key.clone(), item.clone()); - new_results.push(new_map); - } - } - // If array is empty, keep existing results without this key - if !arr.is_empty() { - results = new_results; - } - } else { - // Non-array field: add to all existing results - for existing in results.iter_mut() { - existing.insert(key.clone(), val.clone()); + let entries: Vec<(&String, &Value)> = map.iter().collect(); + eval_map_recursive(query_expr, &entries, Map::new()) + } + _ => eval_query_expr(query_expr, label), + } +} + +/// Walk `entries` one field at a time, accumulating scalar values into `current`. +/// +/// * Scalar fields are inserted directly and the walk continues with the remaining entries. +/// * Array fields branch once per element; evaluation short-circuits on the first branch +/// that returns `true`. +/// * An empty array is treated as an absent field (preserving the previous behaviour). +/// * When all fields have been consumed, `eval_query_expr` is called on the accumulated object. +fn eval_map_recursive( + query_expr: &ASTExpr, + entries: &[(&String, &Value)], + mut current: Map, +) -> bool { + match entries { + [] => eval_query_expr(query_expr, &Value::Object(current)), + [(key, Value::Array(arr)), rest @ ..] => { + if arr.is_empty() { + // Omit this key, matching the original behaviour for empty arrays. + eval_map_recursive(query_expr, rest, current) + } else { + for item in arr { + let mut branch = current.clone(); + branch.insert((*key).clone(), item.clone()); + if eval_map_recursive(query_expr, rest, branch) { + return true; } } + false } - - results.into_iter().map(Value::Object).collect() } - // If not an object, return as-is - _ => vec![value.clone()], + [(key, val), rest @ ..] => { + current.insert((*key).clone(), (*val).clone()); + eval_map_recursive(query_expr, rest, current) + } } } -/// Evaluates a query expression against a label, expanding array fields first. -/// Returns true if any expanded variant matches the query. -fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { - let expanded = expand_array_fields(label); - expanded - .iter() - .any(|item| eval_query_expr(query_expr, item)) -} - pub fn read_labels_and_compute_bitmap( base_label_filename: &str, query_label_filename: &str, From 8e3a89bdcfb4754107b4e907bfb9a45744ce272b Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:49:39 +0530 Subject: [PATCH 11/50] Use config api to validate values --- .../src/inputs/document_index.rs | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs index b1a36e48a..11e36d5e3 100644 --- a/diskann-benchmark/src/inputs/document_index.rs +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -51,15 +51,17 @@ impl CheckDeserialization for DocumentBuildParams { fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { self.data.check_deserialization(checker)?; self.data_labels.check_deserialization(checker)?; - if self.max_degree == 0 { - return Err(anyhow::anyhow!("max_degree must be > 0")); - } - if self.l_build == 0 { - return Err(anyhow::anyhow!("l_build must be > 0")); - } - if self.alpha <= 0.0 { - return Err(anyhow::anyhow!("alpha must be > 0")); - } + + // checking if the max_degree, l_build and alpha values are valid. + use diskann::graph::config::{Builder, MaxDegree, PruneKind}; + let mut builder = Builder::new( + self.max_degree, + MaxDegree::Value(self.max_degree), + self.l_build, + PruneKind::Occluding, + ); + builder.alpha(self.alpha); + builder.build()?; Ok(()) } } From 40b131485f908d4f9d2be60e88ad04e6104ffc41 Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:49:58 +0530 Subject: [PATCH 12/50] Specify number of threads explicitly --- diskann-benchmark/example/document-filter.json | 3 ++- diskann-benchmark/src/inputs/document_index.rs | 5 ----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/diskann-benchmark/example/document-filter.json b/diskann-benchmark/example/document-filter.json index d6e9e13b2..d60cd4806 100644 --- a/diskann-benchmark/example/document-filter.json +++ b/diskann-benchmark/example/document-filter.json @@ -13,7 +13,8 @@ "distance": "squared_l2", "max_degree": 32, "l_build": 50, - "alpha": 1.2 + "alpha": 1.2, + "num_threads": 4 }, "search": { "queries": "disk_index_sample_query_10pts.fbin", diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs index 11e36d5e3..f1d2d7c67 100644 --- a/diskann-benchmark/src/inputs/document_index.rs +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -39,14 +39,9 @@ pub(crate) struct DocumentBuildParams { pub(crate) max_degree: usize, pub(crate) l_build: usize, pub(crate) alpha: f32, - #[serde(default = "default_num_threads")] pub(crate) num_threads: usize, } -fn default_num_threads() -> usize { - 1 -} - impl CheckDeserialization for DocumentBuildParams { fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { self.data.check_deserialization(checker)?; From 3d2397254c882a1299888c26f54ba7a782658daa Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:51:23 +0530 Subject: [PATCH 13/50] Error when the visit of input expression fails when creating encodedFilterExpression --- .../encoded_filter_expr.rs | 18 ++++---- .../encoded_document_accessor.rs | 10 +--- .../inline_beta_search/inline_beta_filter.rs | 46 +++++-------------- 3 files changed, 22 insertions(+), 52 deletions(-) diff --git a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs index 370ef25ae..b621e347c 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs @@ -5,6 +5,8 @@ use std::sync::{Arc, RwLock}; +use diskann::ANNResult; + use crate::{ encoded_attribute_provider::{ ast_id_expr::ASTIdExpr, ast_label_id_mapper::ASTLabelIdMapper, @@ -14,21 +16,19 @@ use crate::{ }; pub(crate) struct EncodedFilterExpr { - ast_id_expr: Option>, + ast_id_expr: ASTIdExpr, } impl EncodedFilterExpr { - pub fn new(ast_expr: &ASTExpr, attribute_map: Arc>) -> Self { + pub fn try_create(ast_expr: &ASTExpr, attribute_map: Arc>) -> ANNResult { let mut mapper = ASTLabelIdMapper::new(attribute_map); - match ast_expr.accept(&mut mapper) { - Ok(ast_id_expr) => Self { - ast_id_expr: Some(ast_id_expr), - }, - Err(_e) => Self { ast_id_expr: None }, - } + let ast_id_expr = ast_expr.accept(&mut mapper)?; + Ok(Self { + ast_id_expr, + }) } - pub(crate) fn encoded_filter_expr(&self) -> &Option> { + pub(crate) fn encoded_filter_expr(&self) -> &ASTIdExpr { &self.ast_id_expr } } diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 1def9a406..50e835dc7 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -220,20 +220,12 @@ where .inner_accessor .build_query_computer(from.query()) .into_ann_result()?; - let id_query = EncodedFilterExpr::new(from.filter_expr(), self.attribute_map.clone()); - let is_valid_filter = id_query.encoded_filter_expr().is_some(); - if !is_valid_filter { - tracing::warn!( - "Failed to convert {} into an id expr. This will now be an unfiltered search.", - from.filter_expr() - ); - } + let id_query = EncodedFilterExpr::try_create(from.filter_expr(), self.attribute_map.clone())?; Ok(InlineBetaComputer::new( inner_computer, self.beta_value, id_query, - is_valid_filter, )) } } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 76a78de67..6093451c2 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -115,35 +115,20 @@ where let (vec, attrs) = changing.destructure(); let sim = self.inner_computer.evaluate_similarity(vec); let pred_eval = PredicateEvaluator::new(attrs); - if self.is_valid_filter { - match self - .filter_expr - .encoded_filter_expr() - .as_ref() - .unwrap() - .accept(&pred_eval) - { - Ok(matched) => { - if matched { - sim * self.beta_value - } else { - sim - } - } - Err(_) => { - //If predicate evaluation fails for any reason, we simply revert - //to unfiltered search. - tracing::warn!("Predicate evaluation failed"); + match self.filter_expr.encoded_filter_expr().accept(&pred_eval) { + Ok(matched) => { + if matched { + sim * self.beta_value + } else { sim } } - } else { - //If predicate evaluation fails, we will return the score returned by the - //inner computer, as though no predicate was specified. - tracing::warn!( - "Predicate evaluation failed in OnlineBetaComputer::evaluate_similarity()" - ); - sim + Err(_) => { + //If predicate evaluation fails for any reason, we simply revert + //to unfiltered search. + tracing::warn!("Predicate evaluation failed"); + sim + } } } } @@ -182,14 +167,7 @@ where let doc = accessor.get_element(candidate.id).await?; let pe = PredicateEvaluator::new(doc.attributes()); - if computer.is_valid_filter() - && computer - .filter_expr() - .encoded_filter_expr() - .as_ref() - .unwrap() - .accept(&pe)? - { + if computer.filter_expr().encoded_filter_expr().accept(&pe)? { filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); } } From 9e35ccba00fc17eb5d29c6e84973ef5bd705cc1c Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:52:07 +0530 Subject: [PATCH 14/50] remove new runtime method added, use method in benchmark::core --- diskann-benchmark/src/utils/tokio.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/diskann-benchmark/src/utils/tokio.rs b/diskann-benchmark/src/utils/tokio.rs index 21c78abb2..72dbeb918 100644 --- a/diskann-benchmark/src/utils/tokio.rs +++ b/diskann-benchmark/src/utils/tokio.rs @@ -3,13 +3,6 @@ * Licensed under the MIT license. */ -/// Create a generic multi-threaded runtime with `num_threads`. -pub(crate) fn runtime(num_threads: usize) -> anyhow::Result { - Ok(tokio::runtime::Builder::new_multi_thread() - .worker_threads(num_threads) - .build()?) -} - /// Create a current-thread runtime and block on the given future. /// Only for functions that don't need multi-threading pub(crate) fn block_on(future: F) -> F::Output { From 5a8c5601720ffd78421f30d331e481e85a4f8100 Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:55:10 +0530 Subject: [PATCH 15/50] Use dispatch rule to validate benchmark type support --- .../src/backend/document_index/benchmark.rs | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 3a88ba9b7..96d9f42e3 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -64,36 +64,48 @@ use crate::{ /// Register the document index benchmarks. pub(crate) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register::>( - "document-index-build", - |job, checkpoint, out| { - let stats = job.run(checkpoint, out)?; + benchmarks.register::>( + "document-index-build-f32", + |job, _checkpoint, out| { + let stats = job.run(out)?; Ok(serde_json::to_value(stats)?) }, ); } /// Document index benchmark job. -pub(super) struct DocumentIndexJob<'a> { +pub(super) struct DocumentIndexJob<'a, T> { input: &'a DocumentIndexBuild, + _type: std::marker::PhantomData, } -impl<'a> DocumentIndexJob<'a> { +impl<'a, T> DocumentIndexJob<'a, T> { fn new(input: &'a DocumentIndexBuild) -> Self { - Self { input } + Self { + input, + _type: std::marker::PhantomData, + } } } -impl diskann_benchmark_runner::dispatcher::Map for DocumentIndexJob<'static> { - type Type<'a> = DocumentIndexJob<'a>; +impl diskann_benchmark_runner::dispatcher::Map for DocumentIndexJob<'static, T> { + type Type<'a> = DocumentIndexJob<'a, T>; } // Dispatch from the concrete input type -impl<'a> DispatchRule<&'a DocumentIndexBuild> for DocumentIndexJob<'a> { +impl<'a, T> DispatchRule<&'a DocumentIndexBuild> for DocumentIndexJob<'a, T> +where + datatype::Type: DispatchRule, +{ type Error = std::convert::Infallible; fn try_match(_from: &&'a DocumentIndexBuild) -> Result { - Ok(MatchScore(1)) + match _from.build.data_type { + datatype::DataType::Float32 => Ok(MatchScore(0)), + datatype::DataType::UInt8 => Ok(MatchScore(0)), + datatype::DataType::Int8 => Ok(MatchScore(0)), + _ => Err(datatype::MATCH_FAIL), + } } fn convert(from: &'a DocumentIndexBuild) -> Result { @@ -109,7 +121,10 @@ impl<'a> DispatchRule<&'a DocumentIndexBuild> for DocumentIndexJob<'a> { } // Central dispatch mapping from Any -impl<'a> DispatchRule<&'a Any> for DocumentIndexJob<'a> { +impl<'a, T> DispatchRule<&'a Any> for DocumentIndexJob<'a, T> +where + datatype::Type: DispatchRule, +{ type Error = anyhow::Error; fn try_match(from: &&'a Any) -> Result { From de2036535cab301e9a5cd1c3c6ac202c0d49457d Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:56:34 +0530 Subject: [PATCH 16/50] Use compute_medioid helper --- .../src/backend/document_index/benchmark.rs | 127 ++++++------------ 1 file changed, 39 insertions(+), 88 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 96d9f42e3..692cbc0a0 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -50,6 +50,8 @@ use diskann_providers::model::graph::provider::async_::{ inmem::{CreateFullPrecision, DefaultProvider, DefaultProviderParameters, SetStartPoints}, }; use diskann_utils::views::Matrix; +use diskann_vector::PureDistanceFunction; +use diskann_vector::distance::SquaredL2; use indicatif::{ProgressBar, ProgressStyle}; use serde::Serialize; @@ -146,105 +148,54 @@ fn hashmap_to_attributes(map: std::collections::HashMap) .collect() } -/// Compute the index of the row closest to the medoid (centroid) of the data. -fn compute_medoid_index(data: &Matrix) -> usize +fn find_medoid_index(x: MatrixView<'_, T>, y: &[T]) -> Option where - T: bytemuck::Pod + Copy + 'static, + for<'a> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'a [T], &'a [T], f32>, { - use diskann_vector::{distance::SquaredL2, PureDistanceFunction}; - - let dim = data.ncols(); - if dim == 0 || data.nrows() == 0 { - return 0; - } - - // Compute the centroid (mean of all rows) as f64 for precision - let mut sum = vec![0.0f64; dim]; - for i in 0..data.nrows() { - let row = data.row(i); - for (j, &v) in row.iter().enumerate() { - // Convert T to f64 for summation using bytemuck - let f64_val: f64 = if std::any::TypeId::of::() == std::any::TypeId::of::() { - let f32_val: f32 = bytemuck::cast(v); - f32_val as f64 - } else if std::any::TypeId::of::() == std::any::TypeId::of::() { - let u8_val: u8 = bytemuck::cast(v); - u8_val as f64 - } else if std::any::TypeId::of::() == std::any::TypeId::of::() { - let i8_val: i8 = bytemuck::cast(v); - i8_val as f64 - } else { - 0.0 - }; - sum[j] += f64_val; + let mut min_dist = f32::INFINITY; + let mut min_ind = x.nrows(); + for (i, row) in x.row_iter().enumerate() { + let dist = SquaredL2::evaluate(row, y); + if dist < min_dist { + min_dist = dist; + min_ind = i; } } - // Convert centroid to f32 and compute distances - let centroid_f32: Vec = sum - .iter() - .map(|s| (s / data.nrows() as f64) as f32) - .collect(); - - // Find the row closest to the centroid - let mut min_dist = f32::MAX; - let mut medoid_idx = 0; - for i in 0..data.nrows() { - let row = data.row(i); - let row_f32: Vec = row - .iter() - .map(|&v| { - if std::any::TypeId::of::() == std::any::TypeId::of::() { - bytemuck::cast(v) - } else if std::any::TypeId::of::() == std::any::TypeId::of::() { - let u8_val: u8 = bytemuck::cast(v); - u8_val as f32 - } else if std::any::TypeId::of::() == std::any::TypeId::of::() { - let i8_val: i8 = bytemuck::cast(v); - i8_val as f32 - } else { - 0.0 - } - }) - .collect(); - let d = SquaredL2::evaluate(centroid_f32.as_slice(), row_f32.as_slice()); - if d < min_dist { - min_dist = d; - medoid_idx = i; - } + // No closest neighbor found. + if min_ind == x.nrows() { + None + } else { + Some(min_ind) } - - medoid_idx } -impl<'a> DocumentIndexJob<'a> { - fn run( - self, - _checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - // Print the input description - writeln!(output, "{}", self.input)?; +/// Compute the index of the row closest to the medoid (centroid) of the data. +fn compute_medoid_index(data: &Matrix) -> anyhow::Result +where + T: bytemuck::Pod + Copy + 'static + ComputeMedoid, + for<'a> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'a [T], &'a [T], f32>, +{ + let dim = data.ncols(); + if dim == 0 || data.nrows() == 0 { + return Ok(0); + } - let build = &self.input.build; + // returns row closes to centroid. + let medoid = T::compute_medoid(data.as_view()); - // Dispatch based on data type - retain original type without conversion - match build.data_type { - DataType::Float32 => self.run_typed::(output), - DataType::UInt8 => self.run_typed::(output), - DataType::Int8 => self.run_typed::(output), - _ => Err(anyhow::anyhow!( - "Unsupported data type: {:?}. Supported types: float32, uint8, int8.", - build.data_type - )), - } - } + find_medoid_index(data.as_view(), medoid.as_slice()) + .ok_or_else(|| anyhow::anyhow!("Failed to find medoid index: no closest row found")) +} - fn run_typed(self, mut output: &mut dyn Output) -> Result +impl<'a, T> DocumentIndexJob<'a, T> { + fn run(self, mut output: &mut dyn Output) -> Result where - T: bytemuck::Pod + Copy + Send + Sync + 'static + std::fmt::Debug, - T: diskann::graph::SampleableForStart + diskann_utils::future::AsyncFriendly, - T: diskann::utils::VectorRepr + diskann_utils::sampling::WithApproximateNorm, + T: diskann::utils::VectorRepr + + diskann::graph::SampleableForStart + + diskann_utils::sampling::WithApproximateNorm + + 'static, + for<'b> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'b [T], &'b [T]> { let build = &self.input.build; @@ -326,7 +277,7 @@ impl<'a> DocumentIndexJob<'a> { // Store attributes for the start point (medoid) // Start points are stored at indices num_vectors..num_vectors+frozen_points - let medoid_idx = compute_medoid_index(&data); + let medoid_idx = compute_medoid_index(&data)?; let start_point_id = num_vectors as u32; // Start points begin at max_points let medoid_attrs = attributes.get(medoid_idx).cloned().unwrap_or_default(); use diskann_label_filter::traits::attribute_store::AttributeStore; From 9c477cdefb890e4ab2213e4d162d400d846c214f Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:56:57 +0530 Subject: [PATCH 17/50] Remaining changes from search + build api refactor --- .../src/backend/document_index/benchmark.rs | 99 +++++++++---------- 1 file changed, 44 insertions(+), 55 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 692cbc0a0..1f64ffccb 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -10,14 +10,13 @@ use std::io::Write; use std::num::NonZeroUsize; use std::path::Path; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use anyhow::Result; use diskann::{ graph::{ config::Builder as ConfigBuilder, config::MaxDegree, config::PruneKind, - search_output_buffer, DiskANNIndex, SearchParams, StartPointStrategy, + search_output_buffer, DiskANNIndex, SearchOutputBuffer, SearchParams, StartPointStrategy, }, provider::DefaultContext, }; @@ -31,8 +30,8 @@ use diskann_benchmark_runner::{ dispatcher::{DispatchRule, FailureScore, MatchScore}, output::Output, registry::Benchmarks, - utils::{datatype::DataType, percentiles, MicroSeconds}, - Any, Checkpoint, + utils::{datatype, percentiles, MicroSeconds}, + Any, }; use diskann_label_filter::{ attribute::{Attribute, AttributeValue}, @@ -49,6 +48,8 @@ use diskann_providers::model::graph::provider::async_::{ common::{self, NoStore, TableBasedDeletes}, inmem::{CreateFullPrecision, DefaultProvider, DefaultProviderParameters, SetStartPoints}, }; +use diskann_utils::{future::AsyncFriendly, sampling::medoid::ComputeMedoid}; +use diskann_utils::views::MatrixView; use diskann_utils::views::Matrix; use diskann_vector::PureDistanceFunction; use diskann_vector::distance::SquaredL2; @@ -58,7 +59,6 @@ use serde::Serialize; use crate::{ inputs::document_index::DocumentIndexBuild, utils::{ - self, datafiles::{self, BinFile}, recall::SerializableRecallMetrics, }, @@ -296,58 +296,33 @@ impl<'a, T> DocumentIndexJob<'a, T> { )?; let timer = std::time::Instant::now(); - let insert_strategy: DocumentInsertStrategy<_, [T]> = - DocumentInsertStrategy::new(common::FullPrecision); - let rt = utils::tokio::runtime(build.num_threads)?; - - // Create control block for parallel work distribution + let rt = tokio::runtime(build.num_threads)?; let data_arc = Arc::new(data); let attributes_arc = Arc::new(attributes); - let control_block = DocumentControlBlock::new( + + let builder = DocumentIndexBuilder::new( + doc_index.clone(), data_arc.clone(), attributes_arc.clone(), - output.draw_target(), - )?; - - let num_tasks = build.num_threads; - let insert_latencies = rt.block_on(async { - let tasks: Vec<_> = (0..num_tasks) - .map(|_| { - let block = control_block.clone(); - let index = doc_index.clone(); - let strategy = insert_strategy; - tokio::spawn(async move { - let mut latencies = Vec::::new(); - let ctx = DefaultContext; - loop { - match block.next() { - Some((id, vector, attrs)) => { - let doc = Document::new(vector, attrs); - let start = std::time::Instant::now(); - let result = - index.insert(strategy, &ctx, &(id as u32), &doc).await; - latencies.push(MicroSeconds::from(start.elapsed())); - - if let Err(e) = result { - block.cancel(); - return Err(e); - } - } - None => return Ok(latencies), - } - } - }) - }) - .collect(); - - // Collect results from all tasks - let mut all_latencies = Vec::with_capacity(num_vectors); - for task in tasks { - let task_latencies = task.await??; - all_latencies.extend(task_latencies); - } - Ok::<_, anyhow::Error>(all_latencies) - })?; + DocumentInsertStrategy::new(common::FullPrecision), + ); + let num_tasks = NonZeroUsize::new(build.num_threads).unwrap_or(diskann::utils::ONE); + let parallelism = Parallelism::dynamic(diskann::utils::ONE, num_tasks); + let progress = IndicatifAsProgress({ + let bar = ProgressBar::with_draw_target(Some(num_vectors as u64), output.draw_target()); + bar.set_style( + ProgressStyle::with_template("Building [{elapsed_precise}] {wide_bar} {percent}") + .expect("valid template"), + ); + bar + }); + let build_results = + build::build_tracked(builder, parallelism, &rt, Some(&progress))?; + let insert_latencies: Vec = build_results + .take_output() + .into_iter() + .map(|r| r.latency) + .collect(); let build_time: MicroSeconds = timer.elapsed().into(); writeln!(output, " Index built in {} s", build_time.as_seconds())?; @@ -832,7 +807,17 @@ impl std::fmt::Display for DocumentIndexStats { writeln!( f, " {:>8} {:>8} {:>10} {:>10} {:>15} {:>12} {:>12} {:>10} {:>8} {:>10} {:>12}", - "L", "KNN", "Avg Cmps", "Avg Hops", "QPS -mean(max)", "Avg Latency", "p99 Latency", "Recall", "Threads", "Queries", "WallClock(s)" + "L", + "KNN", + "Avg Cmps", + "Avg Hops", + "QPS -mean(max)", + "Avg Latency", + "p99 Latency", + "Recall", + "Threads", + "Queries", + "WallClock(s)" )?; for s in &self.search { let mean_qps = if s.qps.is_empty() { @@ -844,7 +829,11 @@ impl std::fmt::Display for DocumentIndexStats { let mean_wall_clock = if s.wall_clock_time.is_empty() { 0.0 } else { - s.wall_clock_time.iter().map(|t| t.as_seconds()).sum::() / s.wall_clock_time.len() as f64 + s.wall_clock_time + .iter() + .map(|t| t.as_seconds()) + .sum::() + / s.wall_clock_time.len() as f64 }; writeln!( f, From 7f244328209c65d13f67c8d8d9a2a78f0733c9b4 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 12 Mar 2026 15:05:33 +0530 Subject: [PATCH 18/50] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- diskann-tools/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/diskann-tools/Cargo.toml b/diskann-tools/Cargo.toml index 1b4b3408e..0803dae32 100644 --- a/diskann-tools/Cargo.toml +++ b/diskann-tools/Cargo.toml @@ -5,6 +5,7 @@ version.workspace = true authors.workspace = true description.workspace = true documentation.workspace = true +license.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html From 9c1870aff4667865d6f081511b54fe3180f22bfc Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 12 Mar 2026 15:17:46 +0530 Subject: [PATCH 19/50] fix merge errors --- .../src/backend/document_index/benchmark.rs | 49 +++++++++++++++++-- .../inline_beta_search/inline_beta_filter.rs | 7 --- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index ccde2efbf..b12445b16 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -486,7 +486,7 @@ where O: diskann::graph::SearchOutputBuffer + Send, { let ctx = DefaultContext; - let query_vec = self.queries.row(index); + let query_vec = self.queries.row(index).to_vec(); let (_, ref ast_expr) = self.predicates[index]; let strategy = InlineBetaStrategy::new(self.beta, common::FullPrecision); let filtered_query = FilteredQuery::new(query_vec, ast_expr.clone()); @@ -707,19 +707,60 @@ pub struct BuildParamsStats { pub alpha: f32, } +/// Helper module for serializing arrays as compact single-line JSON strings +mod compact_array { + use serde::Serializer; + + pub fn serialize_u32_vec(vec: &Vec, serializer: S) -> Result + where + S: Serializer, + { + // Serialize as a string containing the compact JSON array + let compact = serde_json::to_string(vec).unwrap_or_default(); + serializer.serialize_str(&compact) + } + + pub fn serialize_f32_vec(vec: &Vec, serializer: S) -> Result + where + S: Serializer, + { + // Serialize as a string containing the compact JSON array + let compact = serde_json::to_string(vec).unwrap_or_default(); + serializer.serialize_str(&compact) + } +} + +/// Per-query detailed results for debugging/analysis +#[derive(Debug, Serialize)] +pub struct PerQueryDetails { + pub query_id: usize, + pub filter: String, + pub recall: f64, + #[serde(serialize_with = "compact_array::serialize_u32_vec")] + pub result_ids: Vec, + #[serde(serialize_with = "compact_array::serialize_f32_vec")] + pub result_distances: Vec, + #[serde(serialize_with = "compact_array::serialize_u32_vec")] + pub groundtruth_ids: Vec, +} + /// Results from a single search configuration (one search_l value). #[derive(Debug, Serialize)] pub struct SearchRunStats { pub num_threads: usize, + pub num_queries: usize, pub search_n: usize, pub search_l: usize, pub recall: SerializableRecallMetrics, pub qps: Vec, + pub wall_clock_time: Vec, pub mean_latency: f64, pub p90_latency: MicroSeconds, pub p99_latency: MicroSeconds, pub mean_cmps: f32, pub mean_hops: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub per_query_details: Option>, } #[derive(Debug, Serialize)] @@ -796,7 +837,7 @@ impl std::fmt::Display for DocumentIndexStats { }; writeln!( f, - " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8}", + " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8} {:>10} {:>12.3}", s.search_l, s.search_n, s.mean_cmps, @@ -806,7 +847,9 @@ impl std::fmt::Display for DocumentIndexStats { s.mean_latency, s.p99_latency, s.recall.average, - s.num_threads + s.num_threads, + s.num_queries, + mean_wall_clock )?; } } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 6093451c2..58087b24e 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -79,7 +79,6 @@ pub struct InlineBetaComputer { inner_computer: Inner, beta_value: f32, filter_expr: EncodedFilterExpr, - is_valid_filter: bool, //optimization to avoid evaluating empty predicates. } impl InlineBetaComputer { @@ -87,23 +86,17 @@ impl InlineBetaComputer { inner_computer: Inner, beta_value: f32, filter_expr: EncodedFilterExpr, - is_valid_filter: bool, ) -> Self { Self { inner_computer, beta_value, filter_expr, - is_valid_filter, } } pub(crate) fn filter_expr(&self) -> &EncodedFilterExpr { &self.filter_expr } - - pub(crate) fn is_valid_filter(&self) -> bool { - self.is_valid_filter - } } impl PreprocessedDistanceFunction, f32> From 1591956df376ac2ef533be349ca6cd42f7a533ca Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 12 Mar 2026 15:37:05 +0530 Subject: [PATCH 20/50] Fix merge errors white recall metrics --- diskann-benchmark/src/utils/recall.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index a7e0e39ab..9628a6205 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -8,7 +8,7 @@ use serde::Serialize; #[derive(Debug, Clone, Serialize)] #[non_exhaustive] -pub(crate) struct SerializableRecallMetrics(benchmark_core::recall::RecallMetrics) { +pub(crate) struct SerializableRecallMetrics { /// The `k` value for `k-recall-at-n`. pub(crate) recall_k: usize, /// The `n` value for `k-recall-at-n`. From f13e8498727c24b92daf820bb7a374dc56cc3c01 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 15:40:04 +0530 Subject: [PATCH 21/50] Remove the need for Vec variants --- .../src/backend/document_index/benchmark.rs | 37 ++++++------ .../encoded_document_accessor.rs | 10 ++-- .../inline_beta_search/inline_beta_filter.rs | 11 ++-- diskann-label-filter/src/query.rs | 10 ++-- .../provider/async_/inmem/full_precision.rs | 56 ------------------- 5 files changed, 35 insertions(+), 89 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index b12445b16..38da6bfc7 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -22,9 +22,7 @@ use diskann::{ }; use diskann_benchmark_core::{ build::{self, AsProgress, Build, Parallelism, Progress}, - recall, - search as search_api, - tokio, + recall, search as search_api, tokio, }; use diskann_benchmark_runner::{ dispatcher::{DispatchRule, FailureScore, MatchScore}, @@ -48,11 +46,11 @@ use diskann_providers::model::graph::provider::async_::{ common::{self, NoStore, TableBasedDeletes}, inmem::{CreateFullPrecision, DefaultProvider, DefaultProviderParameters, SetStartPoints}, }; -use diskann_utils::{future::AsyncFriendly, sampling::medoid::ComputeMedoid}; -use diskann_utils::views::MatrixView; use diskann_utils::views::Matrix; -use diskann_vector::PureDistanceFunction; +use diskann_utils::views::MatrixView; +use diskann_utils::{future::AsyncFriendly, sampling::medoid::ComputeMedoid}; use diskann_vector::distance::SquaredL2; +use diskann_vector::PureDistanceFunction; use indicatif::{ProgressBar, ProgressStyle}; use serde::Serialize; @@ -148,7 +146,7 @@ fn hashmap_to_attributes(map: std::collections::HashMap) .collect() } -fn find_medoid_index(x: MatrixView<'_, T>, y: &[T]) -> Option +fn find_medoid_index(x: MatrixView<'_, T>, y: &[T]) -> Option where for<'a> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'a [T], &'a [T], f32>, { @@ -195,7 +193,7 @@ impl<'a, T> DocumentIndexJob<'a, T> { + diskann::graph::SampleableForStart + diskann_utils::sampling::WithApproximateNorm + 'static, - for<'b> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'b [T], &'b [T]> + for<'b> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'b [T], &'b [T]>, { let build = &self.input.build; @@ -316,8 +314,7 @@ impl<'a, T> DocumentIndexJob<'a, T> { ); bar }); - let build_results = - build::build_tracked(builder, parallelism, &rt, Some(&progress))?; + let build_results = build::build_tracked(builder, parallelism, &rt, Some(&progress))?; let insert_latencies: Vec = build_results .take_output() .into_iter() @@ -455,11 +452,15 @@ where impl search_api::Search for FilteredSearcher where - DP: diskann::provider::DataProvider - + Send + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + Sync + 'static, - InlineBetaStrategy: diskann::graph::glue::SearchStrategy>, u32>, + for<'a> InlineBetaStrategy: + diskann::graph::glue::SearchStrategy, u32>, T: bytemuck::Pod + Copy + Send + Sync + 'static, { type Id = DP::ExternalId; @@ -486,7 +487,7 @@ where O: diskann::graph::SearchOutputBuffer + Send, { let ctx = DefaultContext; - let query_vec = self.queries.row(index).to_vec(); + let query_vec = self.queries.row(index); let (_, ref ast_expr) = self.predicates[index]; let strategy = InlineBetaStrategy::new(self.beta, common::FullPrecision); let filtered_query = FilteredQuery::new(query_vec, ast_expr.clone()); @@ -671,8 +672,8 @@ where > + Send + Sync + 'static, - InlineBetaStrategy: - diskann::graph::glue::SearchStrategy>>, + for<'a> InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>, { let searcher = Arc::new(FilteredSearcher { index: index.clone(), @@ -698,7 +699,9 @@ where }, )?; - results.pop().ok_or_else(|| anyhow::anyhow!("no search results")) + results + .pop() + .ok_or_else(|| anyhow::anyhow!("no search results")) } #[derive(Debug, Serialize)] pub struct BuildParamsStats { diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 50e835dc7..0b658fd4d 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -11,7 +11,7 @@ use diskann::{ provider::{Accessor, AsNeighbor, BuildQueryComputer, DelegateNeighbor, HasId}, ANNError, ANNErrorKind, }; -use diskann_utils::{future::AsyncFriendly, Reborrow}; +use diskann_utils::Reborrow; use roaring::RoaringTreemap; use crate::traits::attribute_accessor::AttributeAccessor; @@ -204,17 +204,17 @@ where } } -impl BuildQueryComputer> for EncodedDocumentAccessor +impl<'a, IA, Q> BuildQueryComputer> for EncodedDocumentAccessor where IA: BuildQueryComputer, - Q: AsyncFriendly + Clone, + Q: Send + Sync + ?Sized, { type QueryComputerError = ANNError; type QueryComputer = InlineBetaComputer; fn build_query_computer( &self, - from: &FilteredQuery, + from: &FilteredQuery<'a, Q>, ) -> Result { let inner_computer = self .inner_accessor @@ -234,7 +234,7 @@ impl ExpandBeam for EncodedDocumentAccessor where IA: Accessor, EncodedDocumentAccessor: BuildQueryComputer + AsNeighbor, - Q: Clone + AsyncFriendly, + Q: Send + Sync + ?Sized, { } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 58087b24e..8d1784029 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -9,7 +9,6 @@ use diskann::neighbor::Neighbor; use diskann::provider::{Accessor, BuildQueryComputer, DataProvider}; use diskann::ANNError; -use diskann_utils::future::AsyncFriendly; use diskann_vector::PreprocessedDistanceFunction; use roaring::RoaringTreemap; @@ -36,12 +35,12 @@ impl InlineBetaStrategy { } impl - SearchStrategy>, FilteredQuery> + SearchStrategy>, FilteredQuery<'_, Q>> for InlineBetaStrategy where DP: DataProvider, Strategy: SearchStrategy, - Q: AsyncFriendly + Clone, + Q: Send + Sync + ?Sized, { type QueryComputer = InlineBetaComputer; type PostProcessor = FilterResults; @@ -130,19 +129,19 @@ pub struct FilterResults { inner_post_processor: IPP, } -impl SearchPostProcess, FilteredQuery> +impl<'a, Q, IA, IPP> SearchPostProcess, FilteredQuery<'a, Q>> for FilterResults where IA: BuildQueryComputer, - Q: Clone + AsyncFriendly, IPP: SearchPostProcess + Send + Sync, + Q: Send + Sync + ?Sized, { type Error = ANNError; async fn post_process( &self, accessor: &mut EncodedDocumentAccessor, - query: &FilteredQuery, + query: &FilteredQuery<'a, Q>, computer: &InlineBetaComputer<>::QueryComputer>, candidates: I, output: &mut B, diff --git a/diskann-label-filter/src/query.rs b/diskann-label-filter/src/query.rs index 15c42501f..d85406b5d 100644 --- a/diskann-label-filter/src/query.rs +++ b/diskann-label-filter/src/query.rs @@ -9,17 +9,17 @@ use crate::ASTExpr; /// The Readme.md file in the label-filter folder describes the format /// of the query expression. #[derive(Clone)] -pub struct FilteredQuery { - query: V, +pub struct FilteredQuery<'a, V : ?Sized> { + query: &'a V, filter_expr: ASTExpr, } -impl FilteredQuery { - pub fn new(query: V, filter_expr: ASTExpr) -> Self { +impl<'a, V: ?Sized> FilteredQuery<'a, V> { + pub fn new(query: &'a V, filter_expr: ASTExpr) -> Self { Self { query, filter_expr } } - pub(crate) fn query(&self) -> &V { + pub(crate) fn query(&self) -> &'a V { &self.query } diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 9a48488fe..f83b2ae25 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -321,36 +321,6 @@ where { } -/// Support for Vec queries that delegates to the [T] impl via deref. -/// This allows InlineBetaStrategy to use Vec queries with FullAccessor. -impl BuildQueryComputer> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; - - fn build_query_computer( - &self, - from: &Vec, - ) -> Result { - // Delegate to [T] impl via deref - Ok(T::query_distance(from.as_slice(), self.provider.metric)) - } -} - -/// Support for Vec queries that delegates to the [T] impl. -impl ExpandBeam> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr + Clone, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ -} impl FillSet for FullAccessor<'_, T, Q, D, Ctx> where @@ -527,32 +497,6 @@ where } } -/// Support for Vec queries that delegates to the [T] impl. -/// This allows InlineBetaStrategy to use Vec queries with FullPrecision. -impl SearchStrategy, Vec> for FullPrecision -where - T: VectorRepr + Clone, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = glue::Pipeline; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} // Pruning impl PruneStrategy> for FullPrecision From 7a1244a18d826d650e3ed184c206cf4296b6b04b Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 15:43:50 +0530 Subject: [PATCH 22/50] Undo unecessary change --- test_data/disk_index_search/data.256.label.jsonl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_data/disk_index_search/data.256.label.jsonl b/test_data/disk_index_search/data.256.label.jsonl index a99cde8e2..83254af7b 100644 --- a/test_data/disk_index_search/data.256.label.jsonl +++ b/test_data/disk_index_search/data.256.label.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:92576896b10780a2cd80a16030f8384610498b76453f57fadeacb854379e0acf -size 17701 +oid sha256:7f8b6b99ca32173557689712d3fb5da30c5e4111130fd2accbccf32f5ce3e47e +size 17702 From 86208c89da5af9b4677e6aa688cd96d03686c2b2 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 15:58:29 +0530 Subject: [PATCH 23/50] Remove whitespaces from file --- .../provider/async_/inmem/full_precision.rs | 1162 ++++++++--------- 1 file changed, 580 insertions(+), 582 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index f83b2ae25..e74419a46 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -1,582 +1,580 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::{collections::HashMap, fmt::Debug, future::Future}; - -use diskann::{ - ANNError, ANNResult, - graph::{ - SearchOutputBuffer, - glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, - }, - }, - neighbor::Neighbor, - provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - ExecutionContext, HasId, - }, - utils::{IntoUsize, VectorRepr}, -}; -use diskann_utils::future::AsyncFriendly; -use diskann_vector::{DistanceFunction, distance::Metric}; - -use crate::model::graph::{ - provider::async_::{ - FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, - common::{ - CreateVectorStore, FullPrecision, Internal, NoDeletes, NoStore, Panics, - PrefetchCacheLineLevel, SetElementHelper, - }, - inmem::DefaultProvider, - postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, - }, - traits::AdHoc, -}; - -/// A type alias for the DefaultProvider with full-precision as the primary vector store. -pub type FullPrecisionProvider = - DefaultProvider, Q, D, Ctx>; - -/// The default full-precision vector store. -pub type FullPrecisionStore = FastMemoryVectorProviderAsync>; - -/// A default full-precision vector store provider. -#[derive(Clone)] -pub struct CreateFullPrecision { - dim: usize, - prefetch_cache_line_level: Option, - _phantom: std::marker::PhantomData, -} - -impl CreateFullPrecision -where - T: VectorRepr, -{ - /// Create a new full-precision vector store provider. - pub fn new(dim: usize, prefetch_cache_line_level: Option) -> Self { - Self { - dim, - prefetch_cache_line_level, - _phantom: std::marker::PhantomData, - } - } -} - -impl CreateVectorStore for CreateFullPrecision -where - T: VectorRepr, -{ - type Target = FullPrecisionStore; - fn create( - self, - max_points: usize, - metric: Metric, - prefetch_lookahead: Option, - ) -> Self::Target { - FullPrecisionStore::new( - max_points, - self.dim, - metric, - self.prefetch_cache_line_level, - prefetch_lookahead, - ) - } -} - -//////////////// -// SetElement // -//////////////// - -impl SetElementHelper for FullPrecisionStore -where - T: VectorRepr, -{ - /// Set the element at the given index. - fn set_element(&self, id: &u32, element: &[T]) -> Result<(), ANNError> { - unsafe { self.set_vector_sync(id.into_usize(), element) } - } -} - -////////////////// -// FullAccessor // -////////////////// - -/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. -/// -/// This type implements the following traits: -/// -/// * [`Accessor`] for the [`DefaultProvider`]. -/// * [`ComputerAccessor`] for comparing full-precision distances. -/// * [`BuildQueryComputer`]. -pub struct FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, -{ - /// The host provider. - provider: &'a FullPrecisionProvider, - - /// A buffer for resolving iterators given during bulk operations. - /// - /// The accessor reuses this allocation to amortize allocation cost over multiple bulk - /// operations. - id_buffer: Vec, -} - -impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, -{ - type Repr = T; - fn as_full_precision(&self) -> &FullPrecisionStore { - &self.provider.base_vectors - } -} - -impl HasId for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, -{ - type Id = u32; -} - -impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - fn starting_points(&self) -> impl Future>> { - std::future::ready(self.provider.starting_points()) - } -} - -impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - pub fn new(provider: &'a FullPrecisionProvider) -> Self { - Self { - provider, - id_buffer: Vec::new(), - } - } -} - -impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Delegate = &'a SimpleNeighborProviderAsync; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() - } -} - -impl<'a, T, Q, D, Ctx> Accessor for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - /// The extended element inherets the lifetime of the Accessor. - type Extended = &'a [T]; - - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - /// - /// NOTE: We intentionally don't use `'b` here since our implementation borrows - /// the inner `Opaque` from the underlying provider. - type Element<'b> - = &'a [T] - where - Self: 'b; - - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'b> = &'b [T]; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = Panics; - - /// Return the full-precision vector stored at index `i`. - /// - /// This function always completes synchronously. - #[inline(always)] - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB (undefined behavior) that can result from - // potentially mixing unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - self.provider.base_vectors.get_vector_sync(id.into_usize()) - })) - } - - /// Perform a bulk operation. - /// - /// This implementation uses prefetching. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - { - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(itr); - - let len = id_buffer.len(); - let lookahead = self.provider.base_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.base_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .base_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } - - // Invoke the passed closure on the full-precision vector. - // - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - f( - unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, - *id, - ) - } - - std::future::ready(Ok(())) - } -} - -impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputerError = Panics; - type DistanceComputer = T::Distance; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(T::distance( - self.provider.metric, - Some(self.provider.base_vectors.dim()), - )) - } -} - -impl BuildQueryComputer<[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; - - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - Ok(T::query_distance(from, self.provider.metric)) - } -} - -impl ExpandBeam<[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ -} - - -impl FillSet for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - async fn fill_set( - &mut self, - set: &mut HashMap, - itr: Itr, - ) -> Result<(), Self::GetError> - where - Itr: Iterator + Send + Sync, - { - for i in itr { - set.entry(i).or_insert_with(|| unsafe { - self.provider.base_vectors.get_vector_sync(i.into_usize()) - }); - } - Ok(()) - } -} - -//-------------------// -// In-mem Extensions // -//-------------------// - -impl<'a, T, Q, D, Ctx> AsDeletionCheck for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type Checker = D; - fn as_deletion_check(&self) -> &D { - &self.provider.deleted - } -} - -////////////////// -// Post Process // -////////////////// - -pub trait GetFullPrecision { - type Repr: VectorRepr; - fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync>; -} - -/// A [`SearchPostProcess`]or that: -/// -/// 1. Filters out deleted ids from being returned. -/// 2. Reranks a candidate stream using full-precision distances. -/// 3. Copies back the results to the output buffer. -#[derive(Debug, Default, Clone, Copy)] -pub struct Rerank; - -impl glue::SearchPostProcess for Rerank -where - T: VectorRepr, - A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, -{ - type Error = Panics; - - fn post_process( - &self, - accessor: &mut A, - query: &[T], - _computer: &A::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator>, - B: SearchOutputBuffer + ?Sized, - { - let full = accessor.as_full_precision(); - let checker = accessor.as_deletion_check(); - let f = full.distance(); - - // Filter before computing the full precision distances. - let mut reranked: Vec<(u32, f32)> = candidates - .filter_map(|n| { - if checker.deletion_check(n.id) { - None - } else { - Some(( - n.id, - f.evaluate_similarity(query, unsafe { - full.get_vector_sync(n.id.into_usize()) - }), - )) - } - }) - .collect(); - - // Sort the full precision distances. - reranked - .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); - // Store the reranked results. - std::future::ready(Ok(output.extend(reranked))) - } -} - -//////////////// -// Strategies // -//////////////// - -// A layered approach is used for search strategies. The `Internal` version does the heavy -// lifting in terms of establishing accessors and post processing. -// -// However, during post-processing, the `Internal` versions of strategies will not filter -// out the start points. The publicly exposed types *will* filter out the start points. -// -// This layered approach allows algorithms like `InplaceDeleteStrategy` that need to adjust -// the adjacency list for the start point to reuse the `Internal` strategies. - -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> - for Internal -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = RemoveDeletedIdsAndCopy; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = glue::Pipeline; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - - -// Pruning -impl PruneStrategy> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputer = T::Distance; - type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type PruneAccessorError = diskann::error::Infallible; - - fn prune_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider)) - } -} - -/// Implementing this trait allows `FullPrecision` to be used for multi-insert. -impl<'a, T, Q, D, Ctx> glue::AsElement<&'a [T]> for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Error = diskann::error::Infallible; - fn as_element( - &mut self, - vector: &'a [T], - _id: Self::Id, - ) -> impl Future, Self::Error>> + Send { - std::future::ready(Ok(vector)) - } -} - -impl InsertStrategy, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type PruneStrategy = Self; - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } -} - -// Inplace Delete // -impl InplaceDeleteStrategy> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type DeleteElementError = Panics; - type DeleteElement<'a> = [T]; - type DeleteElementGuard = Box<[T]>; - type PruneStrategy = Self; - type SearchStrategy = Internal; - fn search_strategy(&self) -> Self::SearchStrategy { - Internal(Self) - } - - fn prune_strategy(&self) -> Self::PruneStrategy { - Self - } - - async fn get_delete_element<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - id: u32, - ) -> Result { - Ok(unsafe { provider.base_vectors.get_vector_sync(id.into_usize()) }.into()) - } -} +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{collections::HashMap, fmt::Debug, future::Future}; + +use diskann::{ + ANNError, ANNResult, + graph::{ + SearchOutputBuffer, + glue::{ + self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchExt, SearchStrategy, + }, + }, + neighbor::Neighbor, + provider::{ + Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, + ExecutionContext, HasId, + }, + utils::{IntoUsize, VectorRepr}, +}; +use diskann_utils::future::AsyncFriendly; +use diskann_vector::{DistanceFunction, distance::Metric}; + +use crate::model::graph::{ + provider::async_::{ + FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, + common::{ + CreateVectorStore, FullPrecision, Internal, NoDeletes, NoStore, Panics, + PrefetchCacheLineLevel, SetElementHelper, + }, + inmem::DefaultProvider, + postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, + }, + traits::AdHoc, +}; + +/// A type alias for the DefaultProvider with full-precision as the primary vector store. +pub type FullPrecisionProvider = + DefaultProvider, Q, D, Ctx>; + +/// The default full-precision vector store. +pub type FullPrecisionStore = FastMemoryVectorProviderAsync>; + +/// A default full-precision vector store provider. +#[derive(Clone)] +pub struct CreateFullPrecision { + dim: usize, + prefetch_cache_line_level: Option, + _phantom: std::marker::PhantomData, +} + +impl CreateFullPrecision +where + T: VectorRepr, +{ + /// Create a new full-precision vector store provider. + pub fn new(dim: usize, prefetch_cache_line_level: Option) -> Self { + Self { + dim, + prefetch_cache_line_level, + _phantom: std::marker::PhantomData, + } + } +} + +impl CreateVectorStore for CreateFullPrecision +where + T: VectorRepr, +{ + type Target = FullPrecisionStore; + fn create( + self, + max_points: usize, + metric: Metric, + prefetch_lookahead: Option, + ) -> Self::Target { + FullPrecisionStore::new( + max_points, + self.dim, + metric, + self.prefetch_cache_line_level, + prefetch_lookahead, + ) + } +} + +//////////////// +// SetElement // +//////////////// + +impl SetElementHelper for FullPrecisionStore +where + T: VectorRepr, +{ + /// Set the element at the given index. + fn set_element(&self, id: &u32, element: &[T]) -> Result<(), ANNError> { + unsafe { self.set_vector_sync(id.into_usize(), element) } + } +} + +////////////////// +// FullAccessor // +////////////////// + +/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. +/// +/// This type implements the following traits: +/// +/// * [`Accessor`] for the [`DefaultProvider`]. +/// * [`ComputerAccessor`] for comparing full-precision distances. +/// * [`BuildQueryComputer`]. +pub struct FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, +{ + /// The host provider. + provider: &'a FullPrecisionProvider, + + /// A buffer for resolving iterators given during bulk operations. + /// + /// The accessor reuses this allocation to amortize allocation cost over multiple bulk + /// operations. + id_buffer: Vec, +} + +impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, +{ + type Repr = T; + fn as_full_precision(&self) -> &FullPrecisionStore { + &self.provider.base_vectors + } +} + +impl HasId for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, +{ + type Id = u32; +} + +impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + fn starting_points(&self) -> impl Future>> { + std::future::ready(self.provider.starting_points()) + } +} + +impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + pub fn new(provider: &'a FullPrecisionProvider) -> Self { + Self { + provider, + id_buffer: Vec::new(), + } + } +} + +impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type Delegate = &'a SimpleNeighborProviderAsync; + + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + self.provider.neighbors() + } +} + +impl<'a, T, Q, D, Ctx> Accessor for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + /// The extended element inherets the lifetime of the Accessor. + type Extended = &'a [T]; + + /// This accessor returns raw slices. There *is* a chance of racing when the fast + /// providers are used. We just have to live with it. + /// + /// NOTE: We intentionally don't use `'b` here since our implementation borrows + /// the inner `Opaque` from the underlying provider. + type Element<'b> + = &'a [T] + where + Self: 'b; + + /// `ElementRef` has an arbitrarily short lifetime. + type ElementRef<'b> = &'b [T]; + + /// Choose to panic on an out-of-bounds access rather than propagate an error. + type GetError = Panics; + + /// Return the full-precision vector stored at index `i`. + /// + /// This function always completes synchronously. + #[inline(always)] + fn get_element( + &mut self, + id: Self::Id, + ) -> impl Future, Self::GetError>> + Send { + // SAFETY: We've decided to live with UB (undefined behavior) that can result from + // potentially mixing unsynchronized reads and writes on the underlying memory. + std::future::ready(Ok(unsafe { + self.provider.base_vectors.get_vector_sync(id.into_usize()) + })) + } + + /// Perform a bulk operation. + /// + /// This implementation uses prefetching. + fn on_elements_unordered( + &mut self, + itr: Itr, + mut f: F, + ) -> impl Future> + Send + where + Self: Sync, + Itr: Iterator + Send, + F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + { + // Reuse the internal buffer to collect the results and give us random access + // capabilities. + let id_buffer = &mut self.id_buffer; + id_buffer.clear(); + id_buffer.extend(itr); + + let len = id_buffer.len(); + let lookahead = self.provider.base_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.base_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .base_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + // Invoke the passed closure on the full-precision vector. + // + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + f( + unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, + *id, + ) + } + + std::future::ready(Ok(())) + } +} + +impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type DistanceComputerError = Panics; + type DistanceComputer = T::Distance; + + fn build_distance_computer( + &self, + ) -> Result { + Ok(T::distance( + self.provider.metric, + Some(self.provider.base_vectors.dim()), + )) + } +} + +impl BuildQueryComputer<[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type QueryComputerError = Panics; + type QueryComputer = T::QueryDistance; + + fn build_query_computer( + &self, + from: &[T], + ) -> Result { + Ok(T::query_distance(from, self.provider.metric)) + } +} + +impl ExpandBeam<[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + +impl FillSet for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + async fn fill_set( + &mut self, + set: &mut HashMap, + itr: Itr, + ) -> Result<(), Self::GetError> + where + Itr: Iterator + Send + Sync, + { + for i in itr { + set.entry(i).or_insert_with(|| unsafe { + self.provider.base_vectors.get_vector_sync(i.into_usize()) + }); + } + Ok(()) + } +} + +//-------------------// +// In-mem Extensions // +//-------------------// + +impl<'a, T, Q, D, Ctx> AsDeletionCheck for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type Checker = D; + fn as_deletion_check(&self) -> &D { + &self.provider.deleted + } +} + +////////////////// +// Post Process // +////////////////// + +pub trait GetFullPrecision { + type Repr: VectorRepr; + fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync>; +} + +/// A [`SearchPostProcess`]or that: +/// +/// 1. Filters out deleted ids from being returned. +/// 2. Reranks a candidate stream using full-precision distances. +/// 3. Copies back the results to the output buffer. +#[derive(Debug, Default, Clone, Copy)] +pub struct Rerank; + +impl glue::SearchPostProcess for Rerank +where + T: VectorRepr, + A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, +{ + type Error = Panics; + + fn post_process( + &self, + accessor: &mut A, + query: &[T], + _computer: &A::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator>, + B: SearchOutputBuffer + ?Sized, + { + let full = accessor.as_full_precision(); + let checker = accessor.as_deletion_check(); + let f = full.distance(); + + // Filter before computing the full precision distances. + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if checker.deletion_check(n.id) { + None + } else { + Some(( + n.id, + f.evaluate_similarity(query, unsafe { + full.get_vector_sync(n.id.into_usize()) + }), + )) + } + }) + .collect(); + + // Sort the full precision distances. + reranked + .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + // Store the reranked results. + std::future::ready(Ok(output.extend(reranked))) + } +} + +//////////////// +// Strategies // +//////////////// + +// A layered approach is used for search strategies. The `Internal` version does the heavy +// lifting in terms of establishing accessors and post processing. +// +// However, during post-processing, the `Internal` versions of strategies will not filter +// out the start points. The publicly exposed types *will* filter out the start points. +// +// This layered approach allows algorithms like `InplaceDeleteStrategy` that need to adjust +// the adjacency list for the start point to reuse the `Internal` strategies. + +/// Perform a search entirely in the full-precision space. +/// +/// Starting points are not filtered out of the final results. +impl SearchStrategy, [T]> + for Internal +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = RemoveDeletedIdsAndCopy; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +/// Perform a search entirely in the full-precision space. +/// +/// Starting points are not filtered out of the final results. +impl SearchStrategy, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = glue::Pipeline; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +// Pruning +impl PruneStrategy> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type DistanceComputer = T::Distance; + type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type PruneAccessorError = diskann::error::Infallible; + + fn prune_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::PruneAccessorError> { + Ok(FullAccessor::new(provider)) + } +} + +/// Implementing this trait allows `FullPrecision` to be used for multi-insert. +impl<'a, T, Q, D, Ctx> glue::AsElement<&'a [T]> for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type Error = diskann::error::Infallible; + fn as_element( + &mut self, + vector: &'a [T], + _id: Self::Id, + ) -> impl Future, Self::Error>> + Send { + std::future::ready(Ok(vector)) + } +} + +impl InsertStrategy, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type PruneStrategy = Self; + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } +} + +// Inplace Delete // +impl InplaceDeleteStrategy> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type DeleteElementError = Panics; + type DeleteElement<'a> = [T]; + type DeleteElementGuard = Box<[T]>; + type PruneStrategy = Self; + type SearchStrategy = Internal; + fn search_strategy(&self) -> Self::SearchStrategy { + Internal(Self) + } + + fn prune_strategy(&self) -> Self::PruneStrategy { + Self + } + + async fn get_delete_element<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + id: u32, + ) -> Result { + Ok(unsafe { provider.base_vectors.get_vector_sync(id.into_usize()) }.into()) + } +} From 4441dc735266292b0e7dcc35feee165cffdd9f6c Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 16:08:56 +0530 Subject: [PATCH 24/50] Fix formatting errors --- .../encoded_attribute_provider/encoded_filter_expr.rs | 9 +++++---- .../src/inline_beta_search/encoded_document_accessor.rs | 3 ++- .../src/inline_beta_search/inline_beta_filter.rs | 6 ++++-- diskann-label-filter/src/query.rs | 2 +- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs index b621e347c..0ebdaf72f 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs @@ -20,12 +20,13 @@ pub(crate) struct EncodedFilterExpr { } impl EncodedFilterExpr { - pub fn try_create(ast_expr: &ASTExpr, attribute_map: Arc>) -> ANNResult { + pub fn try_create( + ast_expr: &ASTExpr, + attribute_map: Arc>, + ) -> ANNResult { let mut mapper = ASTLabelIdMapper::new(attribute_map); let ast_id_expr = ast_expr.accept(&mut mapper)?; - Ok(Self { - ast_id_expr, - }) + Ok(Self { ast_id_expr }) } pub(crate) fn encoded_filter_expr(&self) -> &ASTIdExpr { diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 0b658fd4d..418526af7 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -220,7 +220,8 @@ where .inner_accessor .build_query_computer(from.query()) .into_ann_result()?; - let id_query = EncodedFilterExpr::try_create(from.filter_expr(), self.attribute_map.clone())?; + let id_query = + EncodedFilterExpr::try_create(from.filter_expr(), self.attribute_map.clone())?; Ok(InlineBetaComputer::new( inner_computer, diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 8d1784029..ae9a045d6 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -35,8 +35,10 @@ impl InlineBetaStrategy { } impl - SearchStrategy>, FilteredQuery<'_, Q>> - for InlineBetaStrategy + SearchStrategy< + DocumentProvider>, + FilteredQuery<'_, Q>, + > for InlineBetaStrategy where DP: DataProvider, Strategy: SearchStrategy, diff --git a/diskann-label-filter/src/query.rs b/diskann-label-filter/src/query.rs index d85406b5d..142867360 100644 --- a/diskann-label-filter/src/query.rs +++ b/diskann-label-filter/src/query.rs @@ -9,7 +9,7 @@ use crate::ASTExpr; /// The Readme.md file in the label-filter folder describes the format /// of the query expression. #[derive(Clone)] -pub struct FilteredQuery<'a, V : ?Sized> { +pub struct FilteredQuery<'a, V: ?Sized> { query: &'a V, filter_expr: ASTExpr, } From 5ed93b9ee4e50c5b50a6d07598660f36f8a2f5a5 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 16:12:56 +0530 Subject: [PATCH 25/50] Fix clippy warning --- diskann-label-filter/src/query.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann-label-filter/src/query.rs b/diskann-label-filter/src/query.rs index 142867360..b1a7a6409 100644 --- a/diskann-label-filter/src/query.rs +++ b/diskann-label-filter/src/query.rs @@ -20,7 +20,7 @@ impl<'a, V: ?Sized> FilteredQuery<'a, V> { } pub(crate) fn query(&self) -> &'a V { - &self.query + self.query } pub(crate) fn filter_expr(&self) -> &ASTExpr { From 072edb5e48d85adae0757de4c6a77e0b5d9f038f Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 16:56:16 +0530 Subject: [PATCH 26/50] Fix build errors after merge with main - Use Knn instead of the old SearchParams --- .../src/backend/document_index/benchmark.rs | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 38da6bfc7..2d5b6a7ef 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -15,8 +15,8 @@ use std::sync::Arc; use anyhow::Result; use diskann::{ graph::{ - config::Builder as ConfigBuilder, config::MaxDegree, config::PruneKind, - search_output_buffer, DiskANNIndex, SearchOutputBuffer, SearchParams, StartPointStrategy, + config::Builder as ConfigBuilder, config::MaxDegree, config::PruneKind, search::Knn, + search_output_buffer, DiskANNIndex, SearchOutputBuffer, StartPointStrategy, }, provider::DefaultContext, }; @@ -428,6 +428,7 @@ impl<'a, T> DocumentIndexJob<'a, T> { Ok(stats) } } + /// Per-query output from [`FilteredSearcher::search`]. struct FilteredSearchOutput { distances: Vec, @@ -464,22 +465,20 @@ where T: bytemuck::Pod + Copy + Send + Sync + 'static, { type Id = DP::ExternalId; - type Parameters = SearchParams; + type Parameters = Knn; type Output = FilteredSearchOutput; fn num_queries(&self) -> usize { self.queries.nrows() } - fn id_count(&self, parameters: &SearchParams) -> search_api::IdCount { - search_api::IdCount::Fixed( - NonZeroUsize::new(parameters.k_value).unwrap_or(diskann::utils::ONE), - ) + fn id_count(&self, parameters: &Knn) -> search_api::IdCount { + search_api::IdCount::Fixed(parameters.k_value()) } async fn search( &self, - parameters: &SearchParams, + parameters: &Knn, buffer: &mut O, index: usize, ) -> diskann::ANNResult @@ -494,14 +493,14 @@ where // Use a concrete IdDistance scratch buffer so that both the IDs and distances // are captured. Afterwards, the valid IDs are forwarded into the framework buffer. - let k = parameters.k_value; + let k = parameters.k_value().get(); let mut ids = vec![0u32; k]; let mut distances = vec![0.0f32; k]; let mut scratch = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let stats = self + let stats = &self .index - .search(&strategy, &ctx, &filtered_query, parameters, &mut scratch) + .search(*parameters, &strategy, &ctx, &filtered_query, &mut scratch) .await?; let count = scratch.current_len(); @@ -526,18 +525,16 @@ struct FilteredSearchAggregator<'a> { recall_k: usize, } -impl search_api::Aggregate - for FilteredSearchAggregator<'_> -{ +impl search_api::Aggregate for FilteredSearchAggregator<'_> { type Output = SearchRunStats; fn aggregate( &mut self, - run: search_api::Run, + run: search_api::Run, results: Vec>, ) -> anyhow::Result { let parameters = run.parameters(); - let search_n = parameters.k_value; + let search_n = parameters.k_value().get(); let num_queries = results.first().map(|r| r.len()).unwrap_or(0); // Recall from first rep only. @@ -635,7 +632,7 @@ impl search_api::Aggregate num_threads: run.setup().threads.get(), num_queries, search_n, - search_l: parameters.l_value, + search_l: parameters.l_value().get(), recall: recall_metrics, qps, wall_clock_time: rep_latencies, @@ -682,7 +679,7 @@ where beta, }); - let parameters = SearchParams::new_default(search_n, search_l)?; + let parameters = Knn::new_default(search_n, search_l)?; let setup = search_api::Setup { threads: num_threads, tasks: num_threads, From 6a31782bdb86e97b2e68637d422bbc9b864aa697 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 18:52:08 +0530 Subject: [PATCH 27/50] Undo rename of RecallMetrics --- diskann-benchmark/src/backend/document_index/benchmark.rs | 6 +++--- diskann-benchmark/src/backend/index/result.rs | 2 +- diskann-benchmark/src/utils/recall.rs | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 2d5b6a7ef..d954f1f47 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -58,7 +58,7 @@ use crate::{ inputs::document_index::DocumentIndexBuild, utils::{ datafiles::{self, BinFile}, - recall::SerializableRecallMetrics, + recall::RecallMetrics, }, }; @@ -538,7 +538,7 @@ impl search_api::Aggregate for FilteredSearchAgg let num_queries = results.first().map(|r| r.len()).unwrap_or(0); // Recall from first rep only. - let recall_metrics: SerializableRecallMetrics = match results.first() { + let recall_metrics: RecallMetrics = match results.first() { Some(first) => (&recall::knn( self.groundtruth, None, @@ -751,7 +751,7 @@ pub struct SearchRunStats { pub num_queries: usize, pub search_n: usize, pub search_l: usize, - pub recall: SerializableRecallMetrics, + pub recall: RecallMetrics, pub qps: Vec, pub wall_clock_time: Vec, pub mean_latency: f64, diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index 429d4f060..1d6102f9b 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -116,7 +116,7 @@ pub(super) struct SearchResults { pub(super) mean_latencies: Vec, pub(super) p90_latencies: Vec, pub(super) p99_latencies: Vec, - pub(super) recall: utils::recall::SerializableRecallMetrics, + pub(super) recall: utils::recall::RecallMetrics, pub(super) mean_cmps: f32, pub(super) mean_hops: f32, } diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index 9628a6205..c0ed813cd 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -8,7 +8,7 @@ use serde::Serialize; #[derive(Debug, Clone, Serialize)] #[non_exhaustive] -pub(crate) struct SerializableRecallMetrics { +pub(crate) struct RecallMetrics { /// The `k` value for `k-recall-at-n`. pub(crate) recall_k: usize, /// The `n` value for `k-recall-at-n`. @@ -23,7 +23,7 @@ pub(crate) struct SerializableRecallMetrics { pub(crate) maximum: usize, } -impl From<&benchmark_core::recall::RecallMetrics> for SerializableRecallMetrics { +impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { fn from(m: &benchmark_core::recall::RecallMetrics) -> Self { Self { recall_k: m.recall_k, From a70ee53049601238d7ee61c19d3955e2deaadd59 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Wed, 18 Mar 2026 13:23:28 +0530 Subject: [PATCH 28/50] Remove fallback to unfiltered search. --- .../inline_beta_search/inline_beta_filter.rs | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index ae9a045d6..8d6392381 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -109,20 +109,10 @@ where let (vec, attrs) = changing.destructure(); let sim = self.inner_computer.evaluate_similarity(vec); let pred_eval = PredicateEvaluator::new(attrs); - match self.filter_expr.encoded_filter_expr().accept(&pred_eval) { - Ok(matched) => { - if matched { - sim * self.beta_value - } else { - sim - } - } - Err(_) => { - //If predicate evaluation fails for any reason, we simply revert - //to unfiltered search. - tracing::warn!("Predicate evaluation failed"); - sim - } + if self.filter_expr.encoded_filter_expr().accept(&pred_eval).expect("Expected predicate evaluation to not error out!") { + sim * self.beta_value + } else { + sim } } } From a313013cdeb50cdb9806563b9321bf8ff5f2581d Mon Sep 17 00:00:00 2001 From: sampathrg Date: Wed, 18 Mar 2026 13:48:02 +0530 Subject: [PATCH 29/50] Fix formatting error --- .../src/inline_beta_search/inline_beta_filter.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 8d6392381..68dcad630 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -109,7 +109,12 @@ where let (vec, attrs) = changing.destructure(); let sim = self.inner_computer.evaluate_similarity(vec); let pred_eval = PredicateEvaluator::new(attrs); - if self.filter_expr.encoded_filter_expr().accept(&pred_eval).expect("Expected predicate evaluation to not error out!") { + if self + .filter_expr + .encoded_filter_expr() + .accept(&pred_eval) + .expect("Expected predicate evaluation to not error out!") + { sim * self.beta_value } else { sim From 7ea6e4ef08d699a75062290348f5aa4a475e8101 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 23 Mar 2026 12:53:36 +0530 Subject: [PATCH 30/50] Address review comments --- .../example/document-filter.json | 4 + .../src/backend/document_index/benchmark.rs | 162 +++++------------- diskann-benchmark/src/backend/index/build.rs | 4 +- diskann-benchmark/src/backend/index/mod.rs | 2 +- .../src/inputs/document_index.rs | 31 ++-- .../document_insert_strategy.rs | 2 +- 6 files changed, 73 insertions(+), 132 deletions(-) diff --git a/diskann-benchmark/example/document-filter.json b/diskann-benchmark/example/document-filter.json index d60cd4806..0bce1572d 100644 --- a/diskann-benchmark/example/document-filter.json +++ b/diskann-benchmark/example/document-filter.json @@ -21,6 +21,10 @@ "query_predicates": "query.10.label.jsonl", "groundtruth": "disk_index_10pts_idx_uint32_truth_search_filter_res.bin", "beta": 0.5, + "reps": 5, + "num_threads": [ + 1 + ], "runs": [ { "search_n": 20, diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index d954f1f47..6cd1ec1fa 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -15,20 +15,21 @@ use std::sync::Arc; use anyhow::Result; use diskann::{ graph::{ - config::Builder as ConfigBuilder, config::MaxDegree, config::PruneKind, search::Knn, + search::Knn, search_output_buffer, DiskANNIndex, SearchOutputBuffer, StartPointStrategy, }, provider::DefaultContext, + ANNError, ANNErrorKind, }; use diskann_benchmark_core::{ - build::{self, AsProgress, Build, Parallelism, Progress}, + build::{self, Build, Parallelism}, recall, search as search_api, tokio, }; use diskann_benchmark_runner::{ dispatcher::{DispatchRule, FailureScore, MatchScore}, output::Output, registry::Benchmarks, - utils::{datatype, percentiles, MicroSeconds}, + utils::{datatype, fmt, percentiles, MicroSeconds}, Any, }; use diskann_label_filter::{ @@ -40,8 +41,11 @@ use diskann_label_filter::{ }, inline_beta_search::inline_beta_filter::InlineBetaStrategy, query::FilteredQuery, - read_and_parse_queries, read_baselabels, ASTExpr, + read_and_parse_queries, read_baselabels, + traits::attribute_store::AttributeStore, + ASTExpr, }; + use diskann_providers::model::graph::provider::async_::{ common::{self, NoStore, TableBasedDeletes}, inmem::{CreateFullPrecision, DefaultProvider, DefaultProviderParameters, SetStartPoints}, @@ -51,10 +55,10 @@ use diskann_utils::views::MatrixView; use diskann_utils::{future::AsyncFriendly, sampling::medoid::ComputeMedoid}; use diskann_vector::distance::SquaredL2; use diskann_vector::PureDistanceFunction; -use indicatif::{ProgressBar, ProgressStyle}; use serde::Serialize; use crate::{ + backend::index::build::ProgressMeter, inputs::document_index::DocumentIndexBuild, utils::{ datafiles::{self, BinFile}, @@ -100,24 +104,12 @@ where type Error = std::convert::Infallible; fn try_match(_from: &&'a DocumentIndexBuild) -> Result { - match _from.build.data_type { - datatype::DataType::Float32 => Ok(MatchScore(0)), - datatype::DataType::UInt8 => Ok(MatchScore(0)), - datatype::DataType::Int8 => Ok(MatchScore(0)), - _ => Err(datatype::MATCH_FAIL), - } + datatype::Type::::try_match(&_from.build.data_type) } fn convert(from: &'a DocumentIndexBuild) -> Result { Ok(DocumentIndexJob::new(from)) } - - fn description( - f: &mut std::fmt::Formatter<'_>, - _from: Option<&&'a DocumentIndexBuild>, - ) -> std::fmt::Result { - writeln!(f, "tag: \"{}\"", DocumentIndexBuild::tag()) - } } // Central dispatch mapping from Any @@ -236,23 +228,14 @@ impl<'a, T> DocumentIndexJob<'a, T> { .collect(); // 3. Create the index configuration - let metric = build.distance.into(); - let prune_kind = PruneKind::from_metric(metric); - let mut config_builder = ConfigBuilder::new( - build.max_degree, // pruned_degree - MaxDegree::Same, // max_degree - build.l_build, // l_build - prune_kind, // prune_kind - ); - config_builder.alpha(build.alpha); - let config = config_builder.build()?; + let config = build.build_config()?; // 4. Create the data provider directly writeln!(output, "Creating index...")?; let params = DefaultProviderParameters { max_points: num_vectors, frozen_points: diskann::utils::ONE, - metric, + metric: build.distance.into(), dim, prefetch_lookahead: None, prefetch_cache_line_level: None, @@ -278,7 +261,6 @@ impl<'a, T> DocumentIndexJob<'a, T> { let medoid_idx = compute_medoid_index(&data)?; let start_point_id = num_vectors as u32; // Start points begin at max_points let medoid_attrs = attributes.get(medoid_idx).cloned().unwrap_or_default(); - use diskann_label_filter::traits::attribute_store::AttributeStore; attribute_store.set_element(&start_point_id, &medoid_attrs)?; let doc_provider = DocumentProvider::new(inner_provider, attribute_store); @@ -306,15 +288,8 @@ impl<'a, T> DocumentIndexJob<'a, T> { ); let num_tasks = NonZeroUsize::new(build.num_threads).unwrap_or(diskann::utils::ONE); let parallelism = Parallelism::dynamic(diskann::utils::ONE, num_tasks); - let progress = IndicatifAsProgress({ - let bar = ProgressBar::with_draw_target(Some(num_vectors as u64), output.draw_target()); - bar.set_style( - ProgressStyle::with_template("Building [{elapsed_precise}] {wide_bar} {percent}") - .expect("valid template"), - ); - bar - }); - let build_results = build::build_tracked(builder, parallelism, &rt, Some(&progress))?; + let build_results = + build::build_tracked(builder, parallelism, &rt, Some(&ProgressMeter::new(output)))?; let insert_latencies: Vec = build_results .take_output() .into_iter() @@ -416,11 +391,6 @@ impl<'a, T> DocumentIndexJob<'a, T> { label_load_time, build_time, insert_latencies: insert_percentiles, - build_params: BuildParamsStats { - max_degree: build.max_degree, - l_build: build.l_build, - alpha: build.alpha, - }, search: search_results, }; @@ -700,12 +670,6 @@ where .pop() .ok_or_else(|| anyhow::anyhow!("no search results")) } -#[derive(Debug, Serialize)] -pub struct BuildParamsStats { - pub max_degree: usize, - pub l_build: usize, - pub alpha: f32, -} /// Helper module for serializing arrays as compact single-line JSON strings mod compact_array { @@ -772,7 +736,6 @@ pub struct DocumentIndexStats { pub label_load_time: MicroSeconds, pub build_time: MicroSeconds, pub insert_latencies: percentiles::Percentiles, - pub build_params: BuildParamsStats, pub search: Vec, } @@ -797,16 +760,9 @@ impl std::fmt::Display for DocumentIndexStats { writeln!(f, " P50: {} us", self.insert_latencies.median)?; writeln!(f, " P90: {} us", self.insert_latencies.p90)?; writeln!(f, " P99: {} us", self.insert_latencies.p99)?; - writeln!(f, " Build Parameters:")?; - writeln!(f, " max_degree (R): {}", self.build_params.max_degree)?; - writeln!(f, " l_build (L): {}", self.build_params.l_build)?; - writeln!(f, " alpha: {}", self.build_params.alpha)?; if !self.search.is_empty() { - writeln!(f, "\nFiltered Search Results:")?; - writeln!( - f, - " {:>8} {:>8} {:>10} {:>10} {:>15} {:>12} {:>12} {:>10} {:>8} {:>10} {:>12}", + let header = [ "L", "KNN", "Avg Cmps", @@ -817,41 +773,34 @@ impl std::fmt::Display for DocumentIndexStats { "Recall", "Threads", "Queries", - "WallClock(s)" - )?; - for s in &self.search { - let mean_qps = if s.qps.is_empty() { - 0.0 - } else { - s.qps.iter().sum::() / s.qps.len() as f64 - }; + "WallClock(s)", + ]; + writeln!(f, "\nFiltered Search Results:")?; + let mut table = fmt::Table::new(header, self.search.len()); + self.search.iter().enumerate().for_each(|(row_idx, s)| { + let mut row = table.row(row_idx); + let mean_qps = percentiles::mean(&s.qps).unwrap_or(0.0); let max_qps = s.qps.iter().cloned().fold(0.0_f64, f64::max); - let mean_wall_clock = if s.wall_clock_time.is_empty() { - 0.0 - } else { - s.wall_clock_time + let mean_wall_clock = percentiles::mean( + &s.wall_clock_time .iter() - .map(|t| t.as_seconds()) - .sum::() - / s.wall_clock_time.len() as f64 - }; - writeln!( - f, - " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8} {:>10} {:>12.3}", - s.search_l, - s.search_n, - s.mean_cmps, - s.mean_hops, - mean_qps, - max_qps, - s.mean_latency, - s.p99_latency, - s.recall.average, - s.num_threads, - s.num_queries, - mean_wall_clock - )?; - } + .map(|l| l.as_seconds()) + .collect::>(), + ) + .unwrap_or(0.0); + row.insert(s.search_l, 0); + row.insert(s.search_n, 1); + row.insert(format!("{:.1}", s.mean_cmps), 2); + row.insert(format!("{:.1}", s.mean_hops), 3); + row.insert(format!("{:.1}({:.1})", mean_qps, max_qps), 4); + row.insert(format!("{:.1} s", s.mean_latency), 5); + row.insert(format!("{:.1} s", s.p99_latency), 6); + row.insert(format!("{:.4}", s.recall.average), 7); + row.insert(s.num_threads, 8); + row.insert(s.num_queries, 9); + row.insert(format!("{:.3} s", mean_wall_clock), 10); + }); + write!(f, "{}", table)?; } Ok(()) } @@ -906,7 +855,12 @@ where async fn build(&self, range: std::ops::Range) -> diskann::ANNResult { let ctx = DefaultContext; for i in range { - let attrs = self.attributes.get(i).cloned().unwrap_or_default(); + let attrs = self.attributes.get(i).cloned().ok_or_else(|| { + ANNError::message( + ANNErrorKind::Opaque, + format!("Failed to get attributes at index {}", i), + ) + })?; let doc = Document::new(self.data.row(i), attrs); self.index .insert(self.strategy, &ctx, &(i as u32), &doc) @@ -915,25 +869,3 @@ where Ok(()) } } - -/// Adapts an already-constructed [`ProgressBar`] into the [`AsProgress`] / [`Progress`] -/// traits expected by [`build_tracked`]. -struct IndicatifAsProgress(ProgressBar); - -struct IndicatifProgress(ProgressBar); - -impl Progress for IndicatifProgress { - fn progress(&self, handled: usize) { - self.0.inc(handled as u64); - } - - fn finish(&self) { - self.0.finish(); - } -} - -impl AsProgress for IndicatifAsProgress { - fn as_progress(&self, _max: usize) -> Arc { - Arc::new(IndicatifProgress(self.0.clone())) - } -} diff --git a/diskann-benchmark/src/backend/index/build.rs b/diskann-benchmark/src/backend/index/build.rs index ef6284d2a..b674bf5e8 100644 --- a/diskann-benchmark/src/backend/index/build.rs +++ b/diskann-benchmark/src/backend/index/build.rs @@ -213,12 +213,12 @@ impl std::fmt::Display for BuildStats { } } -pub struct ProgressMeter<'a> { +pub(crate) struct ProgressMeter<'a> { output: &'a mut dyn Output, } impl<'a> ProgressMeter<'a> { - pub fn new(output: &'a mut dyn Output) -> Self { + pub(crate) fn new(output: &'a mut dyn Output) -> Self { Self { output } } } diff --git a/diskann-benchmark/src/backend/index/mod.rs b/diskann-benchmark/src/backend/index/mod.rs index 269887c6d..07ed0ccb8 100644 --- a/diskann-benchmark/src/backend/index/mod.rs +++ b/diskann-benchmark/src/backend/index/mod.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -mod build; +pub(crate) mod build; mod search; mod streaming; diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs index f1d2d7c67..4d3e72235 100644 --- a/diskann-benchmark/src/inputs/document_index.rs +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -8,6 +8,7 @@ use std::num::NonZeroUsize; use anyhow::Context; +use diskann::graph::{Config, config::{Builder, MaxDegree, PruneKind, ConfigError}}; use diskann_benchmark_runner::{ files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, }; @@ -42,21 +43,27 @@ pub(crate) struct DocumentBuildParams { pub(crate) num_threads: usize, } +impl DocumentBuildParams { + pub(crate) fn build_config(&self) -> Result { + let metric = self.distance.into(); + let prune_kind = PruneKind::from_metric(metric); + let mut config_builder = Builder::new( + self.max_degree, // pruned_degree + MaxDegree::default_slack(), // max_degree + self.l_build, + prune_kind, + ); + config_builder.alpha(self.alpha); + let config = config_builder.build()?; + Ok(config) + } +} + impl CheckDeserialization for DocumentBuildParams { fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { self.data.check_deserialization(checker)?; self.data_labels.check_deserialization(checker)?; - - // checking if the max_degree, l_build and alpha values are valid. - use diskann::graph::config::{Builder, MaxDegree, PruneKind}; - let mut builder = Builder::new( - self.max_degree, - MaxDegree::Value(self.max_degree), - self.l_build, - PruneKind::Occluding, - ); - builder.alpha(self.alpha); - builder.build()?; + self.build_config()?; Ok(()) } } @@ -67,9 +74,7 @@ pub(crate) struct DocumentSearchParams { pub(crate) query_predicates: InputFile, pub(crate) groundtruth: InputFile, pub(crate) beta: f32, - #[serde(default = "default_reps")] pub(crate) reps: NonZeroUsize, - #[serde(default = "default_thread_counts")] pub(crate) num_threads: Vec, pub(crate) runs: Vec, } diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index 9a3bad9a0..6270af72e 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -233,7 +233,7 @@ where fn prune_accessor<'a>( &'a self, provider: &'a DocumentProvider>, - context: &'a > as DataProvider>::Context, + context: &'a DP::Context, ) -> Result, Self::PruneAccessorError> { self.inner .prune_accessor(provider.inner_provider(), context) From 8010d374ae211136a9aa0b7ab1da913502694054 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 23 Mar 2026 18:27:33 +0530 Subject: [PATCH 31/50] Formatting + revert to old names for some functions --- diskann-benchmark/src/inputs/document_index.rs | 9 ++++++--- diskann-benchmark/src/utils/recall.rs | 1 + .../encoded_attribute_provider/encoded_filter_expr.rs | 2 +- .../src/inline_beta_search/encoded_document_accessor.rs | 3 +-- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs index 4d3e72235..f1ed3c063 100644 --- a/diskann-benchmark/src/inputs/document_index.rs +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -8,7 +8,10 @@ use std::num::NonZeroUsize; use anyhow::Context; -use diskann::graph::{Config, config::{Builder, MaxDegree, PruneKind, ConfigError}}; +use diskann::graph::{ + config::{Builder, ConfigError, MaxDegree, PruneKind}, + Config, +}; use diskann_benchmark_runner::{ files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, }; @@ -48,8 +51,8 @@ impl DocumentBuildParams { let metric = self.distance.into(); let prune_kind = PruneKind::from_metric(metric); let mut config_builder = Builder::new( - self.max_degree, // pruned_degree - MaxDegree::default_slack(), // max_degree + self.max_degree, // pruned_degree + MaxDegree::default_slack(), // max_degree self.l_build, prune_kind, ); diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index c0ed813cd..dcbe86d94 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -2,6 +2,7 @@ * Copyright (c) Microsoft Corporation. * Licensed under the MIT license. */ + use diskann_benchmark_core as benchmark_core; use serde::Serialize; diff --git a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs index 0ebdaf72f..d56cb13c1 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs @@ -20,7 +20,7 @@ pub(crate) struct EncodedFilterExpr { } impl EncodedFilterExpr { - pub fn try_create( + pub fn new( ast_expr: &ASTExpr, attribute_map: Arc>, ) -> ANNResult { diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 418526af7..ab82dad56 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -220,8 +220,7 @@ where .inner_accessor .build_query_computer(from.query()) .into_ann_result()?; - let id_query = - EncodedFilterExpr::try_create(from.filter_expr(), self.attribute_map.clone())?; + let id_query = EncodedFilterExpr::new(from.filter_expr(), self.attribute_map.clone())?; Ok(InlineBetaComputer::new( inner_computer, From 8a86e0d1e5507802d3f672f1d1bfe2c724c1b3ea Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 23 Mar 2026 18:31:37 +0530 Subject: [PATCH 32/50] Add some unit tests + smoke test for benchmark --- diskann-benchmark/src/main.rs | 33 +++ diskann-label-filter/Cargo.toml | 2 + .../document_insert_strategy.rs | 171 ++++++++++--- .../inline_beta_search/inline_beta_filter.rs | 227 ++++++++++++++++++ 4 files changed, 404 insertions(+), 29 deletions(-) diff --git a/diskann-benchmark/src/main.rs b/diskann-benchmark/src/main.rs index b3de5901e..7bc382782 100644 --- a/diskann-benchmark/src/main.rs +++ b/diskann-benchmark/src/main.rs @@ -779,4 +779,37 @@ mod tests { let mut output = Memory::new(); cli.check_target(&mut output).unwrap(); } + + #[test] + fn document_filter_integration() { + let input_path = example_directory().join("document-filter.json"); + + let tempdir = tempfile::tempdir().unwrap(); + let output_path = tempdir.path().join("output.json"); + assert!(!output_path.exists()); + + let modified_input_path = tempdir.path().join("input.json"); + + let mut raw = value_from_file(&input_path); + prefix_search_directories(&mut raw, &root_directory()); + save_to_file(&modified_input_path, &raw); + + let command = Commands::Run { + input_file: modified_input_path.to_owned(), + output_file: output_path.to_owned(), + dry_run: false, + }; + let cli = Cli::from_commands(command, true); + let mut output = Memory::new(); + + cli.run(&mut output).unwrap(); + + let output = String::from_utf8(output.into_inner()).unwrap(); + println!("output = {}", output); + // Check that the results file is generated. + assert!(output_path.exists()); + + let results: Vec = load_from_file(&output_path); + assert_eq!(results.len(), num_jobs(&raw)); + } } diff --git a/diskann-label-filter/Cargo.toml b/diskann-label-filter/Cargo.toml index 98fe3879c..b204dae23 100644 --- a/diskann-label-filter/Cargo.toml +++ b/diskann-label-filter/Cargo.toml @@ -33,6 +33,8 @@ tempfile.workspace = true anyhow.workspace = true futures-util.workspace = true tracing.workspace = true +diskann = { workspace = true, features = ["testing"] } +tokio = { workspace = true, features = ["rt"] } diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index 6270af72e..15f40e5fe 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -6,8 +6,6 @@ //! A strategy wrapper that enables insertion of [Document] objects into a //! [DiskANNIndex] using a [DocumentProvider]. -use std::marker::PhantomData; - use diskann::{ graph::glue::{self, ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, provider::{Accessor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, @@ -19,28 +17,23 @@ use crate::document::Document; use crate::encoded_attribute_provider::roaring_attribute_store::RoaringAttributeStore; /// A strategy wrapper that enables insertion of [Document] objects. -pub struct DocumentInsertStrategy { +pub struct DocumentInsertStrategy { inner: Inner, - _phantom: PhantomData VT>, } -impl Clone for DocumentInsertStrategy { +impl Clone for DocumentInsertStrategy { fn clone(&self) -> Self { Self { inner: self.inner.clone(), - _phantom: PhantomData, } } } -impl Copy for DocumentInsertStrategy {} +impl Copy for DocumentInsertStrategy {} -impl DocumentInsertStrategy { +impl DocumentInsertStrategy { pub fn new(inner: Inner) -> Self { - Self { - inner, - _phantom: PhantomData, - } + Self { inner } } pub fn inner(&self) -> &Inner { @@ -49,32 +42,30 @@ impl DocumentInsertStrategy { } /// Wrapper accessor for Document queries -pub struct DocumentSearchAccessor { +pub struct DocumentSearchAccessor { inner: Inner, - _phantom: PhantomData VT>, + // _phantom: PhantomData VT>, } -impl DocumentSearchAccessor { +impl DocumentSearchAccessor { pub fn new(inner: Inner) -> Self { Self { inner, - _phantom: PhantomData, + // _phantom: PhantomData, } } } -impl HasId for DocumentSearchAccessor +impl HasId for DocumentSearchAccessor where Inner: HasId, - VT: ?Sized, { type Id = Inner::Id; } -impl Accessor for DocumentSearchAccessor +impl Accessor for DocumentSearchAccessor where Inner: Accessor, - VT: ?Sized, { type ElementRef<'a> = Inner::ElementRef<'a>; type Element<'a> @@ -105,7 +96,7 @@ where } } -impl<'doc, Inner, VT> BuildQueryComputer> for DocumentSearchAccessor +impl<'doc, Inner, VT> BuildQueryComputer> for DocumentSearchAccessor where Inner: BuildQueryComputer, VT: ?Sized, @@ -121,10 +112,9 @@ where } } -impl<'this, Inner, VT> DelegateNeighbor<'this> for DocumentSearchAccessor +impl<'this, Inner> DelegateNeighbor<'this> for DocumentSearchAccessor where Inner: DelegateNeighbor<'this>, - VT: ?Sized, { type Delegate = Inner::Delegate; fn delegate_neighbor(&'this mut self) -> Self::Delegate { @@ -132,17 +122,16 @@ where } } -impl<'doc, Inner, VT> ExpandBeam> for DocumentSearchAccessor +impl<'doc, Inner, VT> ExpandBeam> for DocumentSearchAccessor where Inner: ExpandBeam, VT: ?Sized, { } -impl SearchExt for DocumentSearchAccessor +impl SearchExt for DocumentSearchAccessor where Inner: SearchExt, - VT: ?Sized, { fn starting_points( &self, @@ -156,7 +145,7 @@ where impl<'doc, Inner, DP, VT> SearchStrategy>, Document<'doc, VT>> - for DocumentInsertStrategy + for DocumentInsertStrategy where Inner: InsertStrategy, DP: DataProvider, @@ -165,7 +154,7 @@ where type QueryComputer = Inner::QueryComputer; type PostProcessor = glue::CopyIds; type SearchAccessorError = Inner::SearchAccessorError; - type SearchAccessor<'a> = DocumentSearchAccessor, VT>; + type SearchAccessor<'a> = DocumentSearchAccessor>; fn search_accessor<'a>( &'a self, @@ -185,7 +174,7 @@ where impl<'doc, Inner, DP, VT> InsertStrategy>, Document<'doc, VT>> - for DocumentInsertStrategy + for DocumentInsertStrategy where Inner: InsertStrategy, DP: DataProvider, @@ -239,3 +228,127 @@ where .prune_accessor(provider.inner_provider(), context) } } + +#[cfg(test)] +mod tests { + use super::{DocumentInsertStrategy, DocumentPruneStrategy, DocumentSearchAccessor}; + use diskann::{ + graph::{ + glue::{InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, + test::provider::{Config, Context, Provider, StartPoint, Strategy}, + }, + provider::BuildQueryComputer, + }; + use diskann_vector::distance::Metric; + + use crate::{ + document::Document, + encoded_attribute_provider::{ + document_provider::DocumentProvider, roaring_attribute_store::RoaringAttributeStore, + }, + }; + + // --------------------------------------------------------------------------- + // Helpers + // --------------------------------------------------------------------------- + + /// Build a minimal test provider with a single start point and three dimensions. + fn make_test_provider() -> Provider { + let config = Config::new( + Metric::L2, + 10, + StartPoint::new(u32::MAX, vec![1.0f32, 2.0, 0.0]), + ) + .expect("test provider config should be valid"); + Provider::new(config) + } + + fn make_doc_provider( + provider: Provider, + ) -> DocumentProvider> { + DocumentProvider::new(provider, RoaringAttributeStore::new()) + } + + /// `search_accessor` successfully creates a `DocumentSearchAccessor` wrapping the + /// inner accessor. + #[test] + fn test_search_accessor_creates_wrapped_accessor() { + let strategy = DocumentInsertStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as SearchStrategy< + DocumentProvider>, + Document<'_, [f32]>, + >>::search_accessor(&strategy, &provider, &context); + + assert!(result.is_ok()); + } + + #[test] + fn test_insert_search_accessor_creates_wrapped_accessor() { + let strategy = DocumentInsertStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as InsertStrategy< + DocumentProvider>, + Document<'_, [f32]>, + >>::insert_search_accessor(&strategy, &provider, &context); + + assert!(result.is_ok()); + } + + #[test] + fn test_prune_accessor_delegates_to_inner_provider() { + let doc_prune_strategy = DocumentPruneStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as PruneStrategy< + DocumentProvider>, + >>::prune_accessor(&doc_prune_strategy, &provider, &context); + + assert!(result.is_ok()); + } + + #[test] + fn test_build_query_computer_extracts_vector_from_document() { + let provider = make_test_provider(); + let context = Context::new(); + let strategy_inner = Strategy::new(); + let inner_accessor = strategy_inner + .search_accessor(&provider, &context) + .expect("creating search accessor should succeed"); + let doc_accessor = DocumentSearchAccessor::new(inner_accessor); + + let vector = vec![1.0f32, 2.0, 0.0]; + let doc = Document::new(vector.as_slice(), vec![]); + + let result = as BuildQueryComputer< + Document<'_, [f32]>, + >>::build_query_computer(&doc_accessor, &doc); + + assert!( + result.is_ok(), + "build_query_computer should succeed for a valid vector" + ); + } + + #[test] + fn test_terminate_early_delegates_to_inner() { + let provider = make_test_provider(); + let context = Context::new(); + let strategy_inner = Strategy::new(); + let mut inner_accessor = strategy_inner + .search_accessor(&provider, &context) + .expect("creating search accessor should succeed"); + let inner_terminate_early = inner_accessor.terminate_early(); + let mut doc_accessor = DocumentSearchAccessor::new(inner_accessor); + assert_eq!( + inner_terminate_early, + doc_accessor.terminate_early(), + "terminate_early should have same value as inner accessor" + ); + } +} diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 68dcad630..b1cfbca4b 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -175,3 +175,230 @@ where .map_err(|e| e.into()) } } + +#[cfg(test)] +mod tests { + use std::sync::{Arc, RwLock}; + + use diskann::{ + graph::{ + glue::{self, SearchPostProcess, SearchStrategy}, + search_output_buffer::IdDistance, + test::provider::{Config, Context, Provider, StartPoint, Strategy}, + }, + neighbor::Neighbor, + provider::{BuildQueryComputer, SetElement}, + }; + use diskann_vector::{distance::Metric, PreprocessedDistanceFunction}; + use roaring::RoaringTreemap; + use serde_json::Value; + + use crate::{ + attribute::{Attribute, AttributeValue}, + document::EncodedDocument, + encoded_attribute_provider::{ + attribute_encoder::AttributeEncoder, encoded_filter_expr::EncodedFilterExpr, + roaring_attribute_store::RoaringAttributeStore, + }, + inline_beta_search::encoded_document_accessor::EncodedDocumentAccessor, + query::FilteredQuery, + traits::attribute_store::AttributeStore, + ASTExpr, CompareOp, + }; + + use super::{FilterResults, InlineBetaComputer}; + + // ----------------------------------------------------------------------- + // Stub inner distance computer + // ----------------------------------------------------------------------- + + /// Always returns a fixed constant distance, regardless of the vector value. + struct ConstComputer(f32); + + impl PreprocessedDistanceFunction<&[f32], f32> for ConstComputer { + fn evaluate_similarity(&self, _: &[f32]) -> f32 { + self.0 + } + } + + // ----------------------------------------------------------------------- + // Helper: build an AttributeEncoder + ASTExpr for `field == value`, + // returning (attr_map, ast_expr, encoded_id_of_that_attribute). + // ----------------------------------------------------------------------- + + fn setup_encoder_and_filter( + field: &str, + value: &str, + ) -> (Arc>, ASTExpr, u64) { + let mut encoder = AttributeEncoder::new(); + let attr = Attribute::from_value(field, AttributeValue::String(value.to_owned())); + let encoded_id = encoder.insert(&attr); + let attr_map = Arc::new(RwLock::new(encoder)); + let ast_expr = ASTExpr::Compare { + field: field.to_string(), + op: CompareOp::Eq(Value::String(value.to_string())), + }; + (attr_map, ast_expr, encoded_id) + } + + // ----------------------------------------------------------------------- + // Test 1: when the filter matches, evaluate_similarity returns inner * beta + // ----------------------------------------------------------------------- + + #[test] + fn test_evaluate_similarity_filter_match_scales_by_beta() { + let (attr_map, ast_expr, color_red_id) = setup_encoder_and_filter("color", "red"); + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); + + let beta = 2.5_f32; + let inner_dist = 4.0_f32; + let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); + + // Bitmap contains the encoded ID for "color=red" → predicate matches + let mut matching_map = RoaringTreemap::new(); + matching_map.insert(color_red_id); + let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &matching_map); + + assert_eq!( + computer.evaluate_similarity(doc), + inner_dist * beta, + "a matched filter should multiply the inner similarity by beta" + ); + } + + // ----------------------------------------------------------------------- + // Test 2: when the filter does not match, evaluate_similarity is unchanged + // ----------------------------------------------------------------------- + + #[test] + fn test_evaluate_similarity_no_filter_match_preserves_score() { + let (attr_map, ast_expr, _) = setup_encoder_and_filter("color", "red"); + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); + + let beta = 2.5_f32; + let inner_dist = 4.0_f32; + let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); + + // Empty bitmap → no attribute matches the predicate + let empty_map = RoaringTreemap::new(); + let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &empty_map); + + assert_eq!( + computer.evaluate_similarity(doc), + inner_dist, + "an unmatched filter should leave the inner similarity unchanged" + ); + } + + // ----------------------------------------------------------------------- + // Test 3: post_process forwards only filter-matching candidates to the + // inner post processor (and therefore to the output buffer). + // ----------------------------------------------------------------------- + + #[test] + fn test_post_process_only_passes_matching_candidates_to_inner() { + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .expect("test tokio runtime"); + + // IDs 0 and 1 carry color=red (should pass the filter) + // IDs 2 and 3 carry color=blue (should be dropped by the filter) + let attr_store = RoaringAttributeStore::::new(); + let red = Attribute::from_value("color", AttributeValue::String("red".to_owned())); + let blue = Attribute::from_value("color", AttributeValue::String("blue".to_owned())); + for id in 0u32..2 { + attr_store + .set_element(&id, std::slice::from_ref(&red)) + .expect("set red attr"); + } + for id in 2u32..4 { + attr_store + .set_element(&id, std::slice::from_ref(&blue)) + .expect("set blue attr"); + } + + // The attribute_map is shared so EncodedFilterExpr sees the same encodings + // as those stored by the attribute store. + let attr_map = attr_store.attribute_map(); + + let ast_expr = ASTExpr::Compare { + field: "color".to_string(), + op: CompareOp::Eq(Value::String("red".to_string())), + }; + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map.clone()).expect("filter expr"); + + // Build the inner vector provider: start point at u32::MAX + 2-D zero vectors for 0..3 + let config = Config::new(Metric::L2, 10, StartPoint::new(u32::MAX, vec![1.0f32, 0.0])) + .expect("provider config"); + let inner_provider = Provider::new(config); + let ctx = Context::new(); + rt.block_on(async { + for id in 0u32..4 { + inner_provider + .set_element(&ctx, &id, &[0.0f32, 0.0] as &[f32]) + .await + .expect("add vector to inner provider"); + } + }); + + // Obtain the inner search accessor and derive an inner computer from it + let strategy = Strategy::new(); + let inner_accessor = strategy + .search_accessor(&inner_provider, &ctx) + .expect("inner accessor"); + let inner_computer = inner_accessor + .build_query_computer(&[0.0f32, 0.0][..]) + .expect("inner computer"); + + // Wrap accessor + attribute store into an EncodedDocumentAccessor + let attribute_accessor = attr_store.attribute_accessor().expect("attribute accessor"); + let mut doc_accessor = + EncodedDocumentAccessor::new(inner_accessor, attribute_accessor, attr_map, 2.0); + + let computer = InlineBetaComputer::new(inner_computer, 2.0, filter_expr); + + // Four candidates: 0 and 1 match (red); 2 and 3 do not (blue) + let candidates = [ + Neighbor::new(0u32, 1.0_f32), + Neighbor::new(1u32, 2.0_f32), + Neighbor::new(2u32, 3.0_f32), + Neighbor::new(3u32, 4.0_f32), + ]; + + let mut ids = [u32::MAX; 4]; + let mut distances = [f32::MAX; 4]; + let mut output = IdDistance::new(&mut ids, &mut distances); + + let query_vec = [0.0f32, 0.0]; + let filter_query = FilteredQuery::new(&query_vec[..], ast_expr); + + // CopyIds simply copies whatever it receives into the output buffer, + // so the output reflects exactly what FilterResults lets through. + let count = rt + .block_on( + FilterResults { + inner_post_processor: glue::CopyIds, + } + .post_process( + &mut doc_accessor, + &filter_query, + &computer, + candidates.into_iter(), + &mut output, + ), + ) + .expect("post_process"); + + // Only the two red-labeled candidates should have been forwarded + assert_eq!(count, 2, "exactly 2 of 4 candidates should pass the filter"); + let passed = &ids[..count]; + assert!( + passed.contains(&0), + "ID 0 (color=red) should pass the filter" + ); + assert!( + passed.contains(&1), + "ID 1 (color=red) should pass the filter" + ); + } +} From 5664dcaced77cdda5a3e619941496808886ec857 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 23 Mar 2026 18:32:26 +0530 Subject: [PATCH 33/50] Update output serializer + remove unnecessary type parameter --- .../src/backend/document_index/benchmark.rs | 37 +++---------------- 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 6cd1ec1fa..a284e93db 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -15,8 +15,7 @@ use std::sync::Arc; use anyhow::Result; use diskann::{ graph::{ - search::Knn, - search_output_buffer, DiskANNIndex, SearchOutputBuffer, StartPointStrategy, + search::Knn, search_output_buffer, DiskANNIndex, SearchOutputBuffer, StartPointStrategy, }, provider::DefaultContext, ANNError, ANNErrorKind, @@ -671,40 +670,14 @@ where .ok_or_else(|| anyhow::anyhow!("no search results")) } -/// Helper module for serializing arrays as compact single-line JSON strings -mod compact_array { - use serde::Serializer; - - pub fn serialize_u32_vec(vec: &Vec, serializer: S) -> Result - where - S: Serializer, - { - // Serialize as a string containing the compact JSON array - let compact = serde_json::to_string(vec).unwrap_or_default(); - serializer.serialize_str(&compact) - } - - pub fn serialize_f32_vec(vec: &Vec, serializer: S) -> Result - where - S: Serializer, - { - // Serialize as a string containing the compact JSON array - let compact = serde_json::to_string(vec).unwrap_or_default(); - serializer.serialize_str(&compact) - } -} - /// Per-query detailed results for debugging/analysis #[derive(Debug, Serialize)] pub struct PerQueryDetails { pub query_id: usize, pub filter: String, pub recall: f64, - #[serde(serialize_with = "compact_array::serialize_u32_vec")] pub result_ids: Vec, - #[serde(serialize_with = "compact_array::serialize_f32_vec")] pub result_distances: Vec, - #[serde(serialize_with = "compact_array::serialize_u32_vec")] pub groundtruth_ids: Vec, } @@ -817,7 +790,7 @@ struct DocumentIndexBuilder { index: Arc>, data: Arc>, attributes: Arc>>, - strategy: DocumentInsertStrategy, + strategy: DocumentInsertStrategy, } impl DocumentIndexBuilder { @@ -825,7 +798,7 @@ impl DocumentIndexBuilder { index: Arc>, data: Arc>, attributes: Arc>>, - strategy: DocumentInsertStrategy, + strategy: DocumentInsertStrategy, ) -> Arc { Arc::new(Self { index, @@ -841,9 +814,9 @@ where DP: diskann::provider::DataProvider + for<'doc> diskann::provider::SetElement> + AsyncFriendly, - for<'doc> DocumentInsertStrategy: + for<'doc> DocumentInsertStrategy: diskann::graph::glue::InsertStrategy>, - DocumentInsertStrategy: AsyncFriendly, + DocumentInsertStrategy: AsyncFriendly, T: AsyncFriendly, { type Output = (); From b03bef40f315cff77d137413dec31d5e9604787c Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 23 Mar 2026 18:51:18 +0530 Subject: [PATCH 34/50] Put the benchmarks and the smoke test behind a feature --- .github/workflows/ci.yml | 2 +- diskann-benchmark/Cargo.toml | 3 ++ .../src/backend/document_index/mod.rs | 23 ++++++++++- diskann-benchmark/src/main.rs | 40 +++++++++++++++++-- 4 files changed, 62 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0a31f86cc..fb8da8552 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -293,7 +293,7 @@ jobs: --cargo-profile ci \ --config "$RUST_CONFIG" \ --features \ - virtual_storage,bf_tree,spherical-quantization,product-quantization,tracing,experimental_diversity_search + virtual_storage,bf_tree,spherical-quantization,product-quantization,tracing,experimental_diversity_search,document-index cargo test --locked --doc --workspace --profile ci --config "$RUST_CONFIG" diff --git a/diskann-benchmark/Cargo.toml b/diskann-benchmark/Cargo.toml index bebaf4b8e..39d64b1bb 100644 --- a/diskann-benchmark/Cargo.toml +++ b/diskann-benchmark/Cargo.toml @@ -63,6 +63,9 @@ scalar-quantization = [] # Enable minmax-quantization based algorithms minmax-quantization = [] +# Enable Document Index benchmarks +document-index = [] + # Enable Disk Index benchmarks disk-index = [ "diskann-disk/perf_test", diff --git a/diskann-benchmark/src/backend/document_index/mod.rs b/diskann-benchmark/src/backend/document_index/mod.rs index 9937590cc..022470578 100644 --- a/diskann-benchmark/src/backend/document_index/mod.rs +++ b/diskann-benchmark/src/backend/document_index/mod.rs @@ -8,6 +8,25 @@ //! This benchmark tests the DocumentInsertStrategy which enables inserting //! Document objects (vector + attributes) into a DiskANN index. -mod benchmark; +use diskann_benchmark_runner::registry::Benchmarks; -pub(crate) use benchmark::register_benchmarks; +cfg_if::cfg_if! { + if #[cfg(feature = "document-index")] { + mod benchmark; + + /// Register document index benchmarks when the `document-index` feature is enabled. + pub(crate) fn register_benchmarks(registry: &mut Benchmarks) { + benchmark::register_benchmarks(registry); + } + } else { + crate::utils::stub_impl!( + "document-index", + inputs::document_index::DocumentIndexBuild + ); + + /// Register a stub that guides users to enable the `document-index` feature. + pub(crate) fn register_benchmarks(registry: &mut Benchmarks) { + imp::register("document-index", registry); + } + } +} diff --git a/diskann-benchmark/src/main.rs b/diskann-benchmark/src/main.rs index 7bc382782..cdca116e3 100644 --- a/diskann-benchmark/src/main.rs +++ b/diskann-benchmark/src/main.rs @@ -794,8 +794,17 @@ mod tests { prefix_search_directories(&mut raw, &root_directory()); save_to_file(&modified_input_path, &raw); + run_document_filter_integration(&modified_input_path, &output_path, &raw); + } + + #[cfg(feature = "document-index")] + fn run_document_filter_integration( + input_path: &std::path::Path, + output_path: &std::path::Path, + raw: &serde_json::Value, + ) { let command = Commands::Run { - input_file: modified_input_path.to_owned(), + input_file: input_path.to_owned(), output_file: output_path.to_owned(), dry_run: false, }; @@ -809,7 +818,32 @@ mod tests { // Check that the results file is generated. assert!(output_path.exists()); - let results: Vec = load_from_file(&output_path); - assert_eq!(results.len(), num_jobs(&raw)); + let results: Vec = load_from_file(output_path); + assert_eq!(results.len(), num_jobs(raw)); + } + + #[cfg(not(feature = "document-index"))] + fn run_document_filter_integration( + input_path: &std::path::Path, + output_path: &std::path::Path, + _raw: &serde_json::Value, + ) { + let command = Commands::Run { + input_file: input_path.to_owned(), + output_file: output_path.to_owned(), + dry_run: false, + }; + let cli = Cli::from_commands(command, true); + let mut output = Memory::new(); + + let err = cli.run(&mut output).unwrap_err(); + println!("err = {:?}", err); + + let output = String::from_utf8(output.into_inner()).unwrap(); + assert!(output.contains("\"document-index\" feature")); + println!("output = {}", output); + + // The output file should not have been created because we failed. + assert!(!output_path.exists()); } } From c1b7a3c8aa1af0188d389d9ce7a93acdf498be60 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 23 Mar 2026 18:57:10 +0530 Subject: [PATCH 35/50] changes to Cargo.lock --- Cargo.lock | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.lock b/Cargo.lock index 6a06f4dfc..67a10f71b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -799,6 +799,7 @@ dependencies = [ "serde_json", "tempfile", "thiserror 2.0.17", + "tokio", "tracing", ] From 5ebe4770c4a202c65a464fa5f237c8e31769460b Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 26 Mar 2026 19:05:07 +0530 Subject: [PATCH 36/50] Move the tests to a separate folder as is the convention --- .../document_insert_strategy.rs | 123 --------- .../inline_beta_search/inline_beta_filter.rs | 233 +----------------- diskann-label-filter/src/lib.rs | 4 + .../tests/document_insert_strategy_test.rs | 126 ++++++++++ .../src/tests/inline_beta_filter_test.rs | 226 +++++++++++++++++ 5 files changed, 362 insertions(+), 350 deletions(-) create mode 100644 diskann-label-filter/src/tests/document_insert_strategy_test.rs create mode 100644 diskann-label-filter/src/tests/inline_beta_filter_test.rs diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index 15f40e5fe..f26596147 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -229,126 +229,3 @@ where } } -#[cfg(test)] -mod tests { - use super::{DocumentInsertStrategy, DocumentPruneStrategy, DocumentSearchAccessor}; - use diskann::{ - graph::{ - glue::{InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, - test::provider::{Config, Context, Provider, StartPoint, Strategy}, - }, - provider::BuildQueryComputer, - }; - use diskann_vector::distance::Metric; - - use crate::{ - document::Document, - encoded_attribute_provider::{ - document_provider::DocumentProvider, roaring_attribute_store::RoaringAttributeStore, - }, - }; - - // --------------------------------------------------------------------------- - // Helpers - // --------------------------------------------------------------------------- - - /// Build a minimal test provider with a single start point and three dimensions. - fn make_test_provider() -> Provider { - let config = Config::new( - Metric::L2, - 10, - StartPoint::new(u32::MAX, vec![1.0f32, 2.0, 0.0]), - ) - .expect("test provider config should be valid"); - Provider::new(config) - } - - fn make_doc_provider( - provider: Provider, - ) -> DocumentProvider> { - DocumentProvider::new(provider, RoaringAttributeStore::new()) - } - - /// `search_accessor` successfully creates a `DocumentSearchAccessor` wrapping the - /// inner accessor. - #[test] - fn test_search_accessor_creates_wrapped_accessor() { - let strategy = DocumentInsertStrategy::new(Strategy::new()); - let provider = make_doc_provider(make_test_provider()); - let context = Context::new(); - - let result = as SearchStrategy< - DocumentProvider>, - Document<'_, [f32]>, - >>::search_accessor(&strategy, &provider, &context); - - assert!(result.is_ok()); - } - - #[test] - fn test_insert_search_accessor_creates_wrapped_accessor() { - let strategy = DocumentInsertStrategy::new(Strategy::new()); - let provider = make_doc_provider(make_test_provider()); - let context = Context::new(); - - let result = as InsertStrategy< - DocumentProvider>, - Document<'_, [f32]>, - >>::insert_search_accessor(&strategy, &provider, &context); - - assert!(result.is_ok()); - } - - #[test] - fn test_prune_accessor_delegates_to_inner_provider() { - let doc_prune_strategy = DocumentPruneStrategy::new(Strategy::new()); - let provider = make_doc_provider(make_test_provider()); - let context = Context::new(); - - let result = as PruneStrategy< - DocumentProvider>, - >>::prune_accessor(&doc_prune_strategy, &provider, &context); - - assert!(result.is_ok()); - } - - #[test] - fn test_build_query_computer_extracts_vector_from_document() { - let provider = make_test_provider(); - let context = Context::new(); - let strategy_inner = Strategy::new(); - let inner_accessor = strategy_inner - .search_accessor(&provider, &context) - .expect("creating search accessor should succeed"); - let doc_accessor = DocumentSearchAccessor::new(inner_accessor); - - let vector = vec![1.0f32, 2.0, 0.0]; - let doc = Document::new(vector.as_slice(), vec![]); - - let result = as BuildQueryComputer< - Document<'_, [f32]>, - >>::build_query_computer(&doc_accessor, &doc); - - assert!( - result.is_ok(), - "build_query_computer should succeed for a valid vector" - ); - } - - #[test] - fn test_terminate_early_delegates_to_inner() { - let provider = make_test_provider(); - let context = Context::new(); - let strategy_inner = Strategy::new(); - let mut inner_accessor = strategy_inner - .search_accessor(&provider, &context) - .expect("creating search accessor should succeed"); - let inner_terminate_early = inner_accessor.terminate_early(); - let mut doc_accessor = DocumentSearchAccessor::new(inner_accessor); - assert_eq!( - inner_terminate_early, - doc_accessor.terminate_early(), - "terminate_early should have same value as inner accessor" - ); - } -} diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index b1cfbca4b..33b7d7fc7 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -126,6 +126,12 @@ pub struct FilterResults { inner_post_processor: IPP, } +impl FilterResults { + pub(crate) fn new(inner_post_processor: IPP) -> Self { + Self { inner_post_processor } + } +} + impl<'a, Q, IA, IPP> SearchPostProcess, FilteredQuery<'a, Q>> for FilterResults where @@ -175,230 +181,3 @@ where .map_err(|e| e.into()) } } - -#[cfg(test)] -mod tests { - use std::sync::{Arc, RwLock}; - - use diskann::{ - graph::{ - glue::{self, SearchPostProcess, SearchStrategy}, - search_output_buffer::IdDistance, - test::provider::{Config, Context, Provider, StartPoint, Strategy}, - }, - neighbor::Neighbor, - provider::{BuildQueryComputer, SetElement}, - }; - use diskann_vector::{distance::Metric, PreprocessedDistanceFunction}; - use roaring::RoaringTreemap; - use serde_json::Value; - - use crate::{ - attribute::{Attribute, AttributeValue}, - document::EncodedDocument, - encoded_attribute_provider::{ - attribute_encoder::AttributeEncoder, encoded_filter_expr::EncodedFilterExpr, - roaring_attribute_store::RoaringAttributeStore, - }, - inline_beta_search::encoded_document_accessor::EncodedDocumentAccessor, - query::FilteredQuery, - traits::attribute_store::AttributeStore, - ASTExpr, CompareOp, - }; - - use super::{FilterResults, InlineBetaComputer}; - - // ----------------------------------------------------------------------- - // Stub inner distance computer - // ----------------------------------------------------------------------- - - /// Always returns a fixed constant distance, regardless of the vector value. - struct ConstComputer(f32); - - impl PreprocessedDistanceFunction<&[f32], f32> for ConstComputer { - fn evaluate_similarity(&self, _: &[f32]) -> f32 { - self.0 - } - } - - // ----------------------------------------------------------------------- - // Helper: build an AttributeEncoder + ASTExpr for `field == value`, - // returning (attr_map, ast_expr, encoded_id_of_that_attribute). - // ----------------------------------------------------------------------- - - fn setup_encoder_and_filter( - field: &str, - value: &str, - ) -> (Arc>, ASTExpr, u64) { - let mut encoder = AttributeEncoder::new(); - let attr = Attribute::from_value(field, AttributeValue::String(value.to_owned())); - let encoded_id = encoder.insert(&attr); - let attr_map = Arc::new(RwLock::new(encoder)); - let ast_expr = ASTExpr::Compare { - field: field.to_string(), - op: CompareOp::Eq(Value::String(value.to_string())), - }; - (attr_map, ast_expr, encoded_id) - } - - // ----------------------------------------------------------------------- - // Test 1: when the filter matches, evaluate_similarity returns inner * beta - // ----------------------------------------------------------------------- - - #[test] - fn test_evaluate_similarity_filter_match_scales_by_beta() { - let (attr_map, ast_expr, color_red_id) = setup_encoder_and_filter("color", "red"); - let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); - - let beta = 2.5_f32; - let inner_dist = 4.0_f32; - let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); - - // Bitmap contains the encoded ID for "color=red" → predicate matches - let mut matching_map = RoaringTreemap::new(); - matching_map.insert(color_red_id); - let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &matching_map); - - assert_eq!( - computer.evaluate_similarity(doc), - inner_dist * beta, - "a matched filter should multiply the inner similarity by beta" - ); - } - - // ----------------------------------------------------------------------- - // Test 2: when the filter does not match, evaluate_similarity is unchanged - // ----------------------------------------------------------------------- - - #[test] - fn test_evaluate_similarity_no_filter_match_preserves_score() { - let (attr_map, ast_expr, _) = setup_encoder_and_filter("color", "red"); - let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); - - let beta = 2.5_f32; - let inner_dist = 4.0_f32; - let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); - - // Empty bitmap → no attribute matches the predicate - let empty_map = RoaringTreemap::new(); - let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &empty_map); - - assert_eq!( - computer.evaluate_similarity(doc), - inner_dist, - "an unmatched filter should leave the inner similarity unchanged" - ); - } - - // ----------------------------------------------------------------------- - // Test 3: post_process forwards only filter-matching candidates to the - // inner post processor (and therefore to the output buffer). - // ----------------------------------------------------------------------- - - #[test] - fn test_post_process_only_passes_matching_candidates_to_inner() { - let rt = tokio::runtime::Builder::new_current_thread() - .build() - .expect("test tokio runtime"); - - // IDs 0 and 1 carry color=red (should pass the filter) - // IDs 2 and 3 carry color=blue (should be dropped by the filter) - let attr_store = RoaringAttributeStore::::new(); - let red = Attribute::from_value("color", AttributeValue::String("red".to_owned())); - let blue = Attribute::from_value("color", AttributeValue::String("blue".to_owned())); - for id in 0u32..2 { - attr_store - .set_element(&id, std::slice::from_ref(&red)) - .expect("set red attr"); - } - for id in 2u32..4 { - attr_store - .set_element(&id, std::slice::from_ref(&blue)) - .expect("set blue attr"); - } - - // The attribute_map is shared so EncodedFilterExpr sees the same encodings - // as those stored by the attribute store. - let attr_map = attr_store.attribute_map(); - - let ast_expr = ASTExpr::Compare { - field: "color".to_string(), - op: CompareOp::Eq(Value::String("red".to_string())), - }; - let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map.clone()).expect("filter expr"); - - // Build the inner vector provider: start point at u32::MAX + 2-D zero vectors for 0..3 - let config = Config::new(Metric::L2, 10, StartPoint::new(u32::MAX, vec![1.0f32, 0.0])) - .expect("provider config"); - let inner_provider = Provider::new(config); - let ctx = Context::new(); - rt.block_on(async { - for id in 0u32..4 { - inner_provider - .set_element(&ctx, &id, &[0.0f32, 0.0] as &[f32]) - .await - .expect("add vector to inner provider"); - } - }); - - // Obtain the inner search accessor and derive an inner computer from it - let strategy = Strategy::new(); - let inner_accessor = strategy - .search_accessor(&inner_provider, &ctx) - .expect("inner accessor"); - let inner_computer = inner_accessor - .build_query_computer(&[0.0f32, 0.0][..]) - .expect("inner computer"); - - // Wrap accessor + attribute store into an EncodedDocumentAccessor - let attribute_accessor = attr_store.attribute_accessor().expect("attribute accessor"); - let mut doc_accessor = - EncodedDocumentAccessor::new(inner_accessor, attribute_accessor, attr_map, 2.0); - - let computer = InlineBetaComputer::new(inner_computer, 2.0, filter_expr); - - // Four candidates: 0 and 1 match (red); 2 and 3 do not (blue) - let candidates = [ - Neighbor::new(0u32, 1.0_f32), - Neighbor::new(1u32, 2.0_f32), - Neighbor::new(2u32, 3.0_f32), - Neighbor::new(3u32, 4.0_f32), - ]; - - let mut ids = [u32::MAX; 4]; - let mut distances = [f32::MAX; 4]; - let mut output = IdDistance::new(&mut ids, &mut distances); - - let query_vec = [0.0f32, 0.0]; - let filter_query = FilteredQuery::new(&query_vec[..], ast_expr); - - // CopyIds simply copies whatever it receives into the output buffer, - // so the output reflects exactly what FilterResults lets through. - let count = rt - .block_on( - FilterResults { - inner_post_processor: glue::CopyIds, - } - .post_process( - &mut doc_accessor, - &filter_query, - &computer, - candidates.into_iter(), - &mut output, - ), - ) - .expect("post_process"); - - // Only the two red-labeled candidates should have been forwarded - assert_eq!(count, 2, "exactly 2 of 4 candidates should pass the filter"); - let passed = &ids[..count]; - assert!( - passed.contains(&0), - "ID 0 (color=red) should pass the filter" - ); - assert!( - passed.contains(&1), - "ID 1 (color=red) should pass the filter" - ); - } -} diff --git a/diskann-label-filter/src/lib.rs b/diskann-label-filter/src/lib.rs index 273475b15..414b868d4 100644 --- a/diskann-label-filter/src/lib.rs +++ b/diskann-label-filter/src/lib.rs @@ -53,6 +53,10 @@ pub mod tests { #[cfg(test)] pub mod common; #[cfg(test)] + pub mod document_insert_strategy_test; + #[cfg(test)] + pub mod inline_beta_filter_test; + #[cfg(test)] pub mod roaring_attribute_store_test; } diff --git a/diskann-label-filter/src/tests/document_insert_strategy_test.rs b/diskann-label-filter/src/tests/document_insert_strategy_test.rs new file mode 100644 index 000000000..2d7f32b16 --- /dev/null +++ b/diskann-label-filter/src/tests/document_insert_strategy_test.rs @@ -0,0 +1,126 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use diskann::{ + graph::{ + glue::{InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, + test::provider::{Config, Context, Provider, StartPoint, Strategy}, + }, + provider::BuildQueryComputer, +}; +use diskann_vector::distance::Metric; + +use crate::{ + document::Document, + encoded_attribute_provider::{ + document_insert_strategy::{ + DocumentInsertStrategy, DocumentPruneStrategy, DocumentSearchAccessor, + }, + document_provider::DocumentProvider, + roaring_attribute_store::RoaringAttributeStore, + }, +}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Build a minimal test provider with a single start point and three dimensions. +fn make_test_provider() -> Provider { + let config = Config::new( + Metric::L2, + 10, + StartPoint::new(u32::MAX, vec![1.0f32, 2.0, 0.0]), + ) + .expect("test provider config should be valid"); + Provider::new(config) +} + +fn make_doc_provider( + provider: Provider, +) -> DocumentProvider> { + DocumentProvider::new(provider, RoaringAttributeStore::new()) +} + +/// `search_accessor` successfully creates a `DocumentSearchAccessor` wrapping the +/// inner accessor. +#[test] +fn test_search_accessor_creates_wrapped_accessor() { + let strategy = DocumentInsertStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as SearchStrategy< + DocumentProvider>, + Document<'_, [f32]>, + >>::search_accessor(&strategy, &provider, &context); + + assert!(result.is_ok()); +} + +#[test] +fn test_insert_search_accessor_creates_wrapped_accessor() { + let strategy = DocumentInsertStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as InsertStrategy< + DocumentProvider>, + Document<'_, [f32]>, + >>::insert_search_accessor(&strategy, &provider, &context); + + assert!(result.is_ok()); +} + +#[test] +fn test_prune_accessor_delegates_to_inner_provider() { + let doc_prune_strategy = DocumentPruneStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as PruneStrategy< + DocumentProvider>, + >>::prune_accessor(&doc_prune_strategy, &provider, &context); + + assert!(result.is_ok()); +} + +#[test] +fn test_build_query_computer_extracts_vector_from_document() { + let provider = make_test_provider(); + let context = Context::new(); + let strategy_inner = Strategy::new(); + let inner_accessor = strategy_inner + .search_accessor(&provider, &context) + .expect("creating search accessor should succeed"); + let doc_accessor = DocumentSearchAccessor::new(inner_accessor); + + let vector = vec![1.0f32, 2.0, 0.0]; + let doc = Document::new(vector.as_slice(), vec![]); + + let result = as BuildQueryComputer>>::build_query_computer(&doc_accessor, &doc); + + assert!( + result.is_ok(), + "build_query_computer should succeed for a valid vector" + ); +} + +#[test] +fn test_terminate_early_delegates_to_inner() { + let provider = make_test_provider(); + let context = Context::new(); + let strategy_inner = Strategy::new(); + let mut inner_accessor = strategy_inner + .search_accessor(&provider, &context) + .expect("creating search accessor should succeed"); + let inner_terminate_early = inner_accessor.terminate_early(); + let mut doc_accessor = DocumentSearchAccessor::new(inner_accessor); + assert_eq!( + inner_terminate_early, + doc_accessor.terminate_early(), + "terminate_early should have same value as inner accessor" + ); +} diff --git a/diskann-label-filter/src/tests/inline_beta_filter_test.rs b/diskann-label-filter/src/tests/inline_beta_filter_test.rs new file mode 100644 index 000000000..0ca16512d --- /dev/null +++ b/diskann-label-filter/src/tests/inline_beta_filter_test.rs @@ -0,0 +1,226 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use std::sync::{Arc, RwLock}; + +use diskann::{ + graph::{ + glue::{self, SearchPostProcess, SearchStrategy}, + search_output_buffer::IdDistance, + test::provider::{Config, Context, Provider, StartPoint, Strategy}, + }, + neighbor::Neighbor, + provider::{BuildQueryComputer, SetElement}, +}; +use diskann_vector::{distance::Metric, PreprocessedDistanceFunction}; +use roaring::RoaringTreemap; +use serde_json::Value; + +use crate::{ + attribute::{Attribute, AttributeValue}, + document::EncodedDocument, + encoded_attribute_provider::{ + attribute_encoder::AttributeEncoder, encoded_filter_expr::EncodedFilterExpr, + roaring_attribute_store::RoaringAttributeStore, + }, + inline_beta_search::{ + encoded_document_accessor::EncodedDocumentAccessor, + inline_beta_filter::{FilterResults, InlineBetaComputer}, + }, + query::FilteredQuery, + traits::attribute_store::AttributeStore, + ASTExpr, CompareOp, +}; + +// ----------------------------------------------------------------------- +// Stub inner distance computer +// ----------------------------------------------------------------------- + +/// Always returns a fixed constant distance, regardless of the vector value. +struct ConstComputer(f32); + +impl PreprocessedDistanceFunction<&[f32], f32> for ConstComputer { + fn evaluate_similarity(&self, _: &[f32]) -> f32 { + self.0 + } +} + +// ----------------------------------------------------------------------- +// Helper: build an AttributeEncoder + ASTExpr for `field == value`, +// returning (attr_map, ast_expr, encoded_id_of_that_attribute). +// ----------------------------------------------------------------------- + +fn setup_encoder_and_filter( + field: &str, + value: &str, +) -> (Arc>, ASTExpr, u64) { + let mut encoder = AttributeEncoder::new(); + let attr = Attribute::from_value(field, AttributeValue::String(value.to_owned())); + let encoded_id = encoder.insert(&attr); + let attr_map = Arc::new(RwLock::new(encoder)); + let ast_expr = ASTExpr::Compare { + field: field.to_string(), + op: CompareOp::Eq(Value::String(value.to_string())), + }; + (attr_map, ast_expr, encoded_id) +} + +// ----------------------------------------------------------------------- +// Test 1: when the filter matches, evaluate_similarity returns inner * beta +// ----------------------------------------------------------------------- + +#[test] +fn test_evaluate_similarity_filter_match_scales_by_beta() { + let (attr_map, ast_expr, color_red_id) = setup_encoder_and_filter("color", "red"); + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); + + let beta = 2.5_f32; + let inner_dist = 4.0_f32; + let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); + + // Bitmap contains the encoded ID for "color=red" → predicate matches + let mut matching_map = RoaringTreemap::new(); + matching_map.insert(color_red_id); + let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &matching_map); + + assert_eq!( + computer.evaluate_similarity(doc), + inner_dist * beta, + "a matched filter should multiply the inner similarity by beta" + ); +} + +// ----------------------------------------------------------------------- +// Test 2: when the filter does not match, evaluate_similarity is unchanged +// ----------------------------------------------------------------------- + +#[test] +fn test_evaluate_similarity_no_filter_match_preserves_score() { + let (attr_map, ast_expr, _) = setup_encoder_and_filter("color", "red"); + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); + + let beta = 2.5_f32; + let inner_dist = 4.0_f32; + let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); + + // Empty bitmap → no attribute matches the predicate + let empty_map = RoaringTreemap::new(); + let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &empty_map); + + assert_eq!( + computer.evaluate_similarity(doc), + inner_dist, + "an unmatched filter should leave the inner similarity unchanged" + ); +} + +// ----------------------------------------------------------------------- +// Test 3: post_process forwards only filter-matching candidates to the +// inner post processor (and therefore to the output buffer). +// ----------------------------------------------------------------------- + +#[test] +fn test_post_process_only_passes_matching_candidates_to_inner() { + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .expect("test tokio runtime"); + + // IDs 0 and 1 carry color=red (should pass the filter) + // IDs 2 and 3 carry color=blue (should be dropped by the filter) + let attr_store = RoaringAttributeStore::::new(); + let red = Attribute::from_value("color", AttributeValue::String("red".to_owned())); + let blue = Attribute::from_value("color", AttributeValue::String("blue".to_owned())); + for id in 0u32..2 { + attr_store + .set_element(&id, std::slice::from_ref(&red)) + .expect("set red attr"); + } + for id in 2u32..4 { + attr_store + .set_element(&id, std::slice::from_ref(&blue)) + .expect("set blue attr"); + } + + // The attribute_map is shared so EncodedFilterExpr sees the same encodings + // as those stored by the attribute store. + let attr_map = attr_store.attribute_map(); + + let ast_expr = ASTExpr::Compare { + field: "color".to_string(), + op: CompareOp::Eq(Value::String("red".to_string())), + }; + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map.clone()).expect("filter expr"); + + // Build the inner vector provider: start point at u32::MAX + 2-D zero vectors for 0..3 + let config = Config::new(Metric::L2, 10, StartPoint::new(u32::MAX, vec![1.0f32, 0.0])) + .expect("provider config"); + let inner_provider = Provider::new(config); + let ctx = Context::new(); + rt.block_on(async { + for id in 0u32..4 { + inner_provider + .set_element(&ctx, &id, &[0.0f32, 0.0] as &[f32]) + .await + .expect("add vector to inner provider"); + } + }); + + // Obtain the inner search accessor and derive an inner computer from it + let strategy = Strategy::new(); + let inner_accessor = strategy + .search_accessor(&inner_provider, &ctx) + .expect("inner accessor"); + let inner_computer = inner_accessor + .build_query_computer(&[0.0f32, 0.0][..]) + .expect("inner computer"); + + // Wrap accessor + attribute store into an EncodedDocumentAccessor + let attribute_accessor = attr_store.attribute_accessor().expect("attribute accessor"); + let mut doc_accessor = + EncodedDocumentAccessor::new(inner_accessor, attribute_accessor, attr_map, 2.0); + + let computer = InlineBetaComputer::new(inner_computer, 2.0, filter_expr); + + // Four candidates: 0 and 1 match (red); 2 and 3 do not (blue) + let candidates = [ + Neighbor::new(0u32, 1.0_f32), + Neighbor::new(1u32, 2.0_f32), + Neighbor::new(2u32, 3.0_f32), + Neighbor::new(3u32, 4.0_f32), + ]; + + let mut ids = [u32::MAX; 4]; + let mut distances = [f32::MAX; 4]; + let mut output = IdDistance::new(&mut ids, &mut distances); + + let query_vec = [0.0f32, 0.0]; + let filter_query = FilteredQuery::new(&query_vec[..], ast_expr); + + // CopyIds simply copies whatever it receives into the output buffer, + // so the output reflects exactly what FilterResults lets through. + let count = rt + .block_on( + FilterResults::new(glue::CopyIds).post_process( + &mut doc_accessor, + &filter_query, + &computer, + candidates.into_iter(), + &mut output, + ), + ) + .expect("post_process"); + + // Only the two red-labeled candidates should have been forwarded + assert_eq!(count, 2, "exactly 2 of 4 candidates should pass the filter"); + let passed = &ids[..count]; + assert!( + passed.contains(&0), + "ID 0 (color=red) should pass the filter" + ); + assert!( + passed.contains(&1), + "ID 1 (color=red) should pass the filter" + ); +} From c9a97bdcf8232fa9984c275107f8f979b3bf8323 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 26 Mar 2026 19:46:32 +0530 Subject: [PATCH 37/50] Fix formatting --- .../document_insert_strategy.rs | 1 - .../src/inline_beta_search/inline_beta_filter.rs | 5 ++++- .../src/tests/document_insert_strategy_test.rs | 4 +--- .../src/tests/inline_beta_filter_test.rs | 16 +++++++--------- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index f26596147..aca755833 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -228,4 +228,3 @@ where .prune_accessor(provider.inner_provider(), context) } } - diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 33b7d7fc7..ebed7d95f 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -127,8 +127,11 @@ pub struct FilterResults { } impl FilterResults { + #[cfg(test)] pub(crate) fn new(inner_post_processor: IPP) -> Self { - Self { inner_post_processor } + Self { + inner_post_processor, + } } } diff --git a/diskann-label-filter/src/tests/document_insert_strategy_test.rs b/diskann-label-filter/src/tests/document_insert_strategy_test.rs index 2d7f32b16..5fb78dd75 100644 --- a/diskann-label-filter/src/tests/document_insert_strategy_test.rs +++ b/diskann-label-filter/src/tests/document_insert_strategy_test.rs @@ -38,9 +38,7 @@ fn make_test_provider() -> Provider { Provider::new(config) } -fn make_doc_provider( - provider: Provider, -) -> DocumentProvider> { +fn make_doc_provider(provider: Provider) -> DocumentProvider> { DocumentProvider::new(provider, RoaringAttributeStore::new()) } diff --git a/diskann-label-filter/src/tests/inline_beta_filter_test.rs b/diskann-label-filter/src/tests/inline_beta_filter_test.rs index 0ca16512d..079f91eff 100644 --- a/diskann-label-filter/src/tests/inline_beta_filter_test.rs +++ b/diskann-label-filter/src/tests/inline_beta_filter_test.rs @@ -201,15 +201,13 @@ fn test_post_process_only_passes_matching_candidates_to_inner() { // CopyIds simply copies whatever it receives into the output buffer, // so the output reflects exactly what FilterResults lets through. let count = rt - .block_on( - FilterResults::new(glue::CopyIds).post_process( - &mut doc_accessor, - &filter_query, - &computer, - candidates.into_iter(), - &mut output, - ), - ) + .block_on(FilterResults::new(glue::CopyIds).post_process( + &mut doc_accessor, + &filter_query, + &computer, + candidates.into_iter(), + &mut output, + )) .expect("post_process"); // Only the two red-labeled candidates should have been forwarded From cc6a9b64131610ab1b105e0c6149e54e721837bf Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 26 Mar 2026 21:20:07 +0530 Subject: [PATCH 38/50] Fix formatting --- .../encoded_attribute_provider/document_insert_strategy.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index aca755833..2c487bd80 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -44,15 +44,11 @@ impl DocumentInsertStrategy { /// Wrapper accessor for Document queries pub struct DocumentSearchAccessor { inner: Inner, - // _phantom: PhantomData VT>, } impl DocumentSearchAccessor { pub fn new(inner: Inner) -> Self { - Self { - inner, - // _phantom: PhantomData, - } + Self { inner } } } From 47a2e6946a397f048b88da872f58d09d0ee996f5 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 26 Mar 2026 21:44:25 +0530 Subject: [PATCH 39/50] Fix build errors after merging with main --- .../src/backend/document_index/benchmark.rs | 11 ++++++----- .../document_insert_strategy.rs | 7 +------ .../src/inline_beta_search/inline_beta_filter.rs | 4 ++-- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index a284e93db..8f3d56175 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -15,7 +15,8 @@ use std::sync::Arc; use anyhow::Result; use diskann::{ graph::{ - search::Knn, search_output_buffer, DiskANNIndex, SearchOutputBuffer, StartPointStrategy, + glue, search::Knn, search_output_buffer, DiskANNIndex, SearchOutputBuffer, + StartPointStrategy, }, provider::DefaultContext, ANNError, ANNErrorKind, @@ -429,8 +430,8 @@ where > + Send + Sync + 'static, - for<'a> InlineBetaStrategy: - diskann::graph::glue::SearchStrategy, u32>, + for<'a> InlineBetaStrategy: glue::SearchStrategy> + + glue::DefaultPostProcessor, u32>, T: bytemuck::Pod + Copy + Send + Sync + 'static, { type Id = DP::ExternalId; @@ -638,8 +639,8 @@ where > + Send + Sync + 'static, - for<'a> InlineBetaStrategy: - diskann::graph::glue::SearchStrategy>, + for<'a> InlineBetaStrategy: glue::SearchStrategy> + + glue::DefaultPostProcessor, u32>, { let searcher = Arc::new(FilteredSearcher { index: index.clone(), diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index 2c487bd80..6e6886eda 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -7,7 +7,7 @@ //! [DiskANNIndex] using a [DocumentProvider]. use diskann::{ - graph::glue::{self, ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, + graph::glue::{ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, provider::{Accessor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, ANNResult, }; @@ -148,7 +148,6 @@ where VT: Sync + Send + ?Sized + 'static, { type QueryComputer = Inner::QueryComputer; - type PostProcessor = glue::CopyIds; type SearchAccessorError = Inner::SearchAccessorError; type SearchAccessor<'a> = DocumentSearchAccessor>; @@ -162,10 +161,6 @@ where .search_accessor(provider.inner_provider(), context)?; Ok(DocumentSearchAccessor::new(inner_accessor)) } - - fn post_processor(&self) -> Self::PostProcessor { - glue::CopyIds - } } impl<'doc, Inner, DP, VT> diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index bac127b37..f7a4f505d 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -74,12 +74,12 @@ where impl diskann::graph::glue::DefaultPostProcessor< DocumentProvider>, - FilteredQuery, + FilteredQuery<'_, Q>, > for InlineBetaStrategy where DP: DataProvider, Strategy: diskann::graph::glue::DefaultPostProcessor, - Q: AsyncFriendly + Clone, + Q: Send + Sync + ?Sized, { type Processor = FilterResults; From 28af8d40af621ecb1edd739a36e3b054b8c8a233 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Wed, 1 Apr 2026 20:46:17 +0530 Subject: [PATCH 40/50] Fix build errors after merge. This commit mostly changes the type of the documentItem on which insertStrategy and setElement are implemented to references. This is due to the change in index.insert requiring copy trait on the element being set. Other changes include fixes to api breakage. --- .../src/backend/document_index/benchmark.rs | 82 +++++++++---------- .../src/inputs/document_index.rs | 4 +- .../document_insert_strategy.rs | 32 +++++--- .../document_provider.rs | 6 +- .../encoded_document_accessor.rs | 3 +- .../inline_beta_search/inline_beta_filter.rs | 14 ++-- 6 files changed, 71 insertions(+), 70 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 8f3d56175..395888a56 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -19,10 +19,12 @@ use diskann::{ StartPointStrategy, }, provider::DefaultContext, + utils::VectorRepr, ANNError, ANNErrorKind, }; use diskann_benchmark_core::{ - build::{self, Build, Parallelism}, + build, + build::{Build, Parallelism}, recall, search as search_api, tokio, }; use diskann_benchmark_runner::{ @@ -30,7 +32,7 @@ use diskann_benchmark_runner::{ output::Output, registry::Benchmarks, utils::{datatype, fmt, percentiles, MicroSeconds}, - Any, + Benchmark, }; use diskann_label_filter::{ attribute::{Attribute, AttributeValue}, @@ -68,13 +70,7 @@ use crate::{ /// Register the document index benchmarks. pub(crate) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register::>( - "document-index-build-f32", - |job, _checkpoint, out| { - let stats = job.run(out)?; - Ok(serde_json::to_value(stats)?) - }, - ); + benchmarks.register::>("document-index-build-f32"); } /// Document index benchmark job. @@ -92,45 +88,41 @@ impl<'a, T> DocumentIndexJob<'a, T> { } } -impl diskann_benchmark_runner::dispatcher::Map for DocumentIndexJob<'static, T> { - type Type<'a> = DocumentIndexJob<'a, T>; -} - -// Dispatch from the concrete input type -impl<'a, T> DispatchRule<&'a DocumentIndexBuild> for DocumentIndexJob<'a, T> -where - datatype::Type: DispatchRule, -{ - type Error = std::convert::Infallible; - - fn try_match(_from: &&'a DocumentIndexBuild) -> Result { - datatype::Type::::try_match(&_from.build.data_type) - } - - fn convert(from: &'a DocumentIndexBuild) -> Result { - Ok(DocumentIndexJob::new(from)) - } -} - -// Central dispatch mapping from Any -impl<'a, T> DispatchRule<&'a Any> for DocumentIndexJob<'a, T> +impl Benchmark for DocumentIndexJob<'static, T> where + T: VectorRepr + + diskann::graph::SampleableForStart + + diskann_utils::sampling::WithApproximateNorm + + 'static, datatype::Type: DispatchRule, + for<'b> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'b [T], &'b [T]>, { - type Error = anyhow::Error; + type Input = DocumentIndexBuild; + type Output = DocumentIndexStats; - fn try_match(from: &&'a Any) -> Result { - from.try_match::() + fn try_match(input: &Self::Input) -> std::result::Result { + datatype::Type::::try_match(&input.build.data_type) } - fn convert(from: &'a Any) -> Result { - from.convert::() + fn description( + f: &mut std::fmt::Formatter<'_>, + input: Option<&Self::Input>, + ) -> std::fmt::Result { + match input { + Some(arg) => datatype::Type::::description(f, Some(&arg.build.data_type)), + None => datatype::Type::::description(f, None::<&datatype::DataType>), + } } - fn description(f: &mut std::fmt::Formatter, from: Option<&&'a Any>) -> std::fmt::Result { - Any::description::(f, from, DocumentIndexBuild::tag()) + fn run( + input: &Self::Input, + _checkpoint: diskann_benchmark_runner::Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result { + DocumentIndexJob::::new(input).run(output) } } + /// Convert a HashMap to Vec fn hashmap_to_attributes(map: std::collections::HashMap) -> Vec { map.into_iter() @@ -430,8 +422,8 @@ where > + Send + Sync + 'static, - for<'a> InlineBetaStrategy: glue::SearchStrategy> - + glue::DefaultPostProcessor, u32>, + for<'a> InlineBetaStrategy: glue::SearchStrategy> + + glue::DefaultPostProcessor, u32>, T: bytemuck::Pod + Copy + Send + Sync + 'static, { type Id = DP::ExternalId; @@ -639,8 +631,8 @@ where > + Send + Sync + 'static, - for<'a> InlineBetaStrategy: glue::SearchStrategy> - + glue::DefaultPostProcessor, u32>, + for<'a> InlineBetaStrategy: glue::SearchStrategy> + + glue::DefaultPostProcessor, u32>, { let searcher = Arc::new(FilteredSearcher { index: index.clone(), @@ -813,10 +805,10 @@ impl DocumentIndexBuilder { impl Build for DocumentIndexBuilder where DP: diskann::provider::DataProvider - + for<'doc> diskann::provider::SetElement> + + for<'a> diskann::provider::SetElement<&'a Document<'a, [T]>> + AsyncFriendly, - for<'doc> DocumentInsertStrategy: - diskann::graph::glue::InsertStrategy>, + for<'a> DocumentInsertStrategy: + diskann::graph::glue::InsertStrategy>, DocumentInsertStrategy: AsyncFriendly, T: AsyncFriendly, { diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs index f1ed3c063..3d45d2fa3 100644 --- a/diskann-benchmark/src/inputs/document_index.rs +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -18,7 +18,7 @@ use diskann_benchmark_runner::{ use serde::{Deserialize, Serialize}; use super::async_::GraphSearch; -use crate::inputs::{as_input, Example, Input}; +use crate::inputs::{as_input, Example}; ////////////// // Registry // @@ -29,7 +29,7 @@ as_input!(DocumentIndexBuild); pub(super) fn register_inputs( registry: &mut diskann_benchmark_runner::registry::Inputs, ) -> anyhow::Result<()> { - registry.register(Input::::new())?; + registry.register::()?; Ok(()) } diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index 6e6886eda..0e8176b8f 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -68,7 +68,6 @@ where = Inner::Element<'a> where Self: 'a; - type Extended = Inner::Extended; type GetError = Inner::GetError; fn get_element( @@ -92,9 +91,9 @@ where } } -impl<'doc, Inner, VT> BuildQueryComputer> for DocumentSearchAccessor +impl<'doc, Inner, VT> BuildQueryComputer<&'doc Document<'doc, VT>> for DocumentSearchAccessor where - Inner: BuildQueryComputer, + Inner: BuildQueryComputer<&'doc VT>, VT: ?Sized, { type QueryComputerError = Inner::QueryComputerError; @@ -102,7 +101,7 @@ where fn build_query_computer( &self, - from: &Document<'doc, VT>, + from: &'doc Document<'doc, VT>, ) -> Result { self.inner.build_query_computer(from.vector()) } @@ -118,9 +117,9 @@ where } } -impl<'doc, Inner, VT> ExpandBeam> for DocumentSearchAccessor +impl<'doc, Inner, VT> ExpandBeam<&'doc Document<'doc, VT>> for DocumentSearchAccessor where - Inner: ExpandBeam, + Inner: ExpandBeam<&'doc VT>, VT: ?Sized, { } @@ -140,10 +139,12 @@ where } impl<'doc, Inner, DP, VT> - SearchStrategy>, Document<'doc, VT>> - for DocumentInsertStrategy + SearchStrategy< + DocumentProvider>, + &'doc Document<'doc, VT>, + > for DocumentInsertStrategy where - Inner: InsertStrategy, + Inner: InsertStrategy, DP: DataProvider, VT: Sync + Send + ?Sized + 'static, { @@ -164,10 +165,12 @@ where } impl<'doc, Inner, DP, VT> - InsertStrategy>, Document<'doc, VT>> - for DocumentInsertStrategy + InsertStrategy< + DocumentProvider>, + &'doc Document<'doc, VT>, + > for DocumentInsertStrategy where - Inner: InsertStrategy, + Inner: InsertStrategy, DP: DataProvider, VT: Sync + Send + ?Sized + 'static, { @@ -209,6 +212,7 @@ where type DistanceComputer = Inner::DistanceComputer; type PruneAccessor<'a> = Inner::PruneAccessor<'a>; type PruneAccessorError = Inner::PruneAccessorError; + type WorkingSet = Inner::WorkingSet; fn prune_accessor<'a>( &'a self, @@ -218,4 +222,8 @@ where self.inner .prune_accessor(provider.inner_provider(), context) } + + fn create_working_set(&self, _capacity: usize) -> Self::WorkingSet { + self.inner.create_working_set(_capacity) + } } diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs index a80acfe41..33200c40b 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs @@ -74,9 +74,9 @@ where } } -impl<'a, VT, DP, AS> SetElement> for DocumentProvider +impl<'doc, VT, DP, AS> SetElement<&'doc Document<'doc, VT>> for DocumentProvider where - DP: DataProvider + Delete + SetElement<&'a VT>, + DP: DataProvider + Delete + SetElement<&'doc VT>, AS: AttributeStore + AsyncFriendly, VT: Sync + Send + ?Sized, { @@ -86,7 +86,7 @@ where &self, context: &Self::Context, id: &Self::ExternalId, - element: Document<'a, VT>, + element: &'doc Document<'doc, VT>, ) -> Result { let guard = self .inner_provider diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 853d6f63a..d2c3fd3fe 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -162,9 +162,10 @@ where } } -impl<'q, IA, Q> BuildQueryComputer<&'q FilteredQuery> for EncodedDocumentAccessor +impl<'q, IA, Q> BuildQueryComputer<&'q FilteredQuery<'q, Q>> for EncodedDocumentAccessor where IA: BuildQueryComputer<&'q Q>, + Q: ?Sized, { type QueryComputerError = ANNError; type QueryComputer = InlineBetaComputer; diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index be2320041..47fd5dab5 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -37,12 +37,12 @@ impl InlineBetaStrategy { impl<'q, DP, Strategy, Q> SearchStrategy< DocumentProvider>, - &'q FilteredQuery, + &'q FilteredQuery<'q, Q>, > for InlineBetaStrategy where DP: DataProvider, Strategy: SearchStrategy, - Q: Send + Sync, + Q: Send + Sync + ?Sized, { type QueryComputer = InlineBetaComputer; type SearchAccessorError = ANNError; @@ -74,12 +74,12 @@ where impl<'q, DP, Strategy, Q> diskann::graph::glue::DefaultPostProcessor< DocumentProvider>, - &'q FilteredQuery, + &'q FilteredQuery<'q, Q>, > for InlineBetaStrategy where DP: DataProvider, Strategy: diskann::graph::glue::DefaultPostProcessor, - Q: Send + Sync, + Q: Send + Sync + ?Sized, { type Processor = FilterResults; @@ -149,11 +149,11 @@ impl FilterResults { } } -impl<'a, 'q, Q, IA, IPP> SearchPostProcess, &'q FilteredQuery<'a, Q>> +impl<'q, Q, IA, IPP> SearchPostProcess, &'q FilteredQuery<'q, Q>> for FilterResults where IA: BuildQueryComputer<&'q Q>, - Q: Send + Sync, + Q: Send + Sync + ?Sized, IPP: SearchPostProcess + Send + Sync, { type Error = ANNError; @@ -161,7 +161,7 @@ where async fn post_process( &self, accessor: &mut EncodedDocumentAccessor, - query: &'q FilteredQuery, + query: &'q FilteredQuery<'q, Q>, computer: &InlineBetaComputer<>::QueryComputer>, candidates: I, output: &mut B, From cba7147528fdb17c02c69d19e0ae267a9847d71c Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 2 Apr 2026 23:34:17 +0530 Subject: [PATCH 41/50] Fix build errors in test --- .../src/tests/document_insert_strategy_test.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/diskann-label-filter/src/tests/document_insert_strategy_test.rs b/diskann-label-filter/src/tests/document_insert_strategy_test.rs index 5fb78dd75..a27de2c0a 100644 --- a/diskann-label-filter/src/tests/document_insert_strategy_test.rs +++ b/diskann-label-filter/src/tests/document_insert_strategy_test.rs @@ -52,7 +52,7 @@ fn test_search_accessor_creates_wrapped_accessor() { let result = as SearchStrategy< DocumentProvider>, - Document<'_, [f32]>, + &Document<'_, [f32]>, >>::search_accessor(&strategy, &provider, &context); assert!(result.is_ok()); @@ -66,7 +66,7 @@ fn test_insert_search_accessor_creates_wrapped_accessor() { let result = as InsertStrategy< DocumentProvider>, - Document<'_, [f32]>, + &Document<'_, [f32]>, >>::insert_search_accessor(&strategy, &provider, &context); assert!(result.is_ok()); @@ -98,7 +98,7 @@ fn test_build_query_computer_extracts_vector_from_document() { let vector = vec![1.0f32, 2.0, 0.0]; let doc = Document::new(vector.as_slice(), vec![]); - let result = as BuildQueryComputer>>::build_query_computer(&doc_accessor, &doc); + let result = as BuildQueryComputer<&Document<'_, [f32]>>>::build_query_computer(&doc_accessor, &doc); assert!( result.is_ok(), From aaf0488e986941cc0681909e8e3212e940d956e0 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Fri, 3 Apr 2026 15:08:59 +0530 Subject: [PATCH 42/50] Remove lifetime bound and ?Sized requirement on the inner query of FilteredQuery --- .../src/backend/document_index/benchmark.rs | 8 +++--- .../encoded_document_accessor.rs | 6 ++--- .../inline_beta_search/inline_beta_filter.rs | 25 ++++++++++--------- diskann-label-filter/src/query.rs | 17 ++++++++----- 4 files changed, 31 insertions(+), 25 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 395888a56..51c71a020 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -422,8 +422,8 @@ where > + Send + Sync + 'static, - for<'a> InlineBetaStrategy: glue::SearchStrategy> - + glue::DefaultPostProcessor, u32>, + for<'a> InlineBetaStrategy: glue::SearchStrategy> + + glue::DefaultPostProcessor, u32>, T: bytemuck::Pod + Copy + Send + Sync + 'static, { type Id = DP::ExternalId; @@ -631,8 +631,8 @@ where > + Send + Sync + 'static, - for<'a> InlineBetaStrategy: glue::SearchStrategy> - + glue::DefaultPostProcessor, u32>, + for<'a> InlineBetaStrategy: glue::SearchStrategy> + + glue::DefaultPostProcessor, u32>, { let searcher = Arc::new(FilteredSearcher { index: index.clone(), diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index d2c3fd3fe..15f3ce023 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -162,10 +162,10 @@ where } } -impl<'q, IA, Q> BuildQueryComputer<&'q FilteredQuery<'q, Q>> for EncodedDocumentAccessor +impl<'q, IA, Q> BuildQueryComputer<&'q FilteredQuery> for EncodedDocumentAccessor where - IA: BuildQueryComputer<&'q Q>, - Q: ?Sized, + IA: BuildQueryComputer, + Q: Reborrow<'q>, { type QueryComputerError = ANNError; type QueryComputer = InlineBetaComputer; diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 47fd5dab5..45ddf3a92 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -9,6 +9,7 @@ use diskann::neighbor::Neighbor; use diskann::provider::{Accessor, BuildQueryComputer, DataProvider}; use diskann::ANNError; +use diskann_utils::Reborrow; use diskann_vector::PreprocessedDistanceFunction; use roaring::RoaringTreemap; @@ -37,12 +38,12 @@ impl InlineBetaStrategy { impl<'q, DP, Strategy, Q> SearchStrategy< DocumentProvider>, - &'q FilteredQuery<'q, Q>, + &'q FilteredQuery, > for InlineBetaStrategy where DP: DataProvider, - Strategy: SearchStrategy, - Q: Send + Sync + ?Sized, + Strategy: SearchStrategy, + Q: Send + Sync + Reborrow<'q>, { type QueryComputer = InlineBetaComputer; type SearchAccessorError = ANNError; @@ -74,12 +75,12 @@ where impl<'q, DP, Strategy, Q> diskann::graph::glue::DefaultPostProcessor< DocumentProvider>, - &'q FilteredQuery<'q, Q>, + &'q FilteredQuery, > for InlineBetaStrategy where DP: DataProvider, - Strategy: diskann::graph::glue::DefaultPostProcessor, - Q: Send + Sync + ?Sized, + Strategy: diskann::graph::glue::DefaultPostProcessor, + Q: Send + Sync + Reborrow<'q>, { type Processor = FilterResults; @@ -149,20 +150,20 @@ impl FilterResults { } } -impl<'q, Q, IA, IPP> SearchPostProcess, &'q FilteredQuery<'q, Q>> +impl<'q, Q, IA, IPP> SearchPostProcess, &'q FilteredQuery> for FilterResults where - IA: BuildQueryComputer<&'q Q>, - Q: Send + Sync + ?Sized, - IPP: SearchPostProcess + Send + Sync, + IA: BuildQueryComputer, + Q: Send + Sync + Reborrow<'q>, + IPP: SearchPostProcess + Send + Sync, { type Error = ANNError; async fn post_process( &self, accessor: &mut EncodedDocumentAccessor, - query: &'q FilteredQuery<'q, Q>, - computer: &InlineBetaComputer<>::QueryComputer>, + query: &'q FilteredQuery, + computer: &InlineBetaComputer<>::QueryComputer>, candidates: I, output: &mut B, ) -> Result diff --git a/diskann-label-filter/src/query.rs b/diskann-label-filter/src/query.rs index b1a7a6409..777c91d10 100644 --- a/diskann-label-filter/src/query.rs +++ b/diskann-label-filter/src/query.rs @@ -3,24 +3,29 @@ * Licensed under the MIT license. */ +use diskann_utils::Reborrow; + use crate::ASTExpr; /// Type that can be used to specify a query with a filter expression. /// The Readme.md file in the label-filter folder describes the format /// of the query expression. #[derive(Clone)] -pub struct FilteredQuery<'a, V: ?Sized> { - query: &'a V, +pub struct FilteredQuery { + query: V, filter_expr: ASTExpr, } -impl<'a, V: ?Sized> FilteredQuery<'a, V> { - pub fn new(query: &'a V, filter_expr: ASTExpr) -> Self { +impl FilteredQuery { + pub fn new(query: V, filter_expr: ASTExpr) -> Self { Self { query, filter_expr } } - pub(crate) fn query(&self) -> &'a V { - self.query + pub(crate) fn query<'a>(&'a self) -> V::Target + where + V: Reborrow<'a>, + { + self.query.reborrow() } pub(crate) fn filter_expr(&self) -> &ASTExpr { From 5182b7a1135ead2ea5639fe32aae8fe676d1463d Mon Sep 17 00:00:00 2001 From: sampathrg Date: Fri, 3 Apr 2026 16:27:17 +0530 Subject: [PATCH 43/50] Remove unecessary clone() calls and hold a reference to the expression in the FilteredQuerySearch instead of owning it --- .../src/backend/document_index/benchmark.rs | 18 +++++++++--------- .../encoded_document_accessor.rs | 4 ++-- .../inline_beta_search/inline_beta_filter.rs | 8 ++++---- diskann-label-filter/src/query.rs | 17 ++++++++--------- .../src/tests/inline_beta_filter_test.rs | 2 +- 5 files changed, 24 insertions(+), 25 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 51c71a020..61ab8d596 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -274,15 +274,15 @@ impl<'a, T> DocumentIndexJob<'a, T> { let builder = DocumentIndexBuilder::new( doc_index.clone(), - data_arc.clone(), - attributes_arc.clone(), + data_arc, + attributes_arc, DocumentInsertStrategy::new(common::FullPrecision), ); let num_tasks = NonZeroUsize::new(build.num_threads).unwrap_or(diskann::utils::ONE); let parallelism = Parallelism::dynamic(diskann::utils::ONE, num_tasks); let build_results = build::build_tracked(builder, parallelism, &rt, Some(&ProgressMeter::new(output)))?; - let insert_latencies: Vec = build_results + let mut insert_latencies: Vec = build_results .take_output() .into_iter() .map(|r| r.latency) @@ -291,7 +291,7 @@ impl<'a, T> DocumentIndexJob<'a, T> { let build_time: MicroSeconds = timer.elapsed().into(); writeln!(output, " Index built in {} s", build_time.as_seconds())?; - let insert_percentiles = percentiles::compute_percentiles(&mut insert_latencies.clone())?; + let insert_percentiles = percentiles::compute_percentiles(&mut insert_latencies)?; // ===================== // Search Phase // ===================== @@ -422,8 +422,8 @@ where > + Send + Sync + 'static, - for<'a> InlineBetaStrategy: glue::SearchStrategy> - + glue::DefaultPostProcessor, u32>, + for<'a> InlineBetaStrategy: glue::SearchStrategy> + + glue::DefaultPostProcessor, u32>, T: bytemuck::Pod + Copy + Send + Sync + 'static, { type Id = DP::ExternalId; @@ -451,7 +451,7 @@ where let query_vec = self.queries.row(index); let (_, ref ast_expr) = self.predicates[index]; let strategy = InlineBetaStrategy::new(self.beta, common::FullPrecision); - let filtered_query = FilteredQuery::new(query_vec, ast_expr.clone()); + let filtered_query = FilteredQuery::new(query_vec, ast_expr); // Use a concrete IdDistance scratch buffer so that both the IDs and distances // are captured. Afterwards, the valid IDs are forwarded into the framework buffer. @@ -631,8 +631,8 @@ where > + Send + Sync + 'static, - for<'a> InlineBetaStrategy: glue::SearchStrategy> - + glue::DefaultPostProcessor, u32>, + for<'a> InlineBetaStrategy: glue::SearchStrategy> + + glue::DefaultPostProcessor, u32>, { let searcher = Arc::new(FilteredSearcher { index: index.clone(), diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 15f3ce023..58a7c3fa0 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -162,7 +162,7 @@ where } } -impl<'q, IA, Q> BuildQueryComputer<&'q FilteredQuery> for EncodedDocumentAccessor +impl<'q, IA, Q> BuildQueryComputer<&'q FilteredQuery<'_, Q>> for EncodedDocumentAccessor where IA: BuildQueryComputer, Q: Reborrow<'q>, @@ -172,7 +172,7 @@ where fn build_query_computer( &self, - from: &'q FilteredQuery, + from: &'q FilteredQuery<'_, Q>, ) -> Result { let inner_computer = self .inner_accessor diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 45ddf3a92..a3b4d011c 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -38,7 +38,7 @@ impl InlineBetaStrategy { impl<'q, DP, Strategy, Q> SearchStrategy< DocumentProvider>, - &'q FilteredQuery, + &'q FilteredQuery<'_, Q>, > for InlineBetaStrategy where DP: DataProvider, @@ -75,7 +75,7 @@ where impl<'q, DP, Strategy, Q> diskann::graph::glue::DefaultPostProcessor< DocumentProvider>, - &'q FilteredQuery, + &'q FilteredQuery<'_, Q>, > for InlineBetaStrategy where DP: DataProvider, @@ -150,7 +150,7 @@ impl FilterResults { } } -impl<'q, Q, IA, IPP> SearchPostProcess, &'q FilteredQuery> +impl<'q, Q, IA, IPP> SearchPostProcess, &'q FilteredQuery<'_, Q>> for FilterResults where IA: BuildQueryComputer, @@ -162,7 +162,7 @@ where async fn post_process( &self, accessor: &mut EncodedDocumentAccessor, - query: &'q FilteredQuery, + query: &'q FilteredQuery<'_, Q>, computer: &InlineBetaComputer<>::QueryComputer>, candidates: I, output: &mut B, diff --git a/diskann-label-filter/src/query.rs b/diskann-label-filter/src/query.rs index 777c91d10..0d957f626 100644 --- a/diskann-label-filter/src/query.rs +++ b/diskann-label-filter/src/query.rs @@ -10,25 +10,24 @@ use crate::ASTExpr; /// Type that can be used to specify a query with a filter expression. /// The Readme.md file in the label-filter folder describes the format /// of the query expression. -#[derive(Clone)] -pub struct FilteredQuery { +pub struct FilteredQuery<'a, V> { query: V, - filter_expr: ASTExpr, + filter_expr: &'a ASTExpr, } -impl FilteredQuery { - pub fn new(query: V, filter_expr: ASTExpr) -> Self { +impl<'a, V> FilteredQuery<'a, V> { + pub fn new(query: V, filter_expr: &'a ASTExpr) -> Self { Self { query, filter_expr } } - pub(crate) fn query<'a>(&'a self) -> V::Target + pub(crate) fn query<'b>(&'b self) -> V::Target where - V: Reborrow<'a>, + V: Reborrow<'b>, { self.query.reborrow() } - pub(crate) fn filter_expr(&self) -> &ASTExpr { - &self.filter_expr + pub(crate) fn filter_expr(&self) -> &'a ASTExpr { + self.filter_expr } } diff --git a/diskann-label-filter/src/tests/inline_beta_filter_test.rs b/diskann-label-filter/src/tests/inline_beta_filter_test.rs index 079f91eff..8b280c95f 100644 --- a/diskann-label-filter/src/tests/inline_beta_filter_test.rs +++ b/diskann-label-filter/src/tests/inline_beta_filter_test.rs @@ -196,7 +196,7 @@ fn test_post_process_only_passes_matching_candidates_to_inner() { let mut output = IdDistance::new(&mut ids, &mut distances); let query_vec = [0.0f32, 0.0]; - let filter_query = FilteredQuery::new(&query_vec[..], ast_expr); + let filter_query = FilteredQuery::new(&query_vec[..], &ast_expr); // CopyIds simply copies whatever it receives into the output buffer, // so the output reflects exactly what FilterResults lets through. From f32e90d6bd95efecb6cb221ba4580fba318854ee Mon Sep 17 00:00:00 2001 From: sampathrg Date: Fri, 3 Apr 2026 17:53:53 +0530 Subject: [PATCH 44/50] Remove loops that print debug info for arbitrary number of elements --- diskann-tools/src/utils/ground_truth.rs | 49 +------------------------ 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index e32325083..e4b5a993c 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -93,16 +93,6 @@ pub fn read_labels_and_compute_bitmap( base_label_filename ); - // Print first few base labels for debugging - for (i, label) in base_labels.iter().take(3).enumerate() { - tracing::debug!( - "Base label sample [{}]: doc_id={}, label={}", - i, - label.doc_id, - label.label - ); - } - // Parse queries and evaluate against labels let parsed_queries = read_and_parse_queries(query_label_filename)?; tracing::info!( @@ -111,16 +101,6 @@ pub fn read_labels_and_compute_bitmap( query_label_filename ); - // Print first few queries for debugging - for (i, (query_id, query_expr)) in parsed_queries.iter().take(3).enumerate() { - tracing::debug!( - "Query sample [{}]: query_id={}, expr={:?}", - i, - query_id, - query_expr - ); - } - // using the global threadpool is fine here #[allow(clippy::disallowed_methods)] let query_bitmaps: Vec = parsed_queries @@ -156,17 +136,6 @@ pub fn read_labels_and_compute_bitmap( queries_with_matches ); - // Print per-query match counts - for (i, bitmap) in query_bitmaps.iter().enumerate() { - if i < 10 || bitmap.is_empty() { - tracing::debug!( - "Query {}: {} base vectors matched the filter", - i, - bitmap.len() - ); - } - } - // If no matches, print more diagnostic info if total_matches == 0 { tracing::warn!("WARNING: No base vectors matched any query filters!"); @@ -333,23 +302,7 @@ pub fn compute_ground_truth_from_datafiles< recall_at ); for (query_idx, npq) in ground_truth.iter().enumerate() { - let neighbors: Vec<_> = npq.iter().collect(); - let neighbor_count = neighbors.len(); - - if query_idx < 10 { - // Print top K IDs and distances for first 10 queries - let top_ids: Vec = neighbors.iter().take(10).map(|n| n.id).collect(); - let top_dists: Vec = neighbors.iter().take(10).map(|n| n.distance).collect(); - tracing::debug!( - "Query {}: {} neighbors found. Top IDs: {:?}, Top distances: {:?}", - query_idx, - neighbor_count, - top_ids, - top_dists - ); - } - - if neighbor_count == 0 { + if npq.size() == 0 { tracing::warn!("Query {} has 0 neighbors in ground truth!", query_idx); } } From 97a6cf50501d8db230395222e43fbdccd9860d2c Mon Sep 17 00:00:00 2001 From: sampathrg Date: Fri, 3 Apr 2026 23:09:25 +0530 Subject: [PATCH 45/50] User iterator instead of doing a collect. --- diskann-tools/src/utils/ground_truth.rs | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index e4b5a993c..b13d923f5 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -26,7 +26,7 @@ use diskann_utils::{ use diskann_vector::{distance::Metric, DistanceFunction}; use itertools::Itertools; use rayon::prelude::*; -use serde_json::{Map, Value}; +use serde_json::{map, Map, Value}; use crate::utils::{search_index_utils, CMDResult, CMDToolError}; @@ -37,10 +37,7 @@ use crate::utils::{search_index_utils, CMDResult, CMDToolError}; /// Cartesian product. Non-object labels are evaluated directly. fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { match label { - Value::Object(map) => { - let entries: Vec<(&String, &Value)> = map.iter().collect(); - eval_map_recursive(query_expr, &entries, Map::new()) - } + Value::Object(map) => eval_map_recursive(query_expr, &mut map.iter(), Map::new()), _ => eval_query_expr(query_expr, label), } } @@ -54,29 +51,29 @@ fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool /// * When all fields have been consumed, `eval_query_expr` is called on the accumulated object. fn eval_map_recursive( query_expr: &ASTExpr, - entries: &[(&String, &Value)], + map_iter: &mut map::Iter, mut current: Map, ) -> bool { - match entries { - [] => eval_query_expr(query_expr, &Value::Object(current)), - [(key, Value::Array(arr)), rest @ ..] => { + match map_iter.next() { + None => eval_query_expr(query_expr, &Value::Object(current)), + Some((key, Value::Array(arr))) => { if arr.is_empty() { // Omit this key, matching the original behaviour for empty arrays. - eval_map_recursive(query_expr, rest, current) + eval_map_recursive(query_expr, map_iter, current) } else { for item in arr { let mut branch = current.clone(); branch.insert((*key).clone(), item.clone()); - if eval_map_recursive(query_expr, rest, branch) { + if eval_map_recursive(query_expr, map_iter, branch) { return true; } } false } } - [(key, val), rest @ ..] => { + Some((key, val)) => { current.insert((*key).clone(), (*val).clone()); - eval_map_recursive(query_expr, rest, current) + eval_map_recursive(query_expr, map_iter, current) } } } From 266cd51aa59c9c0a3947254cd84e17cc2d263006 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Fri, 3 Apr 2026 23:19:50 +0530 Subject: [PATCH 46/50] don't use mut reference to iterator --- diskann-tools/src/utils/ground_truth.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index b13d923f5..10fe570b3 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -37,7 +37,7 @@ use crate::utils::{search_index_utils, CMDResult, CMDToolError}; /// Cartesian product. Non-object labels are evaluated directly. fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { match label { - Value::Object(map) => eval_map_recursive(query_expr, &mut map.iter(), Map::new()), + Value::Object(map) => eval_map_recursive(query_expr, map.iter(), Map::new()), _ => eval_query_expr(query_expr, label), } } @@ -51,7 +51,7 @@ fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool /// * When all fields have been consumed, `eval_query_expr` is called on the accumulated object. fn eval_map_recursive( query_expr: &ASTExpr, - map_iter: &mut map::Iter, + mut map_iter: map::Iter, mut current: Map, ) -> bool { match map_iter.next() { @@ -64,7 +64,9 @@ fn eval_map_recursive( for item in arr { let mut branch = current.clone(); branch.insert((*key).clone(), item.clone()); - if eval_map_recursive(query_expr, map_iter, branch) { + + // need to clone here because we want to iterate from the next element for each branch + if eval_map_recursive(query_expr, map_iter.clone(), branch) { return true; } } From a22133a93963a6c9555188adc6d2930902c4917a Mon Sep 17 00:00:00 2001 From: sampathrg Date: Fri, 3 Apr 2026 23:20:27 +0530 Subject: [PATCH 47/50] Add some documentation --- diskann-tools/src/utils/ground_truth.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index 10fe570b3..2ab17c30c 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -65,7 +65,7 @@ fn eval_map_recursive( let mut branch = current.clone(); branch.insert((*key).clone(), item.clone()); - // need to clone here because we want to iterate from the next element for each branch + // need to clone here because we want to iterate from same next element for each branch if eval_map_recursive(query_expr, map_iter.clone(), branch) { return true; } From 8afbcd31e540767ab552b490cc0f1ef62ee55e03 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Tue, 7 Apr 2026 18:03:52 +0530 Subject: [PATCH 48/50] remove need of cloning when using Document::new --- .../src/backend/document_index/benchmark.rs | 9 +++++---- diskann-label-filter/src/document.rs | 8 ++++---- .../src/tests/document_insert_strategy_test.rs | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 61ab8d596..27300a909 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -252,8 +252,9 @@ impl<'a, T> DocumentIndexJob<'a, T> { // Start points are stored at indices num_vectors..num_vectors+frozen_points let medoid_idx = compute_medoid_index(&data)?; let start_point_id = num_vectors as u32; // Start points begin at max_points - let medoid_attrs = attributes.get(medoid_idx).cloned().unwrap_or_default(); - attribute_store.set_element(&start_point_id, &medoid_attrs)?; + let default_attrs = vec![]; + let medoid_attrs = attributes.get(medoid_idx).unwrap_or(&default_attrs); + attribute_store.set_element(&start_point_id, medoid_attrs)?; let doc_provider = DocumentProvider::new(inner_provider, attribute_store); @@ -821,13 +822,13 @@ where async fn build(&self, range: std::ops::Range) -> diskann::ANNResult { let ctx = DefaultContext; for i in range { - let attrs = self.attributes.get(i).cloned().ok_or_else(|| { + let attrs = self.attributes.get(i).ok_or_else(|| { ANNError::message( ANNErrorKind::Opaque, format!("Failed to get attributes at index {}", i), ) })?; - let doc = Document::new(self.data.row(i), attrs); + let doc = Document::new(self.data.row(i), &attrs); self.index .insert(self.strategy, &ctx, &(i as u32), &doc) .await?; diff --git a/diskann-label-filter/src/document.rs b/diskann-label-filter/src/document.rs index 5c817525c..247edff78 100644 --- a/diskann-label-filter/src/document.rs +++ b/diskann-label-filter/src/document.rs @@ -10,11 +10,11 @@ use diskann_utils::reborrow::Reborrow; /// supply diskann with a vector and its attributes pub struct Document<'a, V: ?Sized> { vector: &'a V, - attributes: Vec, + attributes: &'a [Attribute], } impl<'a, V: ?Sized> Document<'a, V> { - pub fn new(vector: &'a V, attributes: Vec) -> Self { + pub fn new(vector: &'a V, attributes: &'a [Attribute]) -> Self { Self { vector, attributes } } @@ -22,8 +22,8 @@ impl<'a, V: ?Sized> Document<'a, V> { self.vector } - pub(crate) fn attributes(&self) -> &Vec { - &self.attributes + pub(crate) fn attributes(&self) -> &'a [Attribute] { + self.attributes } } diff --git a/diskann-label-filter/src/tests/document_insert_strategy_test.rs b/diskann-label-filter/src/tests/document_insert_strategy_test.rs index a27de2c0a..525ba79f8 100644 --- a/diskann-label-filter/src/tests/document_insert_strategy_test.rs +++ b/diskann-label-filter/src/tests/document_insert_strategy_test.rs @@ -96,7 +96,7 @@ fn test_build_query_computer_extracts_vector_from_document() { let doc_accessor = DocumentSearchAccessor::new(inner_accessor); let vector = vec![1.0f32, 2.0, 0.0]; - let doc = Document::new(vector.as_slice(), vec![]); + let doc = Document::new(vector.as_slice(), &[]); let result = as BuildQueryComputer<&Document<'_, [f32]>>>::build_query_computer(&doc_accessor, &doc); From c45204abeacc50f0a46f96ac368e1ef56b466065 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Wed, 8 Apr 2026 09:46:20 +0530 Subject: [PATCH 49/50] Fix clippy error --- diskann-benchmark/src/backend/document_index/benchmark.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 27300a909..9bd5feed1 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -828,7 +828,7 @@ where format!("Failed to get attributes at index {}", i), ) })?; - let doc = Document::new(self.data.row(i), &attrs); + let doc = Document::new(self.data.row(i), attrs); self.index .insert(self.strategy, &ctx, &(i as u32), &doc) .await?; From 4f88a9e678697f5ff677f154b8b7d1b0cd89c19a Mon Sep 17 00:00:00 2001 From: sampathrg Date: Wed, 8 Apr 2026 21:55:27 +0530 Subject: [PATCH 50/50] Revert changes to ground_truth.rs --- diskann-tools/src/utils/ground_truth.rs | 127 +----------------------- 1 file changed, 2 insertions(+), 125 deletions(-) diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index 2ab17c30c..883f0c2ec 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -4,7 +4,7 @@ */ use bit_set::BitSet; -use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels, ASTExpr}; +use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels}; use std::{io::Write, mem::size_of, str::FromStr}; @@ -26,79 +26,18 @@ use diskann_utils::{ use diskann_vector::{distance::Metric, DistanceFunction}; use itertools::Itertools; use rayon::prelude::*; -use serde_json::{map, Map, Value}; use crate::utils::{search_index_utils, CMDResult, CMDToolError}; -/// Evaluates a query expression against a label, expanding array-valued fields by recursion. -/// -/// For each key in the JSON object, if the value is an array the expression is evaluated -/// against one element at a time (any-match semantics) without materialising the full -/// Cartesian product. Non-object labels are evaluated directly. -fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { - match label { - Value::Object(map) => eval_map_recursive(query_expr, map.iter(), Map::new()), - _ => eval_query_expr(query_expr, label), - } -} - -/// Walk `entries` one field at a time, accumulating scalar values into `current`. -/// -/// * Scalar fields are inserted directly and the walk continues with the remaining entries. -/// * Array fields branch once per element; evaluation short-circuits on the first branch -/// that returns `true`. -/// * An empty array is treated as an absent field (preserving the previous behaviour). -/// * When all fields have been consumed, `eval_query_expr` is called on the accumulated object. -fn eval_map_recursive( - query_expr: &ASTExpr, - mut map_iter: map::Iter, - mut current: Map, -) -> bool { - match map_iter.next() { - None => eval_query_expr(query_expr, &Value::Object(current)), - Some((key, Value::Array(arr))) => { - if arr.is_empty() { - // Omit this key, matching the original behaviour for empty arrays. - eval_map_recursive(query_expr, map_iter, current) - } else { - for item in arr { - let mut branch = current.clone(); - branch.insert((*key).clone(), item.clone()); - - // need to clone here because we want to iterate from same next element for each branch - if eval_map_recursive(query_expr, map_iter.clone(), branch) { - return true; - } - } - false - } - } - Some((key, val)) => { - current.insert((*key).clone(), (*val).clone()); - eval_map_recursive(query_expr, map_iter, current) - } - } -} - pub fn read_labels_and_compute_bitmap( base_label_filename: &str, query_label_filename: &str, ) -> CMDResult> { // Read base labels let base_labels = read_baselabels(base_label_filename)?; - tracing::info!( - "Loaded {} base labels from {}", - base_labels.len(), - base_label_filename - ); // Parse queries and evaluate against labels let parsed_queries = read_and_parse_queries(query_label_filename)?; - tracing::info!( - "Loaded {} queries from {}", - parsed_queries.len(), - query_label_filename - ); // using the global threadpool is fine here #[allow(clippy::disallowed_methods)] @@ -107,17 +46,7 @@ pub fn read_labels_and_compute_bitmap( .map(|(_query_id, query_expr)| { let mut bitmap = BitSet::new(); for base_label in base_labels.iter() { - // Handle case where base_label.label is an array - check if any element matches - // Also expand array-valued fields within objects (e.g., {"country": ["AU", "NZ"]}) - let matches = if let Some(array) = base_label.label.as_array() { - array - .iter() - .any(|item| eval_query_with_array_expansion(query_expr, item)) - } else { - eval_query_with_array_expansion(query_expr, &base_label.label) - }; - - if matches { + if eval_query_expr(query_expr, &base_label.label) { bitmap.insert(base_label.doc_id); } } @@ -125,33 +54,6 @@ pub fn read_labels_and_compute_bitmap( }) .collect(); - // Debug: Print match statistics for each query - let total_matches: usize = query_bitmaps.iter().map(|b| b.len()).sum(); - let queries_with_matches = query_bitmaps.iter().filter(|b| !b.is_empty()).count(); - tracing::info!( - "Filter matching summary: {} total matches across {} queries ({} queries have matches)", - total_matches, - query_bitmaps.len(), - queries_with_matches - ); - - // If no matches, print more diagnostic info - if total_matches == 0 { - tracing::warn!("WARNING: No base vectors matched any query filters!"); - tracing::warn!( - "This could indicate a format mismatch between base labels and query filters." - ); - - // Try to identify what keys exist in base labels vs queries - if let Some(first_label) = base_labels.first() { - tracing::warn!( - "First base label (full): doc_id={}, label={}", - first_label.doc_id, - first_label.label - ); - } - } - Ok(query_bitmaps) } @@ -294,31 +196,6 @@ pub fn compute_ground_truth_from_datafiles< assert_ne!(ground_truth.len(), 0, "No ground-truth results computed"); - // Debug: Print top K matches for each query - tracing::info!( - "Ground truth computed for {} queries with recall_at={}", - ground_truth.len(), - recall_at - ); - for (query_idx, npq) in ground_truth.iter().enumerate() { - if npq.size() == 0 { - tracing::warn!("Query {} has 0 neighbors in ground truth!", query_idx); - } - } - - // Summary stats - let total_neighbors: usize = ground_truth.iter().map(|npq| npq.iter().count()).sum(); - let queries_with_neighbors = ground_truth - .iter() - .filter(|npq| npq.iter().count() > 0) - .count(); - tracing::info!( - "Ground truth summary: {} total neighbors, {} queries have neighbors, {} queries have 0 neighbors", - total_neighbors, - queries_with_neighbors, - ground_truth.len() - queries_with_neighbors - ); - if has_vector_filters || has_query_bitmaps { let ground_truth_collection = ground_truth .into_iter()