Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 40 additions & 48 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ use datafusion_expr::expr::SetQuantifier;
use datafusion_expr::expr::{InList, WildcardOptions};
use datafusion_expr::{
Between, BinaryExpr, Cast, Expr, ExprSchemable, GetFieldAccess, Like, Literal,
Operator, TryCast, lit, when,
Operator, TryCast, lit,
};

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_functions_nested::expr_fn::{
array_has, array_max, array_min, array_position, cardinality,
array_has, array_max, array_min, cardinality,
};

mod binary_op;
Expand Down Expand Up @@ -1259,64 +1259,59 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}
}

/// Plans `needle <compare_op> ANY/ALL(haystack)` with proper SQL NULL semantics.
///
/// CASE/WHEN structure:
/// WHEN arr IS NULL → NULL
/// WHEN empty → vacuous_result (ANY:false, ALL:true)
/// WHEN lhs IS NULL → NULL
/// WHEN decisive_condition → decisive_result (ANY:true match found, ALL:false violation found)
/// WHEN has_nulls → NULL
/// ELSE → vacuous_result
/// Plans `needle <op> ANY/ALL(haystack)` by desugaring to `array_has`,
/// `array_min`, or `array_max`. Desugars using min/max get a cardinality guard
/// so empty arrays return the vacuous result (ANY → false, ALL → true) instead
/// of NULL.
fn plan_quantified_op(
needle: &Expr,
haystack: &Expr,
compare_op: &BinaryOperator,
quantifier: SetQuantifier,
) -> Result<Expr> {
let null_arr_check = haystack.clone().is_null();
let empty_check = cardinality(haystack.clone()).eq(lit(0u64));
let null_lhs_check = needle.clone().is_null();
// DataFusion's array_position uses is_null() checks internally (not equality),
// so it can locate NULL elements even though NULL = NULL is NULL in standard SQL.
let has_nulls =
array_position(haystack.clone(), lit(ScalarValue::Null), lit(1i64)).is_not_null();

let decisive_condition = match (compare_op, quantifier) {
(BinaryOperator::Eq, SetQuantifier::Any)
| (BinaryOperator::NotEq, SetQuantifier::All) => {
array_has(haystack.clone(), needle.clone())
let (cmp, needs_empty_guard) = match (compare_op, quantifier) {
(BinaryOperator::Eq, SetQuantifier::Any) => {
(array_has(haystack.clone(), needle.clone()), false)
}
(BinaryOperator::Eq, SetQuantifier::All)
| (BinaryOperator::NotEq, SetQuantifier::Any) => {
let all_equal = array_min(haystack.clone())
(BinaryOperator::NotEq, SetQuantifier::All) => (
Expr::Not(Box::new(array_has(haystack.clone(), needle.clone()))),
false,
),
(BinaryOperator::Eq, SetQuantifier::All) => (
array_min(haystack.clone())
.eq(needle.clone())
.and(array_max(haystack.clone()).eq(needle.clone()));
Expr::Not(Box::new(all_equal))
}
.and(array_max(haystack.clone()).eq(needle.clone())),
true,
),
(BinaryOperator::NotEq, SetQuantifier::Any) => (
array_min(haystack.clone())
.not_eq(needle.clone())
.or(array_max(haystack.clone()).not_eq(needle.clone())),
true,
),
(BinaryOperator::Gt, SetQuantifier::Any) => {
needle.clone().gt(array_min(haystack.clone()))
(needle.clone().gt(array_min(haystack.clone())), true)
}
(BinaryOperator::Gt, SetQuantifier::All) => {
Expr::Not(Box::new(needle.clone().gt(array_max(haystack.clone()))))
(needle.clone().gt(array_max(haystack.clone())), true)
}
(BinaryOperator::Lt, SetQuantifier::Any) => {
needle.clone().lt(array_max(haystack.clone()))
(needle.clone().lt(array_max(haystack.clone())), true)
}
(BinaryOperator::Lt, SetQuantifier::All) => {
Expr::Not(Box::new(needle.clone().lt(array_min(haystack.clone()))))
(needle.clone().lt(array_min(haystack.clone())), true)
}
(BinaryOperator::GtEq, SetQuantifier::Any) => {
needle.clone().gt_eq(array_min(haystack.clone()))
(needle.clone().gt_eq(array_min(haystack.clone())), true)
}
(BinaryOperator::GtEq, SetQuantifier::All) => {
Expr::Not(Box::new(needle.clone().gt_eq(array_max(haystack.clone()))))
(needle.clone().gt_eq(array_max(haystack.clone())), true)
}
(BinaryOperator::LtEq, SetQuantifier::Any) => {
needle.clone().lt_eq(array_max(haystack.clone()))
(needle.clone().lt_eq(array_max(haystack.clone())), true)
}
(BinaryOperator::LtEq, SetQuantifier::All) => {
Expr::Not(Box::new(needle.clone().lt_eq(array_min(haystack.clone()))))
(needle.clone().lt_eq(array_min(haystack.clone())), true)
}
_ => {
return plan_err!(
Expand All @@ -1325,18 +1320,15 @@ fn plan_quantified_op(
}
};

let (vacuous_result, decisive_result) = match quantifier {
SetQuantifier::Any => (false, true),
SetQuantifier::All => (true, false),
let expr = if needs_empty_guard {
match quantifier {
SetQuantifier::Any => cardinality(haystack.clone()).gt(lit(0u64)).and(cmp),
SetQuantifier::All => cardinality(haystack.clone()).eq(lit(0u64)).or(cmp),
}
} else {
cmp
};

let null_bool = lit(ScalarValue::Boolean(None));
when(null_arr_check, null_bool.clone())
.when(empty_check, lit(vacuous_result))
.when(null_lhs_check, null_bool.clone())
.when(decisive_condition, lit(decisive_result))
.when(has_nulls, null_bool)
.otherwise(lit(vacuous_result))
Ok(expr)
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ fn roundtrip_statement_postgres_any_array_expr() -> Result<(), DataFusionError>
sql: "select left from array where 1 = any(left);",
parser_dialect: GenericDialect {},
unparser_dialect: UnparserPostgreSqlDialect {},
expected: @r#"SELECT "array"."left" FROM "array" WHERE CASE WHEN "array"."left" IS NULL THEN NULL WHEN (cardinality("array"."left") = 0) THEN false WHEN 1 IS NULL THEN NULL WHEN 1 = ANY("array"."left") THEN true WHEN array_position("array"."left", NULL, 1) IS NOT NULL THEN NULL ELSE false END"#,
expected: @r#"SELECT "array"."left" FROM "array" WHERE 1 = ANY("array"."left")"#,
);
Ok(())
}
Expand Down
10 changes: 5 additions & 5 deletions datafusion/sqllogictest/test_files/array/array_all.slt
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ NULL
query B
select 5 <> ALL(make_array(NULL::INT, NULL::INT));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case has changed behaviour so should be moved out from under the parent comment (All-NULL arrays: returns NULL)

----
NULL
true

query B
select 5 > ALL(make_array(NULL::INT, NULL::INT));
Expand All @@ -171,22 +171,22 @@ NULL
query B
select 5 > ALL(make_array(3, NULL));
----
NULL
true

query B
select 5 >= ALL(make_array(5, NULL));
----
NULL
true

query B
select 1 < ALL(make_array(3, NULL));
----
NULL
true

query B
select 1 <= ALL(make_array(1, NULL));
----
NULL
true

# Mixed NULL + non-NULL (not satisfying condition → FALSE wins over NULL)
query B
Expand Down
38 changes: 18 additions & 20 deletions datafusion/sqllogictest/test_files/array/array_has.slt
Original file line number Diff line number Diff line change
Expand Up @@ -517,18 +517,16 @@ logical_plan
03)----SubqueryAlias: test
04)------SubqueryAlias: t
05)--------Projection:
06)----------Filter: __common_expr_3 IS NULL AND Boolean(NULL) OR __common_expr_3 IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) IS NOT DISTINCT FROM Boolean(true) AND __common_expr_3 IS NOT NULL
07)------------Projection: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) AS __common_expr_3
08)--------------TableScan: generate_series() projection=[value]
06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")])
07)------------TableScan: generate_series() projection=[value]
physical_plan
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
03)----CoalescePartitionsExec
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
05)--------FilterExec: __common_expr_3@0 IS NULL AND NULL OR __common_expr_3@0 IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) IS NOT DISTINCT FROM true AND __common_expr_3@0 IS NOT NULL, projection=[]
06)----------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8View)), 1, 32) as __common_expr_3]
07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[]
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]

query I
with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
Expand Down Expand Up @@ -724,7 +722,8 @@ select 5 <> any(make_array(5, 5, 5));
----
false

# Empty array: all operators should return false (no elements satisfy the condition)
# Empty array: vacuous false. min/max-based desugars are guarded by cardinality
# so `value > any([])` and friends don't leak NULL.
query B
select 5 = any(make_array());
----
Expand Down Expand Up @@ -755,27 +754,26 @@ select 5 <= any(make_array());
----
false

# Mixed NULL + non-NULL array where no non-NULL element satisfies the condition
# These return NULL because NULLs leave the result indeterminate
# Mixed NULL + non-NULL, no non-NULL element satisfies: false (min/max skip NULLs).
query B
select 5 > any(make_array(6, NULL));
----
NULL
false

query B
select 5 < any(make_array(3, NULL));
----
NULL
false

query B
select 5 >= any(make_array(6, NULL));
----
NULL
false

query B
select 5 <= any(make_array(3, NULL));
----
NULL
false

# Mixed NULL + non-NULL array where a non-NULL element satisfies the condition
query B
Expand Down Expand Up @@ -806,9 +804,9 @@ true
query B
select 5 <> any(make_array(5, NULL));
----
NULL
false

# All-NULL array: all operators should return NULL (unknown comparison)
# All-NULL array: NULL from min/max-based desugars; `= ANY` is false via array_has.
query B
select 5 > any(make_array(NULL::INT, NULL::INT));
----
Expand Down Expand Up @@ -837,7 +835,7 @@ NULL
query B
select 5 = any(make_array(NULL::INT, NULL::INT));
----
NULL
false

# NULL left operand: should return NULL for non-empty arrays
query B
Expand Down Expand Up @@ -865,7 +863,7 @@ select NULL <> any(make_array(1, 2, 3));
----
NULL

# NULL left operand with empty array: should return false
# NULL left operand + empty array: vacuous false (cardinality guard short-circuits).
query B
select NULL > any(make_array());
----
Expand Down Expand Up @@ -920,11 +918,11 @@ select 5 = any(make_array(5, NULL));
----
true

# = ANY with mixed NULL (non-satisfying): NULLs leave result indeterminate
# = ANY with mixed NULL, no match: false (array_has treats NULL as absent).
query B
select 5 = any(make_array(1, 2, NULL));
----
NULL
false

statement ok
DROP TABLE any_op_test;
Expand Down
Loading