diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cc50da1f17fdf..f331cd124759f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -691,6 +691,12 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) * `HashPartitioning(B.key2)`. It is also worth noting that `partitionings` * in this collection do not need to be equivalent, which is useful for * Outer Join operators. + * + * [[KeyedPartitioning]]s within a `PartitioningCollection` describe the same physical partitioning + * and so must share the same `partitionKeys` reference, differing only in their `expressions` (with + * matching arity). Use [[PartitioningCollection.fromPartitionings]] to build a collection from + * independently-computed partitionings (e.g. join `outputPartitioning`); it interns `partitionKeys` + * references (including across nested collections) so the invariant holds. */ case class PartitioningCollection(partitionings: Seq[Partitioning]) extends Expression with Partitioning with Unevaluable { @@ -699,6 +705,26 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) partitionings.map(_.numPartitions).distinct.length == 1, s"PartitioningCollection requires all of its partitionings have the same numPartitions.") + checkKeyedPartitioningInvariant() + + private def checkKeyedPartitioningInvariant(): Unit = { + var first: KeyedPartitioning = null + foreach { + case k: KeyedPartitioning => + if (first == null) { + first = k + } else { + require(k.expressions.length == first.expressions.length, + "All KeyedPartitionings in a PartitioningCollection must have matching expression " + + "arity") + require(k.partitionKeys eq first.partitionKeys, + "All KeyedPartitionings in a PartitioningCollection must share the same " + + "partitionKeys reference") + } + case _ => + } + } + override def children: Seq[Expression] = partitionings.collect { case expr: Expression => expr } @@ -730,6 +756,36 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) super.legacyWithNewChildren(newChildren).asInstanceOf[PartitioningCollection] } +object PartitioningCollection { + /** + * Builds a [[PartitioningCollection]], unifying the `partitionKeys` reference across all + * [[KeyedPartitioning]]s (including those in nested collections). Use this when combining + * independently-computed partitionings (e.g. join `outputPartitioning`) where + * `KeyedPartitioning.partitionKeys` are structurally equal but may not be reference-equal. + * + * Note: this can't be implemented with `TreeNode.transform`. + */ + def fromPartitionings(partitionings: Seq[Partitioning]): PartitioningCollection = { + var canonicalKeys: Seq[InternalRowComparableWrapper] = null + def intern(p: Partitioning): Partitioning = p match { + case k: KeyedPartitioning => + if (canonicalKeys == null) { + canonicalKeys = k.partitionKeys + k + } else if (k.partitionKeys ne canonicalKeys) { + require(k.partitionKeys == canonicalKeys, + "All KeyedPartitionings in a PartitioningCollection must have equal partitionKeys") + k.copy(partitionKeys = canonicalKeys) + } else { + k + } + case pc: PartitioningCollection => new PartitioningCollection(pc.partitionings.map(intern)) + case other => other + } + new PartitioningCollection(partitionings.map(intern)) + } +} + /** * Represents a partitioning where rows are collected, transformed and broadcasted to each * node in the cluster. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala index 1f2b1d0a585d6..b37e1b258e9bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala @@ -40,7 +40,7 @@ trait PartitioningPreservingUnaryExecNode extends UnaryExecNode (projectedKPs ++ projectedOthers).take(aliasCandidateLimit) match { case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions) case Seq(p) => p - case ps => PartitioningCollection(ps) + case ps => PartitioningCollection.fromPartitionings(ps) } } @@ -88,22 +88,15 @@ trait PartitioningPreservingUnaryExecNode extends UnaryExecNode * * The resulting [[KeyedPartitioning]]s are the cross-product of the per-position alternatives * restricted to the projectable positions. All share the same `partitionKeys` object (projected - * to the same subset of positions), preserving the invariant required by [[GroupPartitionsExec]]. + * to the same subset of positions), preserving the invariant required by + * [[PartitioningCollection]]. */ private def projectKeyedPartitionings( kps: Seq[KeyedPartitioning]): LazyList[KeyedPartitioning] = { if (kps.isEmpty) return LazyList.empty + // All input KPs share the same `partitionKeys` reference and matching arity by the + // [[PartitioningCollection]] invariant (the only producer of multi-KP inputs here). val numPositions = kps.head.expressions.length - // The function assumes all input KPs share the same `partitionKeys`, which implies matching - // expression arity. This invariant is asserted by [[GroupPartitionsExec]] and is established - // by the constructors of [[PartitioningCollection]] feeding this method (a join's - // `PartitioningCollection(left.outputPartitioning, right.outputPartitioning)` combines KPs - // that have been aligned by [[EnsureRequirements]] to the same join keys). If the invariant - // is ever violated upstream, fail early with a clear message instead of throwing an opaque - // `IndexOutOfBoundsException` from `kp.expressions(i)` below. - assert(kps.tail.forall(_.expressions.length == numPositions), - s"All input KeyedPartitionings must share the same expression arity, " + - s"but got: ${kps.map(_.expressions.length).mkString(", ")}.") val alternativesPerPosition: IndexedSeq[LazyList[Expression]] = if (hasAlias) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 264a0e954936f..4d87be6622939 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -67,24 +67,14 @@ case class GroupPartitionsExec( override def outputPartitioning: Partitioning = { child.outputPartitioning match { case p: Partitioning with Expression => - // There can be multiple `KeyedPartitioning` in an output partitioning of a join, but they - // can only differ in `expressions`. `partitionKeys` must match so we can calculate it only - // once via `groupedPartitions`. - - val keyedPartitionings = p.collect { case k: KeyedPartitioning => k } - if (keyedPartitionings.size > 1) { - val first = keyedPartitionings.head - keyedPartitionings.tail.foreach { k => - assert(k.partitionKeys == first.partitionKeys, - "All KeyedPartitioning nodes must have identical partition keys") - } - } - + // There can be multiple `KeyedPartitioning`s in an output partitioning of a join, but they + // can only differ in `expressions`; their `partitionKeys` reference is shared (enforced by + // `PartitioningCollection`), so `groupedPartitions` is computed only once. + val partitionKeys = groupedPartitions.map(_._1) p.transform { case k: KeyedPartitioning => val projectedExpressions = joinKeyPositions.fold(k.expressions)(_.map(k.expressions)) - KeyedPartitioning(projectedExpressions, groupedPartitions.map(_._1), - isGrouped = isGrouped) + KeyedPartitioning(projectedExpressions, partitionKeys, isGrouped = isGrouped) }.asInstanceOf[Partitioning] case o => o } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index e4f18c9144dda..2881aeac55d89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -84,7 +84,7 @@ case class BroadcastHashJoinExec private( // constructor prevents that. case p :: Nil => p - case ps => PartitioningCollection(ps) + case ps => PartitioningCollection.fromPartitionings(ps) } case _ => streamedPlan.outputPartitioning } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala index f363156c81e54..3fb968bfea7a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala @@ -46,7 +46,8 @@ trait ShuffledJoin extends JoinCodegenSupport { override def outputPartitioning: Partitioning = joinType match { case _: InnerLike => - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + PartitioningCollection.fromPartitionings( + Seq(left.outputPartitioning, right.outputPartitioning)) case LeftOuter | LeftSingle => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala index 71a7d4cf56e13..9eca04c985913 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala @@ -242,7 +242,8 @@ case class StreamingSymmetricHashJoinExec( override def outputPartitioning: Partitioning = joinType match { case _: InnerLike => - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + PartitioningCollection.fromPartitionings( + Seq(left.outputPartitioning, right.outputPartitioning)) case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala index a38570924620a..a70baece77844 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala @@ -387,7 +387,7 @@ class ProjectedOrderingAndPartitioningSuite val y = AttributeReference("y", IntegerType)() val yAlias = AttributeReference("y_alias", IntegerType)() val keys2d = Seq(InternalRow(1, 1), InternalRow(1, 2), InternalRow(2, 1), InternalRow(2, 2)) - val childPartitioning = PartitioningCollection(Seq( + val childPartitioning = PartitioningCollection.fromPartitionings(Seq( KeyedPartitioning(Seq(x, y), keys2d), KeyedPartitioning(Seq(x, yAlias), keys2d))) val child = DummyLeafExecWithPartitioning( @@ -587,27 +587,20 @@ class ProjectedOrderingAndPartitioningSuite } } - test("SPARK-46367: mixed-arity KeyedPartitionings in input fail with a clear assertion") { - // The function assumes all input KPs share the same arity (the invariant asserted by - // `GroupPartitionsExec`). Without the assert below, indexing `kp.expressions(i)` for - // `i >= kp.expressions.length` would throw an opaque `IndexOutOfBoundsException`. The assert - // surfaces the real cause -- an upstream node violated the invariant -- so the bug can be - // fixed at the producer. + test("SPARK-46367: mixed-arity KeyedPartitionings rejected by PartitioningCollection") { + // PartitioningCollection enforces matching expression arity (and shared partitionKeys + // references) across all its KeyedPartitionings, so the invariant required by + // `AliasAwareOutputExpression` cannot be violated by the input. val x = AttributeReference("x", IntegerType)() val y = AttributeReference("y", IntegerType)() val keys2d = Seq(InternalRow(1, 1), InternalRow(2, 2)) val keys1d = Seq(InternalRow(1), InternalRow(2)) - val child = DummyLeafExecWithPartitioning( - output = Seq(x, y), - partitioning = PartitioningCollection(Seq( + val e = intercept[IllegalArgumentException] { + PartitioningCollection.fromPartitionings(Seq( KeyedPartitioning(Seq(x, y), keys2d), - KeyedPartitioning(Seq(x), keys1d)))) - val project = ProjectExec(Seq(x), child) - val e = intercept[AssertionError] { - project.outputPartitioning + KeyedPartitioning(Seq(x), keys1d))) } - assert(e.getMessage.contains("All input KeyedPartitionings must share the same expression " + - "arity")) + assert(e.getMessage.contains("partitionKeys")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala index 5d2adeb0c00af..51951d68cc606 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala @@ -97,7 +97,7 @@ class GroupPartitionsExecSuite extends SharedSparkSession { val leftKP = KeyedPartitioning(Seq(exprA), partitionKeys) val rightKP = KeyedPartitioning(Seq(exprB), partitionKeys) val child = DummySparkPlan( - outputPartitioning = PartitioningCollection(Seq(leftKP, rightKP)), + outputPartitioning = PartitioningCollection.fromPartitionings(Seq(leftKP, rightKP)), outputOrdering = Seq(SortOrder(exprA, Ascending, sameOrderExpressions = Seq(exprB)))) val gpe = GroupPartitionsExec(child) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 1e35985f50491..74b706bce34f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -821,7 +821,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = PartitioningCollection(Seq( + outputPartitioning = PartitioningCollection.fromPartitionings(Seq( KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty), KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty)) ) @@ -1050,7 +1050,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { // With partition collections plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning = - PartitioningCollection( + PartitioningCollection.fromPartitionings( Seq(KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues), KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues)) ) @@ -1077,13 +1077,13 @@ class EnsureRequirementsSuite extends SharedSparkSession { // Nested partition collections plan2 = new DummySparkPlanWithBatchScanChild(outputPartitioning = - PartitioningCollection( + PartitioningCollection.fromPartitionings( Seq( - PartitioningCollection( + PartitioningCollection.fromPartitionings( Seq( KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues), KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues))), - PartitioningCollection( + PartitioningCollection.fromPartitionings( Seq( KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues), KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues))) @@ -1539,7 +1539,7 @@ private case class DummyBothKPBinaryExec(left: SparkPlan, right: SparkPlan) override def output: Seq[Attribute] = left.output ++ right.output override def outputOrdering: Seq[SortOrder] = left.outputOrdering override def outputPartitioning: Partitioning = - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + PartitioningCollection.fromPartitionings(Seq(left.outputPartitioning, right.outputPartitioning)) override protected def doExecute(): RDD[InternalRow] = null override protected def withNewChildrenInternal( newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =