diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java
index ef1a14e50cdad..13b033a98bec8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java
@@ -60,6 +60,53 @@
@Evolving
public interface ReducibleFunction {
+ /**
+ * Generic reducer for parameterized functions (bucket, truncate, etc.).
+ *
+ * If this function is 'reducible' on another function, return the {@link Reducer}.
+ *
+ * This method supports functions with any number of parameters of any type.
+ *
+ * Examples:
+ *
+ * bucket(4, x) and bucket(2, x):
+ * thisParams = [4], otherParams = [2]
+ * Extract with: thisParams.getInt(0), otherParams.getInt(0)
+ *
+ * truncate(x, 3) and truncate(x, 5):
+ * thisParams = [3], otherParams = [5]
+ * Extract with: thisParams.getInt(0), otherParams.getInt(0)
+ *
+ * hypothetical range_bucket(x, 0L, 100L, 4):
+ * thisParams = [0L, 100L, 4]
+ * Extract with: thisParams.getLong(0), thisParams.getLong(1), thisParams.getInt(2)
+ *
+ *
+ *
+ * @param thisParams parameters for this function
+ * @param otherFunction the other parameterized function
+ * @param otherParams parameters for the other function
+ * @return a reduction function if reducible, null otherwise
+ * @since 4.0.0
+ */
+ default Reducer reducer(
+ ReducibleParameters thisParams,
+ ReducibleFunction, ?> otherFunction,
+ ReducibleParameters otherParams) {
+ // Default: try old Int-based API for backward compatibility
+ if (thisParams.count() == 1 && otherParams.count() == 1) {
+ try {
+ return reducer(
+ thisParams.getInt(0),
+ otherFunction,
+ otherParams.getInt(0));
+ } catch (ClassCastException ignored) {
+ // Not Int parameters, fall through
+ }
+ }
+ throw new UnsupportedOperationException();
+ }
+
/**
* This method is for the bucket function.
*
@@ -79,6 +126,7 @@ public interface ReducibleFunction {
* @param otherNumBuckets parameter for the other function
* @return a reduction function if it is reducible, null if not
*/
+ @Deprecated
default Reducer reducer(
int thisNumBuckets,
ReducibleFunction, ?> otherBucketFunction,
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java
new file mode 100644
index 0000000000000..4e0e3232a1c12
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connector.catalog.functions;
+
+import org.apache.spark.annotation.Evolving;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Container for reducible function literal parameters.
+ * Provides type-safe access to parameters of various types.
+ *
+ * Examples:
+ *
+ * bucket(4, col) → ReducibleParameters([4])
+ * truncate(col, 3) → ReducibleParameters([3])
+ * range_bucket(col, 0L, 100L, 10) → ReducibleParameters([0L, 100L, 10])
+ * custom_transform(col, "param") → ReducibleParameters(["param"])
+ *
+ *
+ * @since 4.0.0
+ */
+@Evolving
+public class ReducibleParameters {
+ private final List values;
+
+ public ReducibleParameters(List values) {
+ this.values = values;
+ }
+
+ public ReducibleParameters(Object... values) {
+ this.values = Arrays.asList(values);
+ }
+
+ /**
+ * Get the number of parameters.
+ */
+ public int count() {
+ return values.size();
+ }
+
+ /**
+ * Check if this container has parameters.
+ */
+ public boolean isEmpty() {
+ return values.isEmpty();
+ }
+
+ /**
+ * Get parameter at index as Integer.
+ * @throws ClassCastException if parameter is not an Integer
+ * @throws IndexOutOfBoundsException if index is invalid
+ */
+ public int getInt(int index) {
+ return (Integer) values.get(index);
+ }
+
+ /**
+ * Get parameter at index as Long.
+ * @throws ClassCastException if parameter is not a Long
+ * @throws IndexOutOfBoundsException if index is invalid
+ */
+ public long getLong(int index) {
+ return (Long) values.get(index);
+ }
+
+ /**
+ * Get parameter at index as String.
+ * @throws ClassCastException if parameter is not a String
+ * @throws IndexOutOfBoundsException if index is invalid
+ */
+ public String getString(int index) {
+ return (String) values.get(index);
+ }
+
+ /**
+ * Get parameter at index as Double.
+ * @throws ClassCastException if parameter is not a Double
+ * @throws IndexOutOfBoundsException if index is invalid
+ */
+ public double getDouble(int index) {
+ return (Double) values.get(index);
+ }
+
+ /**
+ * Get parameter at index as Float.
+ * @throws ClassCastException if parameter is not a Float
+ * @throws IndexOutOfBoundsException if index is invalid
+ */
+ public float getFloat(int index) {
+ return (Float) values.get(index);
+ }
+
+ /**
+ * Get raw parameter value at index.
+ */
+ public Object get(int index) {
+ return values.get(index);
+ }
+
+ /**
+ * Get all parameter values as a list.
+ */
+ public List getAll() {
+ return values;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ ReducibleParameters that = (ReducibleParameters) o;
+ return values.equals(that.values);
+ }
+
+ @Override
+ public int hashCode() {
+ return values.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return "ReducibleParameters(" + values + ")";
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
index 9041ed15fc501..eaf79bb6475c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction}
+import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ReducibleParameters, ScalarFunction}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
@@ -28,35 +28,87 @@ import org.apache.spark.sql.types.DataType
*
* @param function the transform function itself. Spark will use it to decide whether two
* partition transform expressions are compatible.
- * @param numBucketsOpt the number of buckets if the transform is `bucket`. Unset otherwise.
*/
-case class TransformExpression(
- function: BoundFunction,
- children: Seq[Expression],
- numBucketsOpt: Option[Int] = None) extends Expression {
+case class TransformExpression(function: BoundFunction, children: Seq[Expression])
+ extends Expression {
override def nullable: Boolean = true
/**
- * Whether this [[TransformExpression]] has the same semantics as `other`.
- * For instance, `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or
- * `year(c)`.
+ * Extract literal children (constant parameters) from this transform. These are constant
+ * arguments like width in truncate(col, width). Literals are compared when checking if two
+ * transforms are the same.
+ */
+ private lazy val literalChildren: Seq[Literal] =
+ children.collect { case l: Literal => l }
+
+ /**
+ * Whether this [[TransformExpression]] has the same semantics as `other`. For instance,
+ * `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or `year(c)`.
+ * Similarly, `truncate(c, 2)` is equal to `truncate(d, 2)`, but may not to `truncate(c, 4)`.
*
* This will be used, for instance, by Spark to determine whether storage-partitioned join can
* be triggered, by comparing partition transforms from both sides of the join and checking
* whether they are compatible.
*
- * @param other the transform expression to compare to
- * @return true if this and `other` has the same semantics w.r.t to transform, false otherwise.
+ * Two transforms are considered the same if:
+ * 1. They have the same function name
+ * 2. They have the same literal arguments (e.g., numBuckets for bucket, width for truncate)
+ *
+ * @param other
+ * the transform expression to compare to
+ * @return
+ * true if this and `other` has the same semantics w.r.t to transform, false otherwise.
*/
def isSameFunction(other: TransformExpression): Boolean = other match {
- case TransformExpression(otherFunction, _, otherNumBucketsOpt) =>
- function.canonicalName() == otherFunction.canonicalName() &&
- numBucketsOpt == otherNumBucketsOpt
+ case TransformExpression(otherFunction, _) =>
+ val sameFunctionName = function.canonicalName() == otherFunction.canonicalName()
+
+ // Compare literal arguments to ensure transforms with different parameters
+ // (e.g., bucket(32, col) vs bucket(16, col), truncate(col, 2) vs truncate(col, 4))
+ // are not considered the same
+ val otherLiterals = other.literalChildren
+ val sameLiterals = literalChildren.length == otherLiterals.length &&
+ literalChildren.zip(otherLiterals).forall { case (l1, l2) =>
+ l1.equals(l2)
+ }
+
+ sameFunctionName && sameLiterals
case _ =>
false
}
+ /**
+ * Override canonicalized to ensure transforms with the same function and literals are
+ * considered semantically equal, regardless of which specific column references they use.
+ *
+ * This is crucial for Storage Partitioned Joins - we need bucket(4, tableA.id) and bucket(4,
+ * tableB.id) to be semantically equal so SPJ can be triggered.
+ */
+ override lazy val canonicalized: Expression = {
+ // Canonicalize only the non-literal children (i.e., column references)
+ val canonicalizedReferenceChildren = children.map {
+ case l: Literal => l
+ case other => other.canonicalized
+ }
+ TransformExpression(function, canonicalizedReferenceChildren)
+ }
+
+ /**
+ * Override collectLeaves to only return reference children (columns), not literal parameters.
+ *
+ * For TransformExpression, literal children are metadata about the transform function (e.g.,
+ * numBuckets=4 in bucket(4, col), width=2 in truncate(col, 2)). All consumers of
+ * collectLeaves() expect only column references, not these metadata literals.
+ *
+ */
+ override def collectLeaves(): Seq[Expression] = {
+ children.flatMap {
+ case _: Literal => Seq.empty // Skip literal parameters (metadata)
+ case other => other.collectLeaves() // Include column references
+ }
+ }
+
/**
* Whether this [[TransformExpression]]'s function is compatible with the `other`
* [[TransformExpression]]'s function.
@@ -73,8 +125,8 @@ case class TransformExpression(
} else {
(function, other.function) match {
case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) =>
- val thisReducer = reducer(f, numBucketsOpt, o, other.numBucketsOpt)
- val otherReducer = reducer(o, other.numBucketsOpt, f, numBucketsOpt)
+ val thisReducer = reducer(f, this, o, other)
+ val otherReducer = reducer(o, other, f, this)
thisReducer.isDefined || otherReducer.isDefined
case _ => false
}
@@ -92,22 +144,47 @@ case class TransformExpression(
*/
def reducers(other: TransformExpression): Option[Reducer[_, _]] = {
(function, other.function) match {
- case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) =>
- reducer(e1, numBucketsOpt, e2, other.numBucketsOpt)
+ case (e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) =>
+ reducer(e1, this, e2, other)
case _ => None
}
}
- // Return a Reducer for a reducible function on another reducible function
+ /**
+ * Extract all literal parameters from a transform expression.
+ * Returns ReducibleParameters containing the literal values in order.
+ *
+ * Examples:
+ * bucket(4, col) => ReducibleParameters([4])
+ * truncate(col, 3) => ReducibleParameters([3])
+ * days(col) => ReducibleParameters([]) (no literals)
+ */
+ private def extractParameters(expr: TransformExpression): ReducibleParameters = {
+ import scala.jdk.CollectionConverters._
+ val values = expr.literalChildren.map {
+ case Literal(value, _) => value.asInstanceOf[AnyRef]
+ }
+ new ReducibleParameters(values.asJava)
+ }
+
+ /**
+ * Return a Reducer for a reducible function on another reducible function
+ * Handles both parameterized (bucket, truncate) and non-parameterized (days, hours) functions.
+ */
private def reducer(
thisFunction: ReducibleFunction[_, _],
- thisNumBucketsOpt: Option[Int],
+ thisExpr: TransformExpression,
otherFunction: ReducibleFunction[_, _],
- otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = {
- val res = (thisNumBucketsOpt, otherNumBucketsOpt) match {
- case (Some(numBuckets), Some(otherNumBuckets)) =>
- thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets)
- case _ => thisFunction.reducer(otherFunction)
+ otherExpr: TransformExpression): Option[Reducer[_, _]] = {
+ val thisParams = extractParameters(thisExpr)
+ val otherParams = extractParameters(otherExpr)
+
+ val res = if (!thisParams.isEmpty && !otherParams.isEmpty) {
+ // Parameterized functions (bucket, truncate, etc.)
+ thisFunction.reducer(thisParams, otherFunction, otherParams)
+ } else {
+ // Non-parameterized functions (days, hours, etc.)
+ thisFunction.reducer(otherFunction)
}
Option(res)
}
@@ -118,10 +195,7 @@ case class TransformExpression(
copy(children = newChildren)
private lazy val resolvedFunction: Option[Expression] = this match {
- case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) =>
- Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc,
- Seq(Literal(numBuckets)) ++ arguments))
- case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
+ case TransformExpression(scalarFunc: ScalarFunction[_], arguments) =>
Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments))
case _ => None
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
index d747bebd5cfe6..2c44264c6e0b8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
-import org.apache.spark.sql.connector.expressions.{BucketTransform, Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
+import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -115,17 +115,6 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
funCatalogOpt: Option[FunctionCatalog] = None): Option[Expression] = trans match {
case IdentityTransform(ref) =>
Some(resolveRef[NamedExpression](ref, query))
- case BucketTransform(numBuckets, refs, sorted)
- if sorted.isEmpty && refs.length == 1 && refs.forall(_.isInstanceOf[NamedReference]) =>
- val resolvedRefs = refs.map(r => resolveRef[NamedExpression](r, query))
- // Create a dummy reference for `numBuckets` here and use that, together with `refs`, to
- // look up the V2 function.
- val numBucketsRef = AttributeReference("numBuckets", IntegerType, nullable = false)()
- funCatalogOpt.flatMap { catalog =>
- loadV2FunctionOpt(catalog, "bucket", Seq(numBucketsRef) ++ resolvedRefs).map { bound =>
- TransformExpression(bound, resolvedRefs, Some(numBuckets))
- }
- }
case NamedTransform(name, args) =>
val catalystArgs = args.map(toCatalyst(_, query, funCatalogOpt))
funCatalogOpt.flatMap { catalog =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index cc50da1f17fdf..adc4bf48892b8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -552,7 +552,9 @@ object KeyedPartitioning {
def supportsExpressions(expressions: Seq[Expression]): Boolean = {
def isSupportedTransform(transform: TransformExpression): Boolean = {
- transform.children.size == 1 && isReference(transform.children.head)
+ // TransformExpression.collectLeaves() only returns column references, not literals.
+ // We need exactly one column reference per transform.
+ transform.collectLeaves().size == 1
}
@tailrec
@@ -1093,7 +1095,13 @@ case class KeyedShuffleSpec(
val newExpressions = partitioning.expressions.zip(keyPositions).map {
case (te: TransformExpression, positionSet) =>
- te.copy(children = te.children.map(_ => clustering(positionSet.head)))
+ // Preserve literal parameters (e.g., numBuckets, truncate width)
+ // while replacing only column references with the new clustering expression
+ val newChildren = te.children.map {
+ case l: Literal => l // Keep literals as-is
+ case _ => clustering(positionSet.head) // Replace column references
+ }
+ te.copy(children = newChildren)
case (_, positionSet) => clustering(positionSet.head)
}
KeyedPartitioning(newExpressions, partitioning.partitionKeys, partitioning.isGrouped)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
index 02e19dd053f29..9f3c4daa7a6f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, ResolveTimeZone, TypeCoercion}
-import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder, TransformExpression, V2ExpressionUtils}
+import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder, TransformExpression, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartitions, RepartitionByExpression, Sort}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
@@ -96,9 +96,7 @@ object DistributionAndOrderingUtils {
}
private def resolveTransformExpression(expr: Expression): Expression = expr.transform {
- case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) =>
- V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments)
- case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
+ case TransformExpression(scalarFunc: ScalarFunction[_], arguments) =>
V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 2a0ab52c36933..66cf2f19bfdb8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.functions.{col, max}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf._
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with ExplainSuiteHelper {
private val functions = Seq(
@@ -126,7 +127,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
val df = sql(s"SELECT * FROM testcat.ns.$table")
val distribution = physical.ClusteredDistribution(
- Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))
+ Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts")))))
checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
}
@@ -138,7 +139,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
val df = sql(s"SELECT * FROM testcat.ns.$table")
val distribution = physical.ClusteredDistribution(
- Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))
+ Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts")))))
// Has exactly one partition.
val partitionKeys = Seq(0).map(v => InternalRow.fromSeq(Seq(v)))
@@ -194,13 +195,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
val df = sql(s"SELECT * FROM testcat.ns.$table")
val distribution = physical.ClusteredDistribution(
- Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))
+ Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts")))))
checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
}
}
- test("non-clustered distribution: V2 function with multiple args") {
+ test("clustered distribution: V2 function with multiple args") {
val partitions: Array[Transform] = Array(
Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(2))
)
@@ -216,7 +217,11 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
val distribution = physical.ClusteredDistribution(
Seq(TransformExpression(TruncateFunction, Seq(attr("data"), Literal(2)))))
- checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
+ // With truncate transform support, KeyedPartitioning should now work
+ val partitionKeys = Seq("aa", "bb", "cc").map(v =>
+ InternalRow(UTF8String.fromString(v)))
+ checkQueryPlan(df, distribution,
+ physical.KeyedPartitioning(distribution.clustering, partitionKeys))
}
/**
@@ -4183,4 +4188,176 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
}
}
}
+
+ // === SPARK-50593: Truncate SPJ support tests ===
+
+ test("SPARK-50593: cross-function truncate vs bucket should NOT trigger SPJ") {
+ val partitions1 = Array(
+ Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(3))
+ )
+ val partitions2 = Array(
+ Expressions.bucket(4, "data")
+ )
+
+ createTable("trunc_cross1", columns, partitions1)
+ sql("INSERT INTO testcat.ns.trunc_cross1 VALUES " +
+ "(0, 'aaa', CAST('2022-01-01' AS timestamp)), " +
+ "(1, 'bbb', CAST('2021-01-01' AS timestamp))")
+
+ createTable("trunc_cross2", columns2, partitions2)
+ sql("INSERT INTO testcat.ns.trunc_cross2 VALUES " +
+ "(1, 5, 'aaa'), " +
+ "(5, 10, 'bbb')")
+
+ withSQLConf(
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+
+ val df = sql(
+ s"""
+ |${selectWithMergeJoinHint("trunc_cross1", "trunc_cross2")}
+ |trunc_cross1.id, trunc_cross2.store_id
+ |FROM testcat.ns.trunc_cross1 JOIN testcat.ns.trunc_cross2
+ |ON trunc_cross1.data = trunc_cross2.data
+ |ORDER BY trunc_cross1.id
+ |""".stripMargin)
+
+ // Different functions (truncate vs bucket) are not mutually reducible, so a shuffle
+ // must still be planned.
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.nonEmpty,
+ "truncate vs bucket are not compatible - a shuffle should be present, " +
+ "but none was planned")
+ checkAnswer(df, Seq(Row(0, 1), Row(1, 5)))
+ }
+ }
+
+ test("SPARK-50593: truncate(3) vs truncate(5) triggers SPJ via width reducer") {
+ // Exercises the ReducibleParameters-based reducer path end-to-end: truncate widths 3 and 5
+ // are mutually reducible (reduce the larger to the smaller), so SPJ must avoid the shuffle.
+ val table1 = "trunc_three"
+ val table2 = "trunc_five"
+
+ val partitions1 = Array(
+ Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(3)))
+ val partitions2 = Array(
+ Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(5)))
+
+ createTable(table1, columns, partitions1)
+ sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
+ "(0, 'apple', CAST('2022-01-01' AS timestamp)), " +
+ "(1, 'grape', CAST('2021-01-01' AS timestamp)), " +
+ "(2, 'orange', CAST('2020-01-01' AS timestamp))")
+
+ createTable(table2, columns, partitions2)
+ sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
+ "(10, 'apple', CAST('2022-01-01' AS timestamp)), " +
+ "(20, 'grape', CAST('2021-01-01' AS timestamp)), " +
+ "(30, 'orange', CAST('2020-01-01' AS timestamp))")
+
+ withSQLConf(
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+
+ val df = sql(
+ s"""
+ |${selectWithMergeJoinHint(table1, table2)}
+ |$table1.id AS left_id, $table2.id AS right_id
+ |FROM testcat.ns.$table1 JOIN testcat.ns.$table2
+ |ON $table1.data = $table2.data
+ |ORDER BY $table1.id
+ |""".stripMargin)
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.isEmpty,
+ "truncate(3) vs truncate(5) should avoid shuffle via the width reducer, " +
+ "but a shuffle was planned")
+ checkAnswer(df, Seq(Row(0, 10), Row(1, 20), Row(2, 30)))
+ }
+ }
+ test("SPARK-50593: TransformExpression.collectLeaves filters out literals") {
+ // bucket(4, col) has children = [Literal(4), col] but collectLeaves should return [col]
+ val col = attr("data")
+ val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(4), col))
+ val leaves = bucketExpr.collectLeaves()
+ assert(leaves.size == 1, s"Expected 1 leaf (column ref), got ${leaves.size}: $leaves")
+ assert(leaves.head.semanticEquals(col),
+ s"Expected leaf to be the column reference, got ${leaves.head}")
+
+ // truncate(col, 3) has children = [col, Literal(3)] but collectLeaves should return [col]
+ val truncExpr = TransformExpression(TruncateFunction, Seq(col, Literal(3)))
+ val truncLeaves = truncExpr.collectLeaves()
+ assert(truncLeaves.size == 1,
+ s"Expected 1 leaf (column ref), got ${truncLeaves.size}: $truncLeaves")
+ assert(truncLeaves.head.semanticEquals(col),
+ s"Expected leaf to be the column reference, got ${truncLeaves.head}")
+
+ // years(col) has children = [col] with no literals
+ val yearsExpr = TransformExpression(YearsFunction, Seq(col))
+ val yearsLeaves = yearsExpr.collectLeaves()
+ assert(yearsLeaves.size == 1,
+ s"Expected 1 leaf for years(), got ${yearsLeaves.size}: $yearsLeaves")
+ }
+
+ test("SPARK-50593: existing bucket SPJ still works with ReducibleParameters API") {
+ // Exercises the new ReducibleParameters-based reducer path end-to-end: bucket(4) and
+ // bucket(2) differ, so SPJ can only avoid the shuffle if BucketFunction's reducer
+ // (now implemented via ReducibleParameters) correctly returns a GCD-based Reducer.
+ val table1 = "bucket_compat1"
+ val table2 = "bucket_compat2"
+
+ val partitions1 = Array(Expressions.bucket(4, "id"))
+ val partitions2 = Array(Expressions.bucket(2, "store_id"))
+
+ createTable(table1, columns, partitions1)
+ sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
+ "(0, 'aaa', CAST('2022-01-01' AS timestamp)), " +
+ "(1, 'bbb', CAST('2021-01-01' AS timestamp)), " +
+ "(2, 'ccc', CAST('2020-01-01' AS timestamp)), " +
+ "(3, 'ddd', CAST('2019-01-01' AS timestamp))")
+
+ createTable(table2, columns2, partitions2)
+ sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
+ "(0, 5, 'aaa'), " +
+ "(1, 10, 'bbb'), " +
+ "(2, 15, 'ccc'), " +
+ "(3, 20, 'ddd')")
+
+ withSQLConf(
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+ val df = sql(
+ s"""
+ |${selectWithMergeJoinHint(table1, table2)}
+ |$table1.id, $table2.store_id
+ |FROM testcat.ns.$table1 JOIN testcat.ns.$table2
+ |ON $table1.id = $table2.store_id
+ |ORDER BY $table1.id
+ |""".stripMargin)
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.isEmpty,
+ "bucket(4) vs bucket(2) should avoid shuffle via the GCD reducer, " +
+ "but a shuffle was planned")
+ checkAnswer(df, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3)))
+ }
+ }
+
+ test("SPARK-50593: ReducibleParameters backward compat - old int API still works via default") {
+ // Verifies the default reducer(ReducibleParameters, ...) implementation correctly
+ // delegates to the deprecated reducer(int, func, int) when a ReducibleFunction only
+ // overrides the old API. This mirrors how Iceberg 1.10.0 (and earlier) ship without
+ // knowledge of ReducibleParameters.
+ val bucketExpr4 = TransformExpression(LegacyBucketFunction, Seq(Literal(4), attr("id")))
+ val bucketExpr2 = TransformExpression(LegacyBucketFunction, Seq(Literal(2), attr("id")))
+
+ val reducer = bucketExpr4.reducers(bucketExpr2)
+ assert(reducer.isDefined, "Expected a reducer for legacy_bucket(4) on legacy_bucket(2)")
+
+ // Verify the returned Reducer actually reduces bucket 4 -> bucket 2 (GCD = 2).
+ // bucket(4, x) produces values in [0, 4); reducing by GCD=2 gives v % 2.
+ val r = reducer.get.asInstanceOf[Reducer[Integer, Integer]]
+ assert(r.reduce(3) == 1, s"Expected reduce(3) == 1, got ${r.reduce(3)}")
+ assert(r.reduce(2) == 0, s"Expected reduce(2) == 0, got ${r.reduce(2)}")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
index 35102c6893d3b..fc034729330e8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
@@ -213,11 +213,14 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In
}
override def reducer(
- thisNumBuckets: Int,
+ thisParams: ReducibleParameters,
otherFunc: ReducibleFunction[_, _],
- otherNumBuckets: Int): Reducer[Int, Int] = {
+ otherParams: ReducibleParameters): Reducer[Int, Int] = {
if (otherFunc == BucketFunction) {
+ val thisNumBuckets = thisParams.getInt(0)
+ val otherNumBuckets = otherParams.getInt(0)
+
val gcd = this.gcd(thisNumBuckets, otherNumBuckets)
if (gcd > 1 && gcd != thisNumBuckets) {
return BucketReducer(gcd)
@@ -235,6 +238,37 @@ case class BucketReducer(divisor: Int) extends Reducer[Int, Int] {
override def displayName(): String = toString
}
+/**
+ * A bucket function that only overrides the deprecated `reducer(int, func, int)` method,
+ * not the new `reducer(ReducibleParameters, func, ReducibleParameters)` method.
+ *
+ * Used to verify that the default implementation of the new method correctly falls back
+ * to the deprecated int-based API, so legacy implementations continue to work.
+ */
+object LegacyBucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] {
+ override def inputTypes(): Array[DataType] = Array(IntegerType, LongType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "legacy_bucket"
+ override def canonicalName(): String = name()
+ override def toString: String = name()
+ override def produceResult(input: InternalRow): Int = {
+ Math.floorMod(input.getLong(1), input.getInt(0))
+ }
+
+ override def reducer(
+ thisNumBuckets: Int,
+ otherFunc: ReducibleFunction[_, _],
+ otherNumBuckets: Int): Reducer[Int, Int] = {
+ if (otherFunc == LegacyBucketFunction) {
+ val gcd = BigInt(thisNumBuckets).gcd(BigInt(otherNumBuckets)).toInt
+ if (gcd > 1 && gcd != thisNumBuckets) {
+ return BucketReducer(gcd)
+ }
+ }
+ null
+ }
+}
+
object UnboundStringSelfFunction extends UnboundFunction {
override def bind(inputType: StructType): BoundFunction = StringSelfFunction
override def description(): String = name()
@@ -253,12 +287,35 @@ object StringSelfFunction extends ScalarFunction[UTF8String] {
}
object UnboundTruncateFunction extends UnboundFunction {
- override def bind(inputType: StructType): BoundFunction = TruncateFunction
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.size == 2) {
+ inputType.head.dataType match {
+ case StringType | BinaryType => TruncateFunction
+ case IntegerType | LongType => IntegerTruncateFunction
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"'truncate' does not support data type: ${inputType.head.dataType}")
+ }
+ } else {
+ throw new UnsupportedOperationException(
+ "'truncate' requires exactly 2 arguments: (column, width)")
+ }
+ }
+
override def description(): String = name()
override def name(): String = "truncate"
}
-object TruncateFunction extends ScalarFunction[UTF8String] {
+/**
+ * Truncate transform for String/Binary types.
+ * Follows Iceberg spec: truncate(str, L) = str[0:L]
+ *
+ * Implements ReducibleFunction: ANY two different widths are compatible.
+ * The reducer uses the smaller width.
+ */
+object TruncateFunction
+ extends ScalarFunction[UTF8String]
+ with ReducibleFunction[UTF8String, UTF8String] {
override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
override def resultType(): DataType = StringType
override def name(): String = "truncate"
@@ -266,7 +323,52 @@ object TruncateFunction extends ScalarFunction[UTF8String] {
override def toString: String = name()
override def produceResult(input: InternalRow): UTF8String = {
val str = input.getUTF8String(0)
- val length = input.getInt(1)
- str.substring(0, length)
+ val width = input.getInt(1)
+ str.substring(0, width)
+ }
+
+ override def reducer(
+ thisParams: ReducibleParameters,
+ otherFunc: ReducibleFunction[_, _],
+ otherParams: ReducibleParameters): Reducer[UTF8String, UTF8String] = {
+
+ if (otherFunc == TruncateFunction) {
+ val thisWidth = thisParams.getInt(0)
+ val otherWidth = otherParams.getInt(0)
+ val smallerWidth = math.min(thisWidth, otherWidth)
+
+ if (smallerWidth != thisWidth) {
+ return TruncateReducer(smallerWidth)
+ }
+ }
+ null
+ }
+}
+
+case class TruncateReducer(width: Int) extends Reducer[UTF8String, UTF8String] {
+ override def reduce(value: UTF8String): UTF8String = {
+ value.substring(0, width)
+ }
+ override def resultType(): DataType = StringType
+ override def displayName(): String = s"truncate($width)"
+}
+
+/**
+ * Truncate transform for Integer/Long types.
+ * Follows Iceberg spec: truncate(value, W) = value - (((value % W) + W) % W)
+ *
+ * Does NOT implement ReducibleFunction because different integer truncate widths
+ * produce incompatible partition structures.
+ */
+object IntegerTruncateFunction extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(IntegerType, IntegerType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "truncate"
+ override def canonicalName(): String = name()
+ override def toString: String = name()
+ override def produceResult(input: InternalRow): Int = {
+ val value = input.getInt(0)
+ val width = input.getInt(1)
+ value - (((value % width) + width) % width)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
index db664b04ef08b..56df2956ccde4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, TransformExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, TransformExpression}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, KeyedPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, YearsFunction}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -492,7 +492,7 @@ class ProjectedOrderingAndPartitioningSuite
// KP([bucket(32, id)], keys1d) through Project(id as pk) should produce
// KP([bucket(32, pk)], keys1d): the alias is pushed into the bucket's column argument.
val id = AttributeReference("id", IntegerType)()
- val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32))
+ val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id))
val keys1d = Seq(InternalRow(0), InternalRow(1), InternalRow(2))
val child = DummyLeafExecWithPartitioning(
output = Seq(id),
@@ -507,7 +507,7 @@ class ProjectedOrderingAndPartitioningSuite
case te: TransformExpression =>
assert(te.isSameFunction(bucketExpr),
"bucket function and numBuckets must be preserved after alias substitution")
- assert(te.children.head.asInstanceOf[Attribute].name === "pk",
+ assert(te.children.collectFirst { case a: Attribute => a }.get.name === "pk",
"bucket's column argument must be rewritten to the aliased attribute")
case other => fail(s"Expected TransformExpression, got $other")
}
@@ -524,7 +524,7 @@ class ProjectedOrderingAndPartitioningSuite
// Result: KP([bucket(32, id)], keys1d, isNarrowed=true, isGrouped=false).
val id = AttributeReference("id", IntegerType)()
val ts = AttributeReference("ts", IntegerType)()
- val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32))
+ val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id))
val yearsExpr = TransformExpression(YearsFunction, Seq(ts))
// Projected to position [0] (bucket): (0),(1),(0) -- bucket value 0 appears twice.
val keys2d = Seq(InternalRow(0, 2020), InternalRow(1, 2020), InternalRow(0, 2021))
@@ -539,7 +539,7 @@ class ProjectedOrderingAndPartitioningSuite
kp.expressions.head match {
case te: TransformExpression =>
assert(te.isSameFunction(bucketExpr), "bucket must be the surviving expression")
- assert(te.children.head.asInstanceOf[Attribute].name === "id")
+ assert(te.children.collectFirst { case a: Attribute => a }.get.name === "id")
case other => fail(s"Expected TransformExpression, got $other")
}
assert(kp.isNarrowed, "dropping years(ts) position must mark the KP as narrowed")
@@ -554,7 +554,7 @@ class ProjectedOrderingAndPartitioningSuite
// Result: KP([bucket(32, id), years(ts_alias)], keys2d) -- not narrowed.
val id = AttributeReference("id", IntegerType)()
val ts = AttributeReference("ts", IntegerType)()
- val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32))
+ val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id))
val yearsExpr = TransformExpression(YearsFunction, Seq(ts))
val keys2d = Seq(InternalRow(0, 2020), InternalRow(1, 2020), InternalRow(0, 2021))
val child = DummyLeafExecWithPartitioning(
@@ -569,14 +569,14 @@ class ProjectedOrderingAndPartitioningSuite
kp.expressions(0) match {
case te: TransformExpression =>
assert(te.isSameFunction(bucketExpr))
- assert(te.children.head.asInstanceOf[Attribute].name === "id",
+ assert(te.children.collectFirst { case a: Attribute => a }.get.name === "id",
"bucket's argument must remain id (no alias for id in this projection)")
case other => fail(s"Expected TransformExpression at pos 0, got $other")
}
kp.expressions(1) match {
case te: TransformExpression =>
assert(te.isSameFunction(yearsExpr))
- assert(te.children.head.asInstanceOf[Attribute].name === "ts_alias",
+ assert(te.children.collectFirst { case a: Attribute => a }.get.name === "ts_alias",
"years() argument must be rewritten to ts_alias")
case other => fail(s"Expected TransformExpression at pos 1, got $other")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 1e35985f50491..b3d4aad59bc56 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -40,10 +40,10 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
class EnsureRequirementsSuite extends SharedSparkSession {
- private val exprA = Literal(1)
- private val exprB = Literal(2)
- private val exprC = Literal(3)
- private val exprD = Literal(4)
+ private val exprA = AttributeReference("a", IntegerType)()
+ private val exprB = AttributeReference("b", IntegerType)()
+ private val exprC = AttributeReference("c", IntegerType)()
+ private val exprD = AttributeReference("d", IntegerType)()
private val EnsureRequirements = new EnsureRequirements()
@@ -1191,11 +1191,11 @@ class EnsureRequirementsSuite extends SharedSparkSession {
}
def bucket(numBuckets: Int, expr: Expression): TransformExpression = {
- TransformExpression(BucketFunction, Seq(expr), Some(numBuckets))
+ TransformExpression(BucketFunction, Seq(Literal(numBuckets), expr))
}
def buckets(numBuckets: Int, expr: Seq[Expression]): TransformExpression = {
- TransformExpression(BucketFunction, expr, Some(numBuckets))
+ TransformExpression(BucketFunction, Seq(Literal(numBuckets)) ++ expr)
}
test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when children are compatible") {