Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1526,7 +1526,7 @@ impl MaterializingSortMergeJoinStream {
/// gathers columns across sources. A null-row sentinel at source index 0
/// handles null right indices (unmatched streamed rows).
fn materialize_right_columns(
&self,
&mut self,
matched_chunks: &[(usize, UInt64Array, UInt64Array)],
total_matched_rows: usize,
) -> Result<Vec<ArrayRef>> {
Expand All @@ -1541,11 +1541,33 @@ impl MaterializingSortMergeJoinStream {
matched_chunks.iter().map(|c| &c.2 as &dyn Array).collect();
as_uint64_array(&compute::concat(&refs)?)?.clone()
};
return fetch_right_columns_by_idxs(

let spill_read_mem = match &self.buffered_data.batches[first_batch_idx].batch
{
BufferedBatchState::Spilled(_) => {
self.buffered_data.batches[first_batch_idx].size_estimation
}
_ => 0,
};

if spill_read_mem > 0 {
self.reservation.grow(spill_read_mem);
self.join_metrics
.peak_mem_used()
.set_max(self.reservation.size());
}

let result = fetch_right_columns_by_idxs(
&self.buffered_data,
first_batch_idx,
&combined_right_indices,
);

if spill_read_mem > 0 {
self.reservation.shrink(spill_read_mem);
}

return result;
}

// Multiple source batches: map each buffered_batch_idx to a
Expand Down Expand Up @@ -1577,20 +1599,31 @@ impl MaterializingSortMergeJoinStream {
let mut right_columns = Vec::with_capacity(num_right_cols);

// Read each source batch once (spilled batches require disk I/O).
let source_data: Vec<Option<RecordBatch>> = source_batches
.iter()
.map(|&idx| {
let bb = &self.buffered_data.batches[idx];
match &bb.batch {
BufferedBatchState::InMemory(batch) => Some(batch.clone()),
BufferedBatchState::Spilled(spill_file) => {
let file = BufReader::new(File::open(spill_file.path()).ok()?);
let reader = StreamReader::try_new(file, None).ok()?;
reader.into_iter().next()?.ok()
}
// Track memory for each spilled batch at the point of deserialization
// so the pool reflects actual usage as it grows.
let mut spill_read_mem: usize = 0;
let mut source_data: Vec<Option<RecordBatch>> =
Vec::with_capacity(source_batches.len());
for &idx in &source_batches {
let bb = &self.buffered_data.batches[idx];
match &bb.batch {
BufferedBatchState::InMemory(batch) => {
source_data.push(Some(batch.clone()));
}
})
.collect();
BufferedBatchState::Spilled(spill_file) => {
let batch_mem = bb.size_estimation;
self.reservation.grow(batch_mem);
self.join_metrics
.peak_mem_used()
.set_max(self.reservation.size());
spill_read_mem += batch_mem;

let file = BufReader::new(File::open(spill_file.path())?);
let reader = StreamReader::try_new(file, None)?;
source_data.push(reader.into_iter().next().transpose()?);
}
}
}

for col_idx in 0..num_right_cols {
let dtype = self.buffered_schema.field(col_idx).data_type();
Expand All @@ -1614,6 +1647,10 @@ impl MaterializingSortMergeJoinStream {
right_columns.push(interleave(&source_arrays, &interleave_indices)?);
}

if spill_read_mem > 0 {
self.reservation.shrink(spill_read_mem);
}

Ok(right_columns)
}

Expand Down
223 changes: 223 additions & 0 deletions datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4724,3 +4724,226 @@ async fn spill_filtered_boundary_loses_outer_rows() -> Result<()> {

Ok(())
}

/// Verifies that `peak_mem_used` reflects spill read-back memory during
/// output materialization.
///
/// When spilled buffered batches are read back from disk to produce join
/// output, the deserialized data temporarily exists in memory. This test
/// verifies that the read-back is tracked via grow/shrink so the pool
/// accurately reflects the transient spike.
#[tokio::test]
async fn spill_read_back_memory_accounting() -> Result<()> {
use arrow::array::Array;

let left_batch = build_table_i32(
("a1", &vec![0, 1]),
("b1", &vec![1, 1]),
("c1", &vec![4, 5]),
);
let size_estimation = left_batch.get_array_memory_size()
+ Int32Array::from(vec![1, 1]).get_array_memory_size()
+ 2usize.next_power_of_two() * size_of::<usize>()
+ size_of::<std::ops::Range<usize>>()
+ size_of::<usize>();

// Memory limit too small for a full batch — forces spilling.
let memory_limit = size_estimation / 2;

// All rows share the same join key (b=1) to force multiple buffered
// batches in the same key group — triggering spill read-back during
// output materialization.
let left_batches: Vec<RecordBatch> = (0..4)
.map(|i| {
build_table_i32(
("a1", &vec![i * 2, i * 2 + 1]),
("b1", &vec![1, 1]),
("c1", &vec![100 + i, 101 + i]),
)
})
.collect();
let left = build_table_from_batches(left_batches);

let right_batches: Vec<RecordBatch> = (0..4)
.map(|i| {
build_table_i32(
("a2", &vec![i * 2, i * 2 + 1]),
("b2", &vec![1, 1]),
("c2", &vec![200 + i, 201 + i]),
)
})
.collect();
let right = build_table_from_batches(right_batches);

let on = vec![(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let sort_options = vec![SortOptions::default(); on.len()];

let runtime = RuntimeEnvBuilder::new()
.with_memory_limit(memory_limit, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
)
.build_arc()?;

let session_config = SessionConfig::default().with_batch_size(50);
let task_ctx = Arc::new(
TaskContext::default()
.with_session_config(session_config)
.with_runtime(Arc::clone(&runtime)),
);

let join = join_with_options(
Arc::clone(&left),
Arc::clone(&right),
on.clone(),
Inner,
sort_options,
NullEquality::NullEqualsNothing,
)?;

let stream = join.execute(0, task_ctx)?;
let result = common::collect(stream).await.unwrap();

assert!(!result.is_empty(), "Expected non-empty join result");

let metrics = join.metrics().unwrap();
assert!(
metrics.spill_count().unwrap() > 0,
"Expected spilling to occur"
);

// peak_mem_used should reflect the spill read-back: when buffered
// batches are read from disk during output materialization, grow()
// temporarily reserves size_estimation. This pushes peak above what
// join_arrays_mem alone would show.
let peak_mem = metrics
.sum_by_name("peak_mem_used")
.map(|m| m.as_usize())
.unwrap_or(0);
assert!(
peak_mem >= size_estimation,
"peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \
because spill read-back temporarily loads full batch into memory"
);

// All memory must be released (grow/shrink balanced)
assert_eq!(
runtime.memory_pool.reserved(),
0,
"All memory should be released after join completes"
);

Ok(())
}

/// Verifies spill read-back memory tracking for the single-source path.
///
/// When only ONE buffered batch exists for a key group and it's spilled,
/// `fetch_right_columns_by_idxs` reads it back. This test verifies the
/// grow/shrink around that single-batch read.
#[tokio::test]
async fn spill_read_back_single_source() -> Result<()> {
use arrow::array::Array;

let left_batch = build_table_i32(
("a1", &vec![0, 1]),
("b1", &vec![1, 1]),
("c1", &vec![4, 5]),
);
let size_estimation = left_batch.get_array_memory_size()
+ Int32Array::from(vec![1, 1]).get_array_memory_size()
+ 2usize.next_power_of_two() * size_of::<usize>()
+ size_of::<std::ops::Range<usize>>()
+ size_of::<usize>();

// Memory limit too small for a full batch — forces spilling.
let memory_limit = size_estimation / 2;

// Multiple distinct keys so each key group has exactly ONE buffered batch.
// This ensures the single-source path is exercised.
let left_batches: Vec<RecordBatch> = (0..4)
.map(|i| {
build_table_i32(
("a1", &vec![i * 2, i * 2 + 1]),
("b1", &vec![i, i]),
("c1", &vec![100 + i, 101 + i]),
)
})
.collect();
let left = build_table_from_batches(left_batches);

// One batch per key — each key group has single source
let right_batches: Vec<RecordBatch> = (0..4)
.map(|i| {
build_table_i32(
("a2", &vec![i * 2, i * 2 + 1]),
("b2", &vec![i, i]),
("c2", &vec![200 + i, 201 + i]),
)
})
.collect();
let right = build_table_from_batches(right_batches);

let on = vec![(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let sort_options = vec![SortOptions::default(); on.len()];

let runtime = RuntimeEnvBuilder::new()
.with_memory_limit(memory_limit, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
)
.build_arc()?;

let session_config = SessionConfig::default().with_batch_size(50);
let task_ctx = Arc::new(
TaskContext::default()
.with_session_config(session_config)
.with_runtime(Arc::clone(&runtime)),
);

let join = join_with_options(
Arc::clone(&left),
Arc::clone(&right),
on.clone(),
Inner,
sort_options,
NullEquality::NullEqualsNothing,
)?;

let stream = join.execute(0, task_ctx)?;
let result = common::collect(stream).await.unwrap();

assert!(!result.is_empty(), "Expected non-empty join result");

let metrics = join.metrics().unwrap();
assert!(
metrics.spill_count().unwrap() > 0,
"Expected spilling to occur"
);

// peak_mem_used should reflect the single-batch read-back
let peak_mem = metrics
.sum_by_name("peak_mem_used")
.map(|m| m.as_usize())
.unwrap_or(0);
assert!(
peak_mem >= size_estimation,
"peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \
because single-source spill read-back loads full batch"
);

// All memory must be released
assert_eq!(
runtime.memory_pool.reserved(),
0,
"All memory should be released after join completes"
);

Ok(())
}
Loading