From 2fdd9d2cbf0ce7fd0d9d31f20519eba985fe49d1 Mon Sep 17 00:00:00 2001 From: akhadka Date: Wed, 6 May 2026 16:13:17 -0700 Subject: [PATCH 1/3] [SPARK-50593][SQL] Generalize ReducibleFunction reducer API with ReducibleParameters container --- .../catalog/functions/ReducibleFunction.java | 43 +++++++++ .../functions/ReducibleParameters.java | 93 +++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index ef1a14e50cdad..bfbf186516b10 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -78,7 +78,10 @@ public interface ReducibleFunction { * @param otherBucketFunction the other parameterized function * @param otherNumBuckets parameter for the other function * @return a reduction function if it is reducible, null if not + * @deprecated Use {@link #reducer(ReducibleParameters, ReducibleFunction, ReducibleParameters)} + * for generic parameterized transforms. */ + @Deprecated default Reducer reducer( int thisNumBuckets, ReducibleFunction otherBucketFunction, @@ -103,4 +106,44 @@ default Reducer reducer( default Reducer reducer(ReducibleFunction otherFunction) { throw new UnsupportedOperationException(); } + + /** + * Generic reducer for any parameterized transform function. + *

+ * This extends SPJ support beyond bucket to transforms like truncate, which use + * non-integer parameters or multiple parameters. + *

+ * Example of reducing f_source = truncate(x, 5) on f_target = truncate(x, 3): + *

    + *
  • thisParams = ReducibleParameters([5])
  • + *
  • otherFunction = truncate
  • + *
  • otherParams = ReducibleParameters([3])
  • + *
  • reducer truncates to min(5, 3) = 3
  • + *
+ *

+ * Default implementation provides backward compatibility: if both parameter sets + * contain a single integer, delegates to {@link #reducer(int, ReducibleFunction, int)}. + * + * @param thisParams parameters for this function + * @param otherFunction the other reducible function + * @param otherParams parameters for the other function + * @return a reduction function if it is reducible + * @throws UnsupportedOperationException if not reducible + * @since 4.1.0 + */ + default Reducer reducer( + ReducibleParameters thisParams, + ReducibleFunction otherFunction, + ReducibleParameters otherParams) { + // Backward compatibility: single-int params → delegate to old bucket API + if (thisParams.count() == 1 && otherParams.count() == 1) { + try { + return reducer(thisParams.getInt(0), otherFunction, otherParams.getInt(0)); + } catch (ClassCastException | NumberFormatException ignored) { + // Not int parameters, fall through + } + } + throw new UnsupportedOperationException( + "reducer() with ReducibleParameters not implemented"); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java new file mode 100644 index 0000000000000..03716225fac2d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector.catalog.functions; + +import java.util.Collections; +import java.util.List; + +import org.apache.spark.annotation.Evolving; + +/** + * Container for parameters of a {@link ReducibleFunction}. + *

+ * Provides type-safe access to function parameters for generic reducer comparisons, + * enabling SPJ support for any parameterized transform (not just bucket). + *

+ * Examples: + *

    + *
  • bucket(4, x) → ReducibleParameters([4])
  • + *
  • truncate(x, 3) → ReducibleParameters([3])
  • + *
  • bucket(16, x) → ReducibleParameters([16])
  • + *
+ * + * @since 4.1.0 + */ +@Evolving +public class ReducibleParameters { + private final List values; + + public ReducibleParameters(List values) { + this.values = Collections.unmodifiableList(values); + } + + /** Number of parameters. */ + public int count() { + return values.size(); + } + + /** Get raw parameter value at index. */ + public Object get(int index) { + return values.get(index); + } + + /** Get parameter as int. Throws ClassCastException if not numeric. */ + public int getInt(int index) { + return ((Number) values.get(index)).intValue(); + } + + /** Get parameter as long. Throws ClassCastException if not numeric. */ + public long getLong(int index) { + return ((Number) values.get(index)).longValue(); + } + + /** Get parameter as String. */ + public String getString(int index) { + return (String) values.get(index); + } + + /** Get parameter as double. Throws ClassCastException if not numeric. */ + public double getDouble(int index) { + return ((Number) values.get(index)).doubleValue(); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (!(other instanceof ReducibleParameters)) return false; + return values.equals(((ReducibleParameters) other).values); + } + + @Override + public int hashCode() { + return values.hashCode(); + } + + @Override + public String toString() { + return "ReducibleParameters(" + values + ")"; + } +} From 81d63dd4af0b2f91d526a7ac8c42288b0674b4f8 Mon Sep 17 00:00:00 2001 From: akhadka Date: Wed, 6 May 2026 18:55:00 -0700 Subject: [PATCH 2/3] [SPARK-50593][SQL] Support truncate transform for Storage Partitioned Joins by generalizing parameter handling --- .../catalog/functions/ReducibleFunction.java | 89 ++++++------ .../functions/ReducibleParameters.java | 102 +++++++++---- .../expressions/TransformExpression.scala | 132 +++++++++++++---- .../expressions/V2ExpressionUtils.scala | 13 +- .../plans/physical/partitioning.scala | 12 +- .../v2/DistributionAndOrderingUtils.scala | 6 +- .../KeyGroupedPartitioningSuite.scala | 135 +++++++++++++++++- .../functions/transformFunctions.scala | 83 ++++++++++- ...rojectedOrderingAndPartitioningSuite.scala | 8 +- .../exchange/EnsureRequirementsSuite.scala | 4 +- 10 files changed, 450 insertions(+), 134 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index bfbf186516b10..13b033a98bec8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -60,6 +60,53 @@ @Evolving public interface ReducibleFunction { + /** + * Generic reducer for parameterized functions (bucket, truncate, etc.). + * + * If this function is 'reducible' on another function, return the {@link Reducer}. + *

+ * This method supports functions with any number of parameters of any type. + *

+ * Examples: + *

    + *
  • bucket(4, x) and bucket(2, x): + *
    thisParams = [4], otherParams = [2] + *
    Extract with: thisParams.getInt(0), otherParams.getInt(0) + *
  • + *
  • truncate(x, 3) and truncate(x, 5): + *
    thisParams = [3], otherParams = [5] + *
    Extract with: thisParams.getInt(0), otherParams.getInt(0) + *
  • + *
  • hypothetical range_bucket(x, 0L, 100L, 4): + *
    thisParams = [0L, 100L, 4] + *
    Extract with: thisParams.getLong(0), thisParams.getLong(1), thisParams.getInt(2) + *
  • + *
+ * + * @param thisParams parameters for this function + * @param otherFunction the other parameterized function + * @param otherParams parameters for the other function + * @return a reduction function if reducible, null otherwise + * @since 4.0.0 + */ + default Reducer reducer( + ReducibleParameters thisParams, + ReducibleFunction otherFunction, + ReducibleParameters otherParams) { + // Default: try old Int-based API for backward compatibility + if (thisParams.count() == 1 && otherParams.count() == 1) { + try { + return reducer( + thisParams.getInt(0), + otherFunction, + otherParams.getInt(0)); + } catch (ClassCastException ignored) { + // Not Int parameters, fall through + } + } + throw new UnsupportedOperationException(); + } + /** * This method is for the bucket function. * @@ -78,8 +125,6 @@ public interface ReducibleFunction { * @param otherBucketFunction the other parameterized function * @param otherNumBuckets parameter for the other function * @return a reduction function if it is reducible, null if not - * @deprecated Use {@link #reducer(ReducibleParameters, ReducibleFunction, ReducibleParameters)} - * for generic parameterized transforms. */ @Deprecated default Reducer reducer( @@ -106,44 +151,4 @@ default Reducer reducer( default Reducer reducer(ReducibleFunction otherFunction) { throw new UnsupportedOperationException(); } - - /** - * Generic reducer for any parameterized transform function. - *

- * This extends SPJ support beyond bucket to transforms like truncate, which use - * non-integer parameters or multiple parameters. - *

- * Example of reducing f_source = truncate(x, 5) on f_target = truncate(x, 3): - *

    - *
  • thisParams = ReducibleParameters([5])
  • - *
  • otherFunction = truncate
  • - *
  • otherParams = ReducibleParameters([3])
  • - *
  • reducer truncates to min(5, 3) = 3
  • - *
- *

- * Default implementation provides backward compatibility: if both parameter sets - * contain a single integer, delegates to {@link #reducer(int, ReducibleFunction, int)}. - * - * @param thisParams parameters for this function - * @param otherFunction the other reducible function - * @param otherParams parameters for the other function - * @return a reduction function if it is reducible - * @throws UnsupportedOperationException if not reducible - * @since 4.1.0 - */ - default Reducer reducer( - ReducibleParameters thisParams, - ReducibleFunction otherFunction, - ReducibleParameters otherParams) { - // Backward compatibility: single-int params → delegate to old bucket API - if (thisParams.count() == 1 && otherParams.count() == 1) { - try { - return reducer(thisParams.getInt(0), otherFunction, otherParams.getInt(0)); - } catch (ClassCastException | NumberFormatException ignored) { - // Not int parameters, fall through - } - } - throw new UnsupportedOperationException( - "reducer() with ReducibleParameters not implemented"); - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java index 03716225fac2d..4e0e3232a1c12 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java @@ -16,69 +16,115 @@ */ package org.apache.spark.sql.connector.catalog.functions; -import java.util.Collections; -import java.util.List; - import org.apache.spark.annotation.Evolving; +import java.util.Arrays; +import java.util.List; /** - * Container for parameters of a {@link ReducibleFunction}. - *

- * Provides type-safe access to function parameters for generic reducer comparisons, - * enabling SPJ support for any parameterized transform (not just bucket). - *

+ * Container for reducible function literal parameters. + * Provides type-safe access to parameters of various types. + * * Examples: *

    - *
  • bucket(4, x) → ReducibleParameters([4])
  • - *
  • truncate(x, 3) → ReducibleParameters([3])
  • - *
  • bucket(16, x) → ReducibleParameters([16])
  • + *
  • bucket(4, col) → ReducibleParameters([4])
  • + *
  • truncate(col, 3) → ReducibleParameters([3])
  • + *
  • range_bucket(col, 0L, 100L, 10) → ReducibleParameters([0L, 100L, 10])
  • + *
  • custom_transform(col, "param") → ReducibleParameters(["param"])
  • *
* - * @since 4.1.0 + * @since 4.0.0 */ @Evolving public class ReducibleParameters { private final List values; public ReducibleParameters(List values) { - this.values = Collections.unmodifiableList(values); + this.values = values; } - /** Number of parameters. */ + public ReducibleParameters(Object... values) { + this.values = Arrays.asList(values); + } + + /** + * Get the number of parameters. + */ public int count() { return values.size(); } - /** Get raw parameter value at index. */ - public Object get(int index) { - return values.get(index); + /** + * Check if this container has parameters. + */ + public boolean isEmpty() { + return values.isEmpty(); } - /** Get parameter as int. Throws ClassCastException if not numeric. */ + /** + * Get parameter at index as Integer. + * @throws ClassCastException if parameter is not an Integer + * @throws IndexOutOfBoundsException if index is invalid + */ public int getInt(int index) { - return ((Number) values.get(index)).intValue(); + return (Integer) values.get(index); } - /** Get parameter as long. Throws ClassCastException if not numeric. */ + /** + * Get parameter at index as Long. + * @throws ClassCastException if parameter is not a Long + * @throws IndexOutOfBoundsException if index is invalid + */ public long getLong(int index) { - return ((Number) values.get(index)).longValue(); + return (Long) values.get(index); } - /** Get parameter as String. */ + /** + * Get parameter at index as String. + * @throws ClassCastException if parameter is not a String + * @throws IndexOutOfBoundsException if index is invalid + */ public String getString(int index) { return (String) values.get(index); } - /** Get parameter as double. Throws ClassCastException if not numeric. */ + /** + * Get parameter at index as Double. + * @throws ClassCastException if parameter is not a Double + * @throws IndexOutOfBoundsException if index is invalid + */ public double getDouble(int index) { - return ((Number) values.get(index)).doubleValue(); + return (Double) values.get(index); + } + + /** + * Get parameter at index as Float. + * @throws ClassCastException if parameter is not a Float + * @throws IndexOutOfBoundsException if index is invalid + */ + public float getFloat(int index) { + return (Float) values.get(index); + } + + /** + * Get raw parameter value at index. + */ + public Object get(int index) { + return values.get(index); + } + + /** + * Get all parameter values as a list. + */ + public List getAll() { + return values; } @Override - public boolean equals(Object other) { - if (this == other) return true; - if (!(other instanceof ReducibleParameters)) return false; - return values.equals(((ReducibleParameters) other).values); + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ReducibleParameters that = (ReducibleParameters) o; + return values.equals(that.values); } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 9041ed15fc501..eaf79bb6475c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction} +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ReducibleParameters, ScalarFunction} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.DataType @@ -28,35 +28,87 @@ import org.apache.spark.sql.types.DataType * * @param function the transform function itself. Spark will use it to decide whether two * partition transform expressions are compatible. - * @param numBucketsOpt the number of buckets if the transform is `bucket`. Unset otherwise. */ -case class TransformExpression( - function: BoundFunction, - children: Seq[Expression], - numBucketsOpt: Option[Int] = None) extends Expression { +case class TransformExpression(function: BoundFunction, children: Seq[Expression]) + extends Expression { override def nullable: Boolean = true /** - * Whether this [[TransformExpression]] has the same semantics as `other`. - * For instance, `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or - * `year(c)`. + * Extract literal children (constant parameters) from this transform. These are constant + * arguments like width in truncate(col, width). Literals are compared when checking if two + * transforms are the same. + */ + private lazy val literalChildren: Seq[Literal] = + children.collect { case l: Literal => l } + + /** + * Whether this [[TransformExpression]] has the same semantics as `other`. For instance, + * `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or `year(c)`. + * Similarly, `truncate(c, 2)` is equal to `truncate(d, 2)`, but may not to `truncate(c, 4)`. * * This will be used, for instance, by Spark to determine whether storage-partitioned join can * be triggered, by comparing partition transforms from both sides of the join and checking * whether they are compatible. * - * @param other the transform expression to compare to - * @return true if this and `other` has the same semantics w.r.t to transform, false otherwise. + * Two transforms are considered the same if: + * 1. They have the same function name + * 2. They have the same literal arguments (e.g., numBuckets for bucket, width for truncate) + * + * @param other + * the transform expression to compare to + * @return + * true if this and `other` has the same semantics w.r.t to transform, false otherwise. */ def isSameFunction(other: TransformExpression): Boolean = other match { - case TransformExpression(otherFunction, _, otherNumBucketsOpt) => - function.canonicalName() == otherFunction.canonicalName() && - numBucketsOpt == otherNumBucketsOpt + case TransformExpression(otherFunction, _) => + val sameFunctionName = function.canonicalName() == otherFunction.canonicalName() + + // Compare literal arguments to ensure transforms with different parameters + // (e.g., bucket(32, col) vs bucket(16, col), truncate(col, 2) vs truncate(col, 4)) + // are not considered the same + val otherLiterals = other.literalChildren + val sameLiterals = literalChildren.length == otherLiterals.length && + literalChildren.zip(otherLiterals).forall { case (l1, l2) => + l1.equals(l2) + } + + sameFunctionName && sameLiterals case _ => false } + /** + * Override canonicalized to ensure transforms with the same function and literals are + * considered semantically equal, regardless of which specific column references they use. + * + * This is crucial for Storage Partitioned Joins - we need bucket(4, tableA.id) and bucket(4, + * tableB.id) to be semantically equal so SPJ can be triggered. + */ + override lazy val canonicalized: Expression = { + // Canonicalize only the non-literal children (i.e., column references) + val canonicalizedReferenceChildren = children.map { + case l: Literal => l + case other => other.canonicalized + } + TransformExpression(function, canonicalizedReferenceChildren) + } + + /** + * Override collectLeaves to only return reference children (columns), not literal parameters. + * + * For TransformExpression, literal children are metadata about the transform function (e.g., + * numBuckets=4 in bucket(4, col), width=2 in truncate(col, 2)). All consumers of + * collectLeaves() expect only column references, not these metadata literals. + * + */ + override def collectLeaves(): Seq[Expression] = { + children.flatMap { + case _: Literal => Seq.empty // Skip literal parameters (metadata) + case other => other.collectLeaves() // Include column references + } + } + /** * Whether this [[TransformExpression]]'s function is compatible with the `other` * [[TransformExpression]]'s function. @@ -73,8 +125,8 @@ case class TransformExpression( } else { (function, other.function) match { case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) => - val thisReducer = reducer(f, numBucketsOpt, o, other.numBucketsOpt) - val otherReducer = reducer(o, other.numBucketsOpt, f, numBucketsOpt) + val thisReducer = reducer(f, this, o, other) + val otherReducer = reducer(o, other, f, this) thisReducer.isDefined || otherReducer.isDefined case _ => false } @@ -92,22 +144,47 @@ case class TransformExpression( */ def reducers(other: TransformExpression): Option[Reducer[_, _]] = { (function, other.function) match { - case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) => - reducer(e1, numBucketsOpt, e2, other.numBucketsOpt) + case (e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) => + reducer(e1, this, e2, other) case _ => None } } - // Return a Reducer for a reducible function on another reducible function + /** + * Extract all literal parameters from a transform expression. + * Returns ReducibleParameters containing the literal values in order. + * + * Examples: + * bucket(4, col) => ReducibleParameters([4]) + * truncate(col, 3) => ReducibleParameters([3]) + * days(col) => ReducibleParameters([]) (no literals) + */ + private def extractParameters(expr: TransformExpression): ReducibleParameters = { + import scala.jdk.CollectionConverters._ + val values = expr.literalChildren.map { + case Literal(value, _) => value.asInstanceOf[AnyRef] + } + new ReducibleParameters(values.asJava) + } + + /** + * Return a Reducer for a reducible function on another reducible function + * Handles both parameterized (bucket, truncate) and non-parameterized (days, hours) functions. + */ private def reducer( thisFunction: ReducibleFunction[_, _], - thisNumBucketsOpt: Option[Int], + thisExpr: TransformExpression, otherFunction: ReducibleFunction[_, _], - otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = { - val res = (thisNumBucketsOpt, otherNumBucketsOpt) match { - case (Some(numBuckets), Some(otherNumBuckets)) => - thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets) - case _ => thisFunction.reducer(otherFunction) + otherExpr: TransformExpression): Option[Reducer[_, _]] = { + val thisParams = extractParameters(thisExpr) + val otherParams = extractParameters(otherExpr) + + val res = if (!thisParams.isEmpty && !otherParams.isEmpty) { + // Parameterized functions (bucket, truncate, etc.) + thisFunction.reducer(thisParams, otherFunction, otherParams) + } else { + // Non-parameterized functions (days, hours, etc.) + thisFunction.reducer(otherFunction) } Option(res) } @@ -118,10 +195,7 @@ case class TransformExpression( copy(children = newChildren) private lazy val resolvedFunction: Option[Expression] = this match { - case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => - Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, - Seq(Literal(numBuckets)) ++ arguments)) - case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => + case TransformExpression(scalarFunc: ScalarFunction[_], arguments) => Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)) case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index d747bebd5cfe6..2c44264c6e0b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier} import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME -import org.apache.spark.sql.connector.expressions.{BucketTransform, Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} +import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.errors.QueryCompilationErrors @@ -115,17 +115,6 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { funCatalogOpt: Option[FunctionCatalog] = None): Option[Expression] = trans match { case IdentityTransform(ref) => Some(resolveRef[NamedExpression](ref, query)) - case BucketTransform(numBuckets, refs, sorted) - if sorted.isEmpty && refs.length == 1 && refs.forall(_.isInstanceOf[NamedReference]) => - val resolvedRefs = refs.map(r => resolveRef[NamedExpression](r, query)) - // Create a dummy reference for `numBuckets` here and use that, together with `refs`, to - // look up the V2 function. - val numBucketsRef = AttributeReference("numBuckets", IntegerType, nullable = false)() - funCatalogOpt.flatMap { catalog => - loadV2FunctionOpt(catalog, "bucket", Seq(numBucketsRef) ++ resolvedRefs).map { bound => - TransformExpression(bound, resolvedRefs, Some(numBuckets)) - } - } case NamedTransform(name, args) => val catalystArgs = args.map(toCatalyst(_, query, funCatalogOpt)) funCatalogOpt.flatMap { catalog => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cc50da1f17fdf..adc4bf48892b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -552,7 +552,9 @@ object KeyedPartitioning { def supportsExpressions(expressions: Seq[Expression]): Boolean = { def isSupportedTransform(transform: TransformExpression): Boolean = { - transform.children.size == 1 && isReference(transform.children.head) + // TransformExpression.collectLeaves() only returns column references, not literals. + // We need exactly one column reference per transform. + transform.collectLeaves().size == 1 } @tailrec @@ -1093,7 +1095,13 @@ case class KeyedShuffleSpec( val newExpressions = partitioning.expressions.zip(keyPositions).map { case (te: TransformExpression, positionSet) => - te.copy(children = te.children.map(_ => clustering(positionSet.head))) + // Preserve literal parameters (e.g., numBuckets, truncate width) + // while replacing only column references with the new clustering expression + val newChildren = te.children.map { + case l: Literal => l // Keep literals as-is + case _ => clustering(positionSet.head) // Replace column references + } + te.copy(children = newChildren) case (_, positionSet) => clustering(positionSet.head) } KeyedPartitioning(newExpressions, partitioning.partitionKeys, partitioning.isGrouped) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala index 02e19dd053f29..9f3c4daa7a6f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, ResolveTimeZone, TypeCoercion} -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder, TransformExpression, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder, TransformExpression, V2ExpressionUtils} import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartitions, RepartitionByExpression, Sort} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} @@ -96,9 +96,7 @@ object DistributionAndOrderingUtils { } private def resolveTransformExpression(expr: Expression): Expression = expr.transform { - case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => - V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments) - case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => + case TransformExpression(scalarFunc: ScalarFunction[_], arguments) => V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 2a0ab52c36933..bd5119a729ce4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.functions.{col, max} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with ExplainSuiteHelper { private val functions = Seq( @@ -126,7 +127,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val df = sql(s"SELECT * FROM testcat.ns.$table") val distribution = physical.ClusteredDistribution( - Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) + Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts"))))) checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) } @@ -138,7 +139,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val df = sql(s"SELECT * FROM testcat.ns.$table") val distribution = physical.ClusteredDistribution( - Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) + Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts"))))) // Has exactly one partition. val partitionKeys = Seq(0).map(v => InternalRow.fromSeq(Seq(v))) @@ -194,13 +195,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val df = sql(s"SELECT * FROM testcat.ns.$table") val distribution = physical.ClusteredDistribution( - Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) + Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts"))))) checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) } } - test("non-clustered distribution: V2 function with multiple args") { + test("clustered distribution: V2 function with multiple args") { val partitions: Array[Transform] = Array( Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(2)) ) @@ -216,7 +217,11 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val distribution = physical.ClusteredDistribution( Seq(TransformExpression(TruncateFunction, Seq(attr("data"), Literal(2))))) - checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) + // With truncate transform support, KeyedPartitioning should now work + val partitionKeys = Seq("aa", "bb", "cc").map(v => + InternalRow(UTF8String.fromString(v))) + checkQueryPlan(df, distribution, + physical.KeyedPartitioning(distribution.clustering, partitionKeys)) } /** @@ -4183,4 +4188,124 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } } } + + // === SPARK-50593: Truncate SPJ support tests === + + test("SPARK-50593: cross-function truncate vs bucket should NOT trigger SPJ") { + val partitions1 = Array( + Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(3)) + ) + val partitions2 = Array( + Expressions.bucket(4, "data") + ) + + createTable("trunc_cross1", columns, partitions1) + sql("INSERT INTO testcat.ns.trunc_cross1 VALUES " + + "(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + "(1, 'bbb', CAST('2021-01-01' AS timestamp))") + + createTable("trunc_cross2", columns2, partitions2) + sql("INSERT INTO testcat.ns.trunc_cross2 VALUES " + + "(1, 5, 'aaa'), " + + "(5, 10, 'bbb')") + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + + val df = sql( + selectWithMergeJoinHint("trunc_cross1", "trunc_cross2") + + "trunc_cross1.id, trunc_cross2.store_id " + + "FROM testcat.ns.trunc_cross1 JOIN testcat.ns.trunc_cross2 " + + "ON trunc_cross1.data = trunc_cross2.data " + + "ORDER BY trunc_cross1.id") + + // Different functions (truncate vs bucket) should NEVER enable SPJ + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.nonEmpty, + "truncate vs bucket should not trigger SPJ, but no shuffles found") + } + } + + test("SPARK-50593: TransformExpression.collectLeaves filters out literals") { + // bucket(4, col) has children = [Literal(4), col] but collectLeaves should return [col] + val col = attr("data") + val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(4), col)) + val leaves = bucketExpr.collectLeaves() + assert(leaves.size == 1, s"Expected 1 leaf (column ref), got ${leaves.size}: $leaves") + assert(leaves.head.semanticEquals(col), + s"Expected leaf to be the column reference, got ${leaves.head}") + + // truncate(col, 3) has children = [col, Literal(3)] but collectLeaves should return [col] + val truncExpr = TransformExpression(TruncateFunction, Seq(col, Literal(3))) + val truncLeaves = truncExpr.collectLeaves() + assert(truncLeaves.size == 1, + s"Expected 1 leaf (column ref), got ${truncLeaves.size}: $truncLeaves") + assert(truncLeaves.head.semanticEquals(col), + s"Expected leaf to be the column reference, got ${truncLeaves.head}") + + // years(col) has children = [col] with no literals + val yearsExpr = TransformExpression(YearsFunction, Seq(col)) + val yearsLeaves = yearsExpr.collectLeaves() + assert(yearsLeaves.size == 1, + s"Expected 1 leaf for years(), got ${yearsLeaves.size}: $yearsLeaves") + } + + test("SPARK-50593: existing bucket SPJ still works with ReducibleParameters API") { + // This test verifies that the migration from reducer(int, func, int) + // to reducer(ReducibleParameters, func, ReducibleParameters) is backward compatible. + // BucketFunction now implements the new API but bucket SPJ should still work. + val table1 = "bucket_compat1" + val table2 = "bucket_compat2" + + val partitions1 = Array(Expressions.bucket(4, "id")) + val partitions2 = Array(Expressions.bucket(4, "store_id")) + + createTable(table1, columns, partitions1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + "(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + + "(2, 'ccc', CAST('2020-01-01' AS timestamp)), " + + "(3, 'ddd', CAST('2019-01-01' AS timestamp))") + + createTable(table2, columns2, partitions2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(0, 5, 'aaa'), " + + "(1, 10, 'bbb'), " + + "(2, 15, 'ccc'), " + + "(3, 20, 'ddd')") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") { + val df = sql( + selectWithMergeJoinHint(table1, table2) + + s"$table1.id, $table2.store_id " + + s"FROM testcat.ns.$table1 JOIN testcat.ns.$table2 " + + s"ON $table1.id = $table2.store_id " + + s"ORDER BY $table1.id") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, + "Bucket SPJ should still work after ReducibleParameters migration") + checkAnswer(df, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3))) + } + } + + test("SPARK-50593: ReducibleParameters backward compat - old int API still works via default") { + // The new reducer(ReducibleParameters, func, ReducibleParameters) default implementation + // delegates to the old reducer(int, func, int) for single-int params. + // This verifies bucket(4) vs bucket(2) still produces a reducer via the fallback path. + val bucketExpr4 = TransformExpression(BucketFunction, Seq(Literal(4), attr("id"))) + val bucketExpr2 = TransformExpression(BucketFunction, Seq(Literal(2), attr("id"))) + + // isCompatible should return true (4 and 2 share GCD > 1) + assert(bucketExpr4.isCompatible(bucketExpr2), + "bucket(4) and bucket(2) should be compatible via reducer") + + // reducers() should return a Reducer + val reducer = bucketExpr4.reducers(bucketExpr2) + assert(reducer.isDefined, "Expected a reducer for bucket(4) on bucket(2)") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 35102c6893d3b..8301ad70b0248 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -213,11 +213,14 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In } override def reducer( - thisNumBuckets: Int, + thisParams: ReducibleParameters, otherFunc: ReducibleFunction[_, _], - otherNumBuckets: Int): Reducer[Int, Int] = { + otherParams: ReducibleParameters): Reducer[Int, Int] = { if (otherFunc == BucketFunction) { + val thisNumBuckets = thisParams.getInt(0) + val otherNumBuckets = otherParams.getInt(0) + val gcd = this.gcd(thisNumBuckets, otherNumBuckets) if (gcd > 1 && gcd != thisNumBuckets) { return BucketReducer(gcd) @@ -253,12 +256,35 @@ object StringSelfFunction extends ScalarFunction[UTF8String] { } object UnboundTruncateFunction extends UnboundFunction { - override def bind(inputType: StructType): BoundFunction = TruncateFunction + override def bind(inputType: StructType): BoundFunction = { + if (inputType.size == 2) { + inputType.head.dataType match { + case StringType | BinaryType => TruncateFunction + case IntegerType | LongType => IntegerTruncateFunction + case _ => + throw new UnsupportedOperationException( + s"'truncate' does not support data type: ${inputType.head.dataType}") + } + } else { + throw new UnsupportedOperationException( + "'truncate' requires exactly 2 arguments: (column, width)") + } + } + override def description(): String = name() override def name(): String = "truncate" } -object TruncateFunction extends ScalarFunction[UTF8String] { +/** + * Truncate transform for String/Binary types. + * Follows Iceberg spec: truncate(str, L) = str[0:L] + * + * Implements ReducibleFunction: ANY two different widths are compatible. + * The reducer uses the smaller width. + */ +object TruncateFunction + extends ScalarFunction[UTF8String] + with ReducibleFunction[UTF8String, UTF8String] { override def inputTypes(): Array[DataType] = Array(StringType, IntegerType) override def resultType(): DataType = StringType override def name(): String = "truncate" @@ -266,7 +292,52 @@ object TruncateFunction extends ScalarFunction[UTF8String] { override def toString: String = name() override def produceResult(input: InternalRow): UTF8String = { val str = input.getUTF8String(0) - val length = input.getInt(1) - str.substring(0, length) + val width = input.getInt(1) + str.substring(0, width) + } + + override def reducer( + thisParams: ReducibleParameters, + otherFunc: ReducibleFunction[_, _], + otherParams: ReducibleParameters): Reducer[UTF8String, UTF8String] = { + + if (otherFunc == TruncateFunction) { + val thisWidth = thisParams.getInt(0) + val otherWidth = otherParams.getInt(0) + val smallerWidth = math.min(thisWidth, otherWidth) + + if (smallerWidth != thisWidth) { + return TruncateReducer(smallerWidth) + } + } + null + } +} + +case class TruncateReducer(width: Int) extends Reducer[UTF8String, UTF8String] { + override def reduce(value: UTF8String): UTF8String = { + value.substring(0, width) + } + override def resultType(): DataType = StringType + override def displayName(): String = s"truncate($width)" +} + +/** + * Truncate transform for Integer/Long types. + * Follows Iceberg spec: truncate(value, W) = value - (((value % W) + W) % W) + * + * Does NOT implement ReducibleFunction because different integer truncate widths + * produce incompatible partition structures. + */ +object IntegerTruncateFunction extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType, IntegerType) + override def resultType(): DataType = IntegerType + override def name(): String = "truncate" + override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): Int = { + val value = input.getInt(0) + val width = input.getInt(1) + value - (((value % width) + width) % width) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala index db664b04ef08b..27a934aedcbe3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, TransformExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, TransformExpression} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, KeyedPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning} import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, YearsFunction} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -492,7 +492,7 @@ class ProjectedOrderingAndPartitioningSuite // KP([bucket(32, id)], keys1d) through Project(id as pk) should produce // KP([bucket(32, pk)], keys1d): the alias is pushed into the bucket's column argument. val id = AttributeReference("id", IntegerType)() - val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32)) + val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id)) val keys1d = Seq(InternalRow(0), InternalRow(1), InternalRow(2)) val child = DummyLeafExecWithPartitioning( output = Seq(id), @@ -524,7 +524,7 @@ class ProjectedOrderingAndPartitioningSuite // Result: KP([bucket(32, id)], keys1d, isNarrowed=true, isGrouped=false). val id = AttributeReference("id", IntegerType)() val ts = AttributeReference("ts", IntegerType)() - val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32)) + val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id)) val yearsExpr = TransformExpression(YearsFunction, Seq(ts)) // Projected to position [0] (bucket): (0),(1),(0) -- bucket value 0 appears twice. val keys2d = Seq(InternalRow(0, 2020), InternalRow(1, 2020), InternalRow(0, 2021)) @@ -554,7 +554,7 @@ class ProjectedOrderingAndPartitioningSuite // Result: KP([bucket(32, id), years(ts_alias)], keys2d) -- not narrowed. val id = AttributeReference("id", IntegerType)() val ts = AttributeReference("ts", IntegerType)() - val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32)) + val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id)) val yearsExpr = TransformExpression(YearsFunction, Seq(ts)) val keys2d = Seq(InternalRow(0, 2020), InternalRow(1, 2020), InternalRow(0, 2021)) val child = DummyLeafExecWithPartitioning( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 1e35985f50491..96a4f3679b9b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1191,11 +1191,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { } def bucket(numBuckets: Int, expr: Expression): TransformExpression = { - TransformExpression(BucketFunction, Seq(expr), Some(numBuckets)) + TransformExpression(BucketFunction, Seq(Literal(numBuckets), expr)) } def buckets(numBuckets: Int, expr: Seq[Expression]): TransformExpression = { - TransformExpression(BucketFunction, expr, Some(numBuckets)) + TransformExpression(BucketFunction, Seq(Literal(numBuckets)) ++ expr) } test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when children are compatible") { From c51a582e9fa3787cc97e40eae56936df8d6d6e5b Mon Sep 17 00:00:00 2001 From: akhadka Date: Wed, 13 May 2026 17:17:54 -0700 Subject: [PATCH 3/3] [SPARK-50593][SQL] Strengthen test coverage for truncate SPJ and ReducibleParameters backward compatibility --- .../KeyGroupedPartitioningSuite.scala | 114 +++++++++++++----- .../functions/transformFunctions.scala | 31 +++++ ...rojectedOrderingAndPartitioningSuite.scala | 8 +- .../exchange/EnsureRequirementsSuite.scala | 8 +- 4 files changed, 122 insertions(+), 39 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index bd5119a729ce4..66cf2f19bfdb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -4210,25 +4210,71 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with "(5, 10, 'bbb')") withSQLConf( - SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", - SQLConf.V2_BUCKETING_ALLOW_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { val df = sql( - selectWithMergeJoinHint("trunc_cross1", "trunc_cross2") + - "trunc_cross1.id, trunc_cross2.store_id " + - "FROM testcat.ns.trunc_cross1 JOIN testcat.ns.trunc_cross2 " + - "ON trunc_cross1.data = trunc_cross2.data " + - "ORDER BY trunc_cross1.id") + s""" + |${selectWithMergeJoinHint("trunc_cross1", "trunc_cross2")} + |trunc_cross1.id, trunc_cross2.store_id + |FROM testcat.ns.trunc_cross1 JOIN testcat.ns.trunc_cross2 + |ON trunc_cross1.data = trunc_cross2.data + |ORDER BY trunc_cross1.id + |""".stripMargin) - // Different functions (truncate vs bucket) should NEVER enable SPJ + // Different functions (truncate vs bucket) are not mutually reducible, so a shuffle + // must still be planned. val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.nonEmpty, - "truncate vs bucket should not trigger SPJ, but no shuffles found") + "truncate vs bucket are not compatible - a shuffle should be present, " + + "but none was planned") + checkAnswer(df, Seq(Row(0, 1), Row(1, 5))) } } + test("SPARK-50593: truncate(3) vs truncate(5) triggers SPJ via width reducer") { + // Exercises the ReducibleParameters-based reducer path end-to-end: truncate widths 3 and 5 + // are mutually reducible (reduce the larger to the smaller), so SPJ must avoid the shuffle. + val table1 = "trunc_three" + val table2 = "trunc_five" + + val partitions1 = Array( + Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(3))) + val partitions2 = Array( + Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(5))) + + createTable(table1, columns, partitions1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 'apple', CAST('2022-01-01' AS timestamp)), " + + "(1, 'grape', CAST('2021-01-01' AS timestamp)), " + + "(2, 'orange', CAST('2020-01-01' AS timestamp))") + + createTable(table2, columns, partitions2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(10, 'apple', CAST('2022-01-01' AS timestamp)), " + + "(20, 'grape', CAST('2021-01-01' AS timestamp)), " + + "(30, 'orange', CAST('2020-01-01' AS timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + + val df = sql( + s""" + |${selectWithMergeJoinHint(table1, table2)} + |$table1.id AS left_id, $table2.id AS right_id + |FROM testcat.ns.$table1 JOIN testcat.ns.$table2 + |ON $table1.data = $table2.data + |ORDER BY $table1.id + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, + "truncate(3) vs truncate(5) should avoid shuffle via the width reducer, " + + "but a shuffle was planned") + checkAnswer(df, Seq(Row(0, 10), Row(1, 20), Row(2, 30))) + } + } test("SPARK-50593: TransformExpression.collectLeaves filters out literals") { // bucket(4, col) has children = [Literal(4), col] but collectLeaves should return [col] val col = attr("data") @@ -4254,14 +4300,14 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } test("SPARK-50593: existing bucket SPJ still works with ReducibleParameters API") { - // This test verifies that the migration from reducer(int, func, int) - // to reducer(ReducibleParameters, func, ReducibleParameters) is backward compatible. - // BucketFunction now implements the new API but bucket SPJ should still work. + // Exercises the new ReducibleParameters-based reducer path end-to-end: bucket(4) and + // bucket(2) differ, so SPJ can only avoid the shuffle if BucketFunction's reducer + // (now implemented via ReducibleParameters) correctly returns a GCD-based Reducer. val table1 = "bucket_compat1" val table2 = "bucket_compat2" val partitions1 = Array(Expressions.bucket(4, "id")) - val partitions2 = Array(Expressions.bucket(4, "store_id")) + val partitions2 = Array(Expressions.bucket(2, "store_id")) createTable(table1, columns, partitions1) sql(s"INSERT INTO testcat.ns.$table1 VALUES " + @@ -4278,34 +4324,40 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with "(3, 20, 'ddd')") withSQLConf( - SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") { + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { val df = sql( - selectWithMergeJoinHint(table1, table2) + - s"$table1.id, $table2.store_id " + - s"FROM testcat.ns.$table1 JOIN testcat.ns.$table2 " + - s"ON $table1.id = $table2.store_id " + - s"ORDER BY $table1.id") + s""" + |${selectWithMergeJoinHint(table1, table2)} + |$table1.id, $table2.store_id + |FROM testcat.ns.$table1 JOIN testcat.ns.$table2 + |ON $table1.id = $table2.store_id + |ORDER BY $table1.id + |""".stripMargin) val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, - "Bucket SPJ should still work after ReducibleParameters migration") + "bucket(4) vs bucket(2) should avoid shuffle via the GCD reducer, " + + "but a shuffle was planned") checkAnswer(df, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3))) } } test("SPARK-50593: ReducibleParameters backward compat - old int API still works via default") { - // The new reducer(ReducibleParameters, func, ReducibleParameters) default implementation - // delegates to the old reducer(int, func, int) for single-int params. - // This verifies bucket(4) vs bucket(2) still produces a reducer via the fallback path. - val bucketExpr4 = TransformExpression(BucketFunction, Seq(Literal(4), attr("id"))) - val bucketExpr2 = TransformExpression(BucketFunction, Seq(Literal(2), attr("id"))) - - // isCompatible should return true (4 and 2 share GCD > 1) - assert(bucketExpr4.isCompatible(bucketExpr2), - "bucket(4) and bucket(2) should be compatible via reducer") + // Verifies the default reducer(ReducibleParameters, ...) implementation correctly + // delegates to the deprecated reducer(int, func, int) when a ReducibleFunction only + // overrides the old API. This mirrors how Iceberg 1.10.0 (and earlier) ship without + // knowledge of ReducibleParameters. + val bucketExpr4 = TransformExpression(LegacyBucketFunction, Seq(Literal(4), attr("id"))) + val bucketExpr2 = TransformExpression(LegacyBucketFunction, Seq(Literal(2), attr("id"))) - // reducers() should return a Reducer val reducer = bucketExpr4.reducers(bucketExpr2) - assert(reducer.isDefined, "Expected a reducer for bucket(4) on bucket(2)") + assert(reducer.isDefined, "Expected a reducer for legacy_bucket(4) on legacy_bucket(2)") + + // Verify the returned Reducer actually reduces bucket 4 -> bucket 2 (GCD = 2). + // bucket(4, x) produces values in [0, 4); reducing by GCD=2 gives v % 2. + val r = reducer.get.asInstanceOf[Reducer[Integer, Integer]] + assert(r.reduce(3) == 1, s"Expected reduce(3) == 1, got ${r.reduce(3)}") + assert(r.reduce(2) == 0, s"Expected reduce(2) == 0, got ${r.reduce(2)}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 8301ad70b0248..fc034729330e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -238,6 +238,37 @@ case class BucketReducer(divisor: Int) extends Reducer[Int, Int] { override def displayName(): String = toString } +/** + * A bucket function that only overrides the deprecated `reducer(int, func, int)` method, + * not the new `reducer(ReducibleParameters, func, ReducibleParameters)` method. + * + * Used to verify that the default implementation of the new method correctly falls back + * to the deprecated int-based API, so legacy implementations continue to work. + */ +object LegacyBucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) + override def resultType(): DataType = IntegerType + override def name(): String = "legacy_bucket" + override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): Int = { + Math.floorMod(input.getLong(1), input.getInt(0)) + } + + override def reducer( + thisNumBuckets: Int, + otherFunc: ReducibleFunction[_, _], + otherNumBuckets: Int): Reducer[Int, Int] = { + if (otherFunc == LegacyBucketFunction) { + val gcd = BigInt(thisNumBuckets).gcd(BigInt(otherNumBuckets)).toInt + if (gcd > 1 && gcd != thisNumBuckets) { + return BucketReducer(gcd) + } + } + null + } +} + object UnboundStringSelfFunction extends UnboundFunction { override def bind(inputType: StructType): BoundFunction = StringSelfFunction override def description(): String = name() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala index 27a934aedcbe3..56df2956ccde4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala @@ -507,7 +507,7 @@ class ProjectedOrderingAndPartitioningSuite case te: TransformExpression => assert(te.isSameFunction(bucketExpr), "bucket function and numBuckets must be preserved after alias substitution") - assert(te.children.head.asInstanceOf[Attribute].name === "pk", + assert(te.children.collectFirst { case a: Attribute => a }.get.name === "pk", "bucket's column argument must be rewritten to the aliased attribute") case other => fail(s"Expected TransformExpression, got $other") } @@ -539,7 +539,7 @@ class ProjectedOrderingAndPartitioningSuite kp.expressions.head match { case te: TransformExpression => assert(te.isSameFunction(bucketExpr), "bucket must be the surviving expression") - assert(te.children.head.asInstanceOf[Attribute].name === "id") + assert(te.children.collectFirst { case a: Attribute => a }.get.name === "id") case other => fail(s"Expected TransformExpression, got $other") } assert(kp.isNarrowed, "dropping years(ts) position must mark the KP as narrowed") @@ -569,14 +569,14 @@ class ProjectedOrderingAndPartitioningSuite kp.expressions(0) match { case te: TransformExpression => assert(te.isSameFunction(bucketExpr)) - assert(te.children.head.asInstanceOf[Attribute].name === "id", + assert(te.children.collectFirst { case a: Attribute => a }.get.name === "id", "bucket's argument must remain id (no alias for id in this projection)") case other => fail(s"Expected TransformExpression at pos 0, got $other") } kp.expressions(1) match { case te: TransformExpression => assert(te.isSameFunction(yearsExpr)) - assert(te.children.head.asInstanceOf[Attribute].name === "ts_alias", + assert(te.children.collectFirst { case a: Attribute => a }.get.name === "ts_alias", "years() argument must be rewritten to ts_alias") case other => fail(s"Expected TransformExpression at pos 1, got $other") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 96a4f3679b9b3..b3d4aad59bc56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -40,10 +40,10 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class EnsureRequirementsSuite extends SharedSparkSession { - private val exprA = Literal(1) - private val exprB = Literal(2) - private val exprC = Literal(3) - private val exprD = Literal(4) + private val exprA = AttributeReference("a", IntegerType)() + private val exprB = AttributeReference("b", IntegerType)() + private val exprC = AttributeReference("c", IntegerType)() + private val exprD = AttributeReference("d", IntegerType)() private val EnsureRequirements = new EnsureRequirements()