From 432f5c7fdea0ef0b04265bec6b121cd5aa90e876 Mon Sep 17 00:00:00 2001 From: SubhamSinghal Date: Mon, 11 May 2026 12:26:03 +0530 Subject: [PATCH] Track spill read-back memory in SMJ --- .../sort_merge_join/materializing_stream.rs | 67 ++++-- .../src/joins/sort_merge_join/tests.rs | 223 ++++++++++++++++++ 2 files changed, 275 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs index 069e94d0a9fd6..c7df8ba8586ac 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs @@ -1526,7 +1526,7 @@ impl MaterializingSortMergeJoinStream { /// gathers columns across sources. A null-row sentinel at source index 0 /// handles null right indices (unmatched streamed rows). fn materialize_right_columns( - &self, + &mut self, matched_chunks: &[(usize, UInt64Array, UInt64Array)], total_matched_rows: usize, ) -> Result> { @@ -1541,11 +1541,33 @@ impl MaterializingSortMergeJoinStream { matched_chunks.iter().map(|c| &c.2 as &dyn Array).collect(); as_uint64_array(&compute::concat(&refs)?)?.clone() }; - return fetch_right_columns_by_idxs( + + let spill_read_mem = match &self.buffered_data.batches[first_batch_idx].batch + { + BufferedBatchState::Spilled(_) => { + self.buffered_data.batches[first_batch_idx].size_estimation + } + _ => 0, + }; + + if spill_read_mem > 0 { + self.reservation.grow(spill_read_mem); + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size()); + } + + let result = fetch_right_columns_by_idxs( &self.buffered_data, first_batch_idx, &combined_right_indices, ); + + if spill_read_mem > 0 { + self.reservation.shrink(spill_read_mem); + } + + return result; } // Multiple source batches: map each buffered_batch_idx to a @@ -1577,20 +1599,31 @@ impl MaterializingSortMergeJoinStream { let mut right_columns = Vec::with_capacity(num_right_cols); // Read each source batch once (spilled batches require disk I/O). - let source_data: Vec> = source_batches - .iter() - .map(|&idx| { - let bb = &self.buffered_data.batches[idx]; - match &bb.batch { - BufferedBatchState::InMemory(batch) => Some(batch.clone()), - BufferedBatchState::Spilled(spill_file) => { - let file = BufReader::new(File::open(spill_file.path()).ok()?); - let reader = StreamReader::try_new(file, None).ok()?; - reader.into_iter().next()?.ok() - } + // Track memory for each spilled batch at the point of deserialization + // so the pool reflects actual usage as it grows. + let mut spill_read_mem: usize = 0; + let mut source_data: Vec> = + Vec::with_capacity(source_batches.len()); + for &idx in &source_batches { + let bb = &self.buffered_data.batches[idx]; + match &bb.batch { + BufferedBatchState::InMemory(batch) => { + source_data.push(Some(batch.clone())); } - }) - .collect(); + BufferedBatchState::Spilled(spill_file) => { + let batch_mem = bb.size_estimation; + self.reservation.grow(batch_mem); + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size()); + spill_read_mem += batch_mem; + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = StreamReader::try_new(file, None)?; + source_data.push(reader.into_iter().next().transpose()?); + } + } + } for col_idx in 0..num_right_cols { let dtype = self.buffered_schema.field(col_idx).data_type(); @@ -1614,6 +1647,10 @@ impl MaterializingSortMergeJoinStream { right_columns.push(interleave(&source_arrays, &interleave_indices)?); } + if spill_read_mem > 0 { + self.reservation.shrink(spill_read_mem); + } + Ok(right_columns) } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index bc34c351c5e21..3e0237350050d 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -4724,3 +4724,226 @@ async fn spill_filtered_boundary_loses_outer_rows() -> Result<()> { Ok(()) } + +/// Verifies that `peak_mem_used` reflects spill read-back memory during +/// output materialization. +/// +/// When spilled buffered batches are read back from disk to produce join +/// output, the deserialized data temporarily exists in memory. This test +/// verifies that the read-back is tracked via grow/shrink so the pool +/// accurately reflects the transient spike. +#[tokio::test] +async fn spill_read_back_memory_accounting() -> Result<()> { + use arrow::array::Array; + + let left_batch = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let size_estimation = left_batch.get_array_memory_size() + + Int32Array::from(vec![1, 1]).get_array_memory_size() + + 2usize.next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + + // Memory limit too small for a full batch — forces spilling. + let memory_limit = size_estimation / 2; + + // All rows share the same join key (b=1) to force multiple buffered + // batches in the same key group — triggering spill read-back during + // output materialization. + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![100 + i, 101 + i]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + let right_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![1, 1]), + ("c2", &vec![200 + i, 201 + i]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let result = common::collect(stream).await.unwrap(); + + assert!(!result.is_empty(), "Expected non-empty join result"); + + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); + + // peak_mem_used should reflect the spill read-back: when buffered + // batches are read from disk during output materialization, grow() + // temporarily reserves size_estimation. This pushes peak above what + // join_arrays_mem alone would show. + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem >= size_estimation, + "peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \ + because spill read-back temporarily loads full batch into memory" + ); + + // All memory must be released (grow/shrink balanced) + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "All memory should be released after join completes" + ); + + Ok(()) +} + +/// Verifies spill read-back memory tracking for the single-source path. +/// +/// When only ONE buffered batch exists for a key group and it's spilled, +/// `fetch_right_columns_by_idxs` reads it back. This test verifies the +/// grow/shrink around that single-batch read. +#[tokio::test] +async fn spill_read_back_single_source() -> Result<()> { + use arrow::array::Array; + + let left_batch = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let size_estimation = left_batch.get_array_memory_size() + + Int32Array::from(vec![1, 1]).get_array_memory_size() + + 2usize.next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + + // Memory limit too small for a full batch — forces spilling. + let memory_limit = size_estimation / 2; + + // Multiple distinct keys so each key group has exactly ONE buffered batch. + // This ensures the single-source path is exercised. + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![i, i]), + ("c1", &vec![100 + i, 101 + i]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + // One batch per key — each key group has single source + let right_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![i, i]), + ("c2", &vec![200 + i, 201 + i]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let result = common::collect(stream).await.unwrap(); + + assert!(!result.is_empty(), "Expected non-empty join result"); + + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); + + // peak_mem_used should reflect the single-batch read-back + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem >= size_estimation, + "peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \ + because single-source spill read-back loads full batch" + ); + + // All memory must be released + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "All memory should be released after join completes" + ); + + Ok(()) +}