diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteAsOfJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteAsOfJoin.scala index aedccf6b86395..6f10cad108f72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteAsOfJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteAsOfJoin.scala @@ -48,43 +48,53 @@ import org.apache.spark.sql.catalyst.rules._ * }}} */ object RewriteAsOfJoin extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithNewOutput { - case j @ AsOfJoin(left, right, asOfCondition, condition, joinType, orderExpression, _) => - val conditionWithOuterReference = - condition.map(And(_, asOfCondition)).getOrElse(asOfCondition).transformUp { - case a: AttributeReference if left.outputSet.contains(a) => - OuterReference(a) - } - val filtered = Filter(conditionWithOuterReference, right) + def apply(plan: LogicalPlan): LogicalPlan = { + // When the sort-merge AS-OF join operator is enabled, skip this rewrite + // so that the AsOfJoin logical node reaches the planner intact and + // AsOfJoinSelection can produce a dedicated physical operator. + // This conf-based gating (rather than excludedRules) is used because + // the planner strategy and this optimizer rule must be kept in sync: + // if the rewrite runs, the planner never sees AsOfJoin. + if (conf.sortMergeAsOfJoinEnabled) return plan - val orderExpressionWithOuterReference = orderExpression.transformUp { - case a: AttributeReference if left.outputSet.contains(a) => - OuterReference(a) + plan.transformUpWithNewOutput { + case j @ AsOfJoin(left, right, asOfCondition, condition, joinType, orderExpression, _) => + val conditionWithOuterReference = + condition.map(And(_, asOfCondition)).getOrElse(asOfCondition).transformUp { + case a: AttributeReference if left.outputSet.contains(a) => + OuterReference(a) } - val rightStruct = CreateStruct(right.output) - val nearestRight = MinBy(rightStruct, orderExpressionWithOuterReference) - .toAggregateExpression() - val aggExpr = Alias(nearestRight, "__nearest_right__")() - val aggregate = Aggregate(Seq.empty, Seq(aggExpr), filtered) + val filtered = Filter(conditionWithOuterReference, right) - val projectWithScalarSubquery = Project( - left.output :+ Alias(ScalarSubquery(aggregate, left.output), "__right__")(), - left) + val orderExpressionWithOuterReference = orderExpression.transformUp { + case a: AttributeReference if left.outputSet.contains(a) => + OuterReference(a) + } + val rightStruct = CreateStruct(right.output) + val nearestRight = MinBy(rightStruct, orderExpressionWithOuterReference) + .toAggregateExpression() + val aggExpr = Alias(nearestRight, "__nearest_right__")() + val aggregate = Aggregate(Seq.empty, Seq(aggExpr), filtered) - val filterRight = joinType match { - case LeftOuter => projectWithScalarSubquery - case _ => - Filter(IsNotNull(projectWithScalarSubquery.output.last), projectWithScalarSubquery) - } + val projectWithScalarSubquery = Project( + left.output :+ Alias(ScalarSubquery(aggregate, left.output), "__right__")(), + left) - val project = Project( - left.output ++ right.output.zipWithIndex.map { - case (out, idx) => - Alias(GetStructField(filterRight.output.last, idx), out.name)() - }, - filterRight) - val attrMapping = j.output.zip(project.output) + val filterRight = joinType match { + case LeftOuter => projectWithScalarSubquery + case _ => + Filter(IsNotNull(projectWithScalarSubquery.output.last), projectWithScalarSubquery) + } + + val project = Project( + left.output ++ right.output.zipWithIndex.map { + case (out, idx) => + Alias(GetStructField(filterRight.output.last, idx), out.name)() + }, + filterRight) + val attrMapping = j.output.zip(project.output) - project -> attrMapping + project -> attrMapping + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5ed831f20f394..ebba58736d9cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -870,6 +870,17 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SORT_MERGE_AS_OF_JOIN_ENABLED = + buildConf("spark.sql.join.sortMergeAsOfJoin.enabled") + .doc("When true, use a dedicated sort-merge physical operator for AS-OF joins " + + "instead of rewriting to a correlated subquery. The sort-merge operator evaluates " + + "AS-OF joins in O(N log N) by co-sorting both sides on the as-of key and scanning " + + "in a single pass.") + .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION = buildConf("spark.sql.requireAllClusterKeysForCoPartition") .internal() @@ -7960,6 +7971,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) + def sortMergeAsOfJoinEnabled: Boolean = getConf(SORT_MERGE_AS_OF_JOIN_ENABLED) + def enableRadixSort: Boolean = getConf(RADIX_SORT_ENABLED) def isParquetSchemaMergingEnabled: Boolean = getConf(PARQUET_SCHEMA_MERGING_ENABLED) diff --git a/sql/core/benchmarks/AsOfJoinBenchmark-jdk21-results.txt b/sql/core/benchmarks/AsOfJoinBenchmark-jdk21-results.txt new file mode 100644 index 0000000000000..8e21b5ef31de7 --- /dev/null +++ b/sql/core/benchmarks/AsOfJoinBenchmark-jdk21-results.txt @@ -0,0 +1,19 @@ +================================================================================================ +AS-OF Join Benchmark +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1010-azure +AMD EPYC 7763 64-Core Processor +AS-OF Join (left=10000, right=10000, groups=100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Correlated subquery (baseline) 37494 37697 287 0.0 3749357.9 1.0X +Sort-merge AS-OF join 62 85 19 0.2 6236.4 601.2X + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1010-azure +AMD EPYC 7763 64-Core Processor +AS-OF Join no equi-key (left=10000, right=10000): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Correlated subquery (baseline) 23664 23684 27 0.0 2366425.3 1.0X +Sort-merge AS-OF join 1773 1795 32 0.0 177269.8 13.3X + + diff --git a/sql/core/benchmarks/AsOfJoinBenchmark-jdk25-results.txt b/sql/core/benchmarks/AsOfJoinBenchmark-jdk25-results.txt new file mode 100644 index 0000000000000..d2c13cfe69955 --- /dev/null +++ b/sql/core/benchmarks/AsOfJoinBenchmark-jdk25-results.txt @@ -0,0 +1,19 @@ +================================================================================================ +AS-OF Join Benchmark +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1010-azure +AMD EPYC 7763 64-Core Processor +AS-OF Join (left=10000, right=10000, groups=100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Correlated subquery (baseline) 37549 37690 199 0.0 3754903.0 1.0X +Sort-merge AS-OF join 56 73 14 0.2 5550.4 676.5X + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1010-azure +AMD EPYC 7763 64-Core Processor +AS-OF Join no equi-key (left=10000, right=10000): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Correlated subquery (baseline) 22100 22128 40 0.0 2209969.1 1.0X +Sort-merge AS-OF join 1614 1624 13 0.0 161417.4 13.7X + + diff --git a/sql/core/benchmarks/AsOfJoinBenchmark-results.txt b/sql/core/benchmarks/AsOfJoinBenchmark-results.txt new file mode 100644 index 0000000000000..55794a9296017 --- /dev/null +++ b/sql/core/benchmarks/AsOfJoinBenchmark-results.txt @@ -0,0 +1,19 @@ +================================================================================================ +AS-OF Join Benchmark +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1010-azure +AMD EPYC 7763 64-Core Processor +AS-OF Join (left=10000, right=10000, groups=100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Correlated subquery (baseline) 37391 37469 110 0.0 3739110.1 1.0X +Sort-merge AS-OF join 59 70 9 0.2 5918.1 631.8X + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1010-azure +AMD EPYC 7763 64-Core Processor +AS-OF Join no equi-key (left=10000, right=10000): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Correlated subquery (baseline) 24013 24033 28 0.0 2401313.2 1.0X +Sort-merge AS-OF join 1713 1719 9 0.0 171343.3 14.0X + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 7e7f839037175..96a3e3179f541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -49,6 +49,7 @@ class SparkPlanner(val session: SparkSession, val experimentalMethods: Experimen Window :: WindowGroupLimit :: JoinSelection :: + AsOfJoinSelection :: InMemoryScans :: SparkScripts :: Pipelines :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 92818c12bfa09..c06d23cb5177b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -178,6 +178,149 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Supports both equi-joins and non-equi-joins. * Supports only inner like joins. */ + /** + * Plans AS-OF joins using a dedicated sort-merge operator when the + * conf is enabled. + */ + object AsOfJoinSelection extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case j @ AsOfJoin(left, right, asOfCondition, condition, joinType, + orderExpression, _) if conf.sortMergeAsOfJoinEnabled => + val (leftKeys, rightKeys, residual) = condition match { + case Some(cond) => extractEquiJoinKeys(cond, left, right) + case None => (Seq.empty[Expression], Seq.empty[Expression], None) + } + val (leftAsOf, rightAsOf) = extractAsOfExprs( + asOfCondition, orderExpression, left, right) + + joins.SortMergeAsOfJoinExec( + leftKeys, rightKeys, leftAsOf, rightAsOf, + asOfCondition, orderExpression, joinType, residual, + planLater(left), planLater(right)) :: Nil + case _ => Nil + } + + /** + * Extract equi-join key pairs and residual (non-equi) condition + * from a conjunction. Only EqualTo is treated as equi-key; + * EqualNullSafe is excluded because the Scanner does not implement + * null-safe comparison semantics. + */ + private def extractEquiJoinKeys( + condition: Expression, + left: LogicalPlan, + right: LogicalPlan): (Seq[Expression], Seq[Expression], Option[Expression]) = { + val leftKeys = + new scala.collection.mutable.ArrayBuffer[Expression]() + val rightKeys = + new scala.collection.mutable.ArrayBuffer[Expression]() + val residuals = + new scala.collection.mutable.ArrayBuffer[Expression]() + + flattenAnd(condition).foreach { + case EqualTo(l, r) + if l.references.subsetOf(left.outputSet) && + r.references.subsetOf(right.outputSet) => + leftKeys += l; rightKeys += r + case EqualTo(l, r) + if r.references.subsetOf(left.outputSet) && + l.references.subsetOf(right.outputSet) => + leftKeys += r; rightKeys += l + case other => + residuals += other + } + val residual = residuals.reduceOption(And) + (leftKeys.toSeq, rightKeys.toSeq, residual) + } + + private def flattenAnd(expr: Expression): Seq[Expression] = expr match { + case And(l, r) => flattenAnd(l) ++ flattenAnd(r) + case other => Seq(other) + } + + private def extractAsOfExprs( + asOfCondition: Expression, + orderExpression: Expression, + left: LogicalPlan, + right: LogicalPlan): (Expression, Expression) = { + val leftAttrs = left.outputSet + val rightAttrs = right.outputSet + + def find(expr: Expression): Option[(Expression, Expression)] = expr match { + case GreaterThanOrEqual(l, r) + if l.references.subsetOf(leftAttrs) && r.references.subsetOf(rightAttrs) => + Some((l, r)) + case GreaterThan(l, r) + if l.references.subsetOf(leftAttrs) && r.references.subsetOf(rightAttrs) => + Some((l, r)) + case LessThanOrEqual(l, r) + if l.references.subsetOf(leftAttrs) && r.references.subsetOf(rightAttrs) => + Some((l, r)) + case LessThan(l, r) + if l.references.subsetOf(leftAttrs) && r.references.subsetOf(rightAttrs) => + Some((l, r)) + case GreaterThanOrEqual(l, r) + if l.references.subsetOf(rightAttrs) && r.references.subsetOf(leftAttrs) => + Some((r, l)) + case GreaterThan(l, r) + if l.references.subsetOf(rightAttrs) && r.references.subsetOf(leftAttrs) => + Some((r, l)) + case LessThanOrEqual(l, r) + if l.references.subsetOf(rightAttrs) && r.references.subsetOf(leftAttrs) => + Some((r, l)) + case LessThan(l, r) + if l.references.subsetOf(rightAttrs) && r.references.subsetOf(leftAttrs) => + Some((r, l)) + case And(l, r) => find(l).orElse(find(r)) + case _ => None + } + + find(asOfCondition).orElse { + // For Nearest direction, asOfCondition may be TrueLiteral. + // Extract as-of keys from the orderExpression instead. + findFromOrder(orderExpression, leftAttrs, rightAttrs) + }.getOrElse { + // Last resort: find attributes from orderExpression + val allAttrs = orderExpression.collect { + case a: AttributeReference => a + } + val leftExpr = allAttrs.find(a => leftAttrs.contains(a)).getOrElse { + throw new IllegalStateException( + "Cannot extract left as-of key from AS-OF join condition " + + s"or order expression: $asOfCondition / $orderExpression") + } + val rightExpr = allAttrs.find(a => rightAttrs.contains(a)).getOrElse { + throw new IllegalStateException( + "Cannot extract right as-of key from AS-OF join condition " + + s"or order expression: $asOfCondition / $orderExpression") + } + (leftExpr, rightExpr) + } + } + + /** Extract as-of key pair from orderExpression (distance metric). */ + private def findFromOrder( + expr: Expression, + leftAttrs: AttributeSet, + rightAttrs: AttributeSet): Option[(Expression, Expression)] = { + expr match { + // Backward: Subtract(leftAsOf, rightAsOf) + // Forward: Subtract(rightAsOf, leftAsOf) + case Subtract(l, r, _) + if l.references.subsetOf(leftAttrs) && + r.references.subsetOf(rightAttrs) => + Some((l, r)) + case Subtract(l, r, _) + if l.references.subsetOf(rightAttrs) && + r.references.subsetOf(leftAttrs) => + Some((r, l)) + // Nearest: If(GT(left, right), Sub(left, right), Sub(right, left)) + case If(_, thenExpr, _) => findFromOrder(thenExpr, leftAttrs, rightAttrs) + case _ => None + } + } + } + object JoinSelection extends Strategy with JoinSelectionHelper { private val hintErrorHandler = conf.hintErrorHandler diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeAsOfJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeAsOfJoinExec.scala new file mode 100644 index 0000000000000..dc0540b68ffd5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeAsOfJoinExec.scala @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} + +/** + * Performs an AS-OF join using sort-merge. Both sides are co-partitioned + * by the equi-join keys and sorted by (equi-join keys, as-of key). + * For each left row, we scan the right side to find the nearest match + * satisfying the as-of condition. + * + * Note: When there are no equi-keys, both sides are collected into a + * single partition (AllTuples). The right side is fully buffered in + * memory, so this operator is not suitable for large right-side tables + * without equi-keys. + */ +case class SortMergeAsOfJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftAsOfExpr: Expression, + rightAsOfExpr: Expression, + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryExecNode { + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, + "number of output rows")) + + override def output: Seq[Attribute] = joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } + + override def outputOrdering: Seq[SortOrder] = { + // Output preserves left-side ordering (equi-keys + as-of key) + left.outputOrdering + } + + override def requiredChildDistribution: Seq[Distribution] = { + if (leftKeys.isEmpty) { + AllTuples :: AllTuples :: Nil + } else { + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + val leftOrdering = leftKeys.map(SortOrder(_, Ascending)) :+ + SortOrder(leftAsOfExpr, Ascending) + val rightOrdering = rightKeys.map(SortOrder(_, Ascending)) :+ + SortOrder(rightAsOfExpr, Ascending) + leftOrdering :: rightOrdering :: Nil + } + + override def outputPartitioning: Partitioning = left.outputPartitioning + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + val scanner = new SortMergeAsOfJoinScanner( + leftIter, + rightIter, + left.output, + right.output, + leftKeys, + rightKeys, + asOfCondition, + orderExpression, + joinType, + condition, + numOutputRows + ) + // Register cleanup to release the right-side buffer on task completion + TaskContext.get().addTaskCompletionListener[Unit](_ => scanner.close()) + scanner.iterator + } + } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, + newRight: SparkPlan): SortMergeAsOfJoinExec = { + copy(left = newLeft, right = newRight) + } +} + +/** + * Performs the sort-merge AS-OF join scan. + * + * Both inputs are sorted by (equi-keys, as-of key) ascending. For each + * left row within an equi-key group, we find the right row that satisfies + * the as-of condition and minimizes the order expression (distance). + * + * Since the right side is sorted by as-of key within each group, for + * backward joins we scan right-to-left and stop at the first match + * (exploiting sort order for early termination). + */ +private[joins] class SortMergeAsOfJoinScanner( + leftIter: Iterator[InternalRow], + rightIter: Iterator[InternalRow], + leftOutput: Seq[Attribute], + rightOutput: Seq[Attribute], + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + residualCondition: Option[Expression], + numOutputRows: SQLMetric) { + + private val joinedOutput = leftOutput ++ rightOutput + private val joinedRow = new JoinedRow() + private val resultProjection = + UnsafeProjection.create(joinedOutput, joinedOutput) + + // Bound expressions for evaluating conditions on joined rows + private val boundAsOfCond = bindReference(asOfCondition, joinedOutput) + private val boundOrderExpr = bindReference(orderExpression, joinedOutput) + private val boundResidualCond = + residualCondition.map(bindReference(_, joinedOutput)) + + // Key ordering for equi-join keys + private val equiKeyOrdering: Option[BaseOrdering] = + if (leftKeys.nonEmpty) { + val keyAttributes = leftKeys.zipWithIndex.map { case (key, i) => + AttributeReference(s"key_$i", key.dataType, key.nullable)() + } + Some(GenerateOrdering.generate( + keyAttributes.map(SortOrder(_, Ascending)), keyAttributes)) + } else { + None + } + + // Projections to extract equi-keys for comparison + private val leftKeyProj = UnsafeProjection.create(leftKeys, leftOutput) + private val rightKeyProj = UnsafeProjection.create(rightKeys, rightOutput) + + // Ordering for the distance metric + private val distanceOrdering = + TypeUtils.getInterpretedOrdering(orderExpression.dataType) + + // Determine scan direction based on the as-of condition. + // Backward (left >= right): best match is at end of sorted buffer -> right-to-left + // Forward (left <= right): best match is at start -> left-to-right + // Nearest / unknown: left-to-right (works correctly, just no early termination + // guarantee for the "as-of not satisfied" shortcut) + private val scanRightToLeft: Boolean = { + def isBackward(expr: Expression): Boolean = expr match { + case GreaterThanOrEqual(_, _) => true + case GreaterThan(_, _) => true + case And(l, _) => isBackward(l) + case _ => false + } + isBackward(asOfCondition) + } + + // Null row for LeftOuter when no match is found + private val nullRightRow = new GenericInternalRow(rightOutput.length) + + // Right-side buffer: holds right rows for the current equi-key group. + // Rows are sorted by as-of key ascending (guaranteed by requiredChildOrdering). + private val rightGroupBuffer = new ArrayBuffer[InternalRow]() + private var rightGroupKey: UnsafeRow = _ + private var rightPeek: InternalRow = _ + private var rightDone: Boolean = !rightIter.hasNext + + // Initialize: read first right row + if (!rightDone) { + rightPeek = rightIter.next().copy() + } + + /** Release resources held by this scanner. */ + def close(): Unit = { + rightGroupBuffer.clear() + rightGroupBuffer.trimToSize() + } + + def iterator: Iterator[InternalRow] = new Iterator[InternalRow] { + private var nextRow: InternalRow = _ + private val leftIterBuffered = leftIter.buffered + + override def hasNext: Boolean = { + if (nextRow != null) return true + nextRow = findNext() + nextRow != null + } + + override def next(): InternalRow = { + if (!hasNext) throw new NoSuchElementException + val result = nextRow + nextRow = null + result + } + + private def findNext(): InternalRow = { + while (leftIterBuffered.hasNext) { + val leftRow = leftIterBuffered.next() + val leftKey = leftKeyProj(leftRow).copy() + + // Advance right side to the matching equi-key group + advanceRightTo(leftKey) + + // Search for best match exploiting sort order + val bestMatch = findBestInGroup(leftRow) + + if (bestMatch != null) { + numOutputRows += 1 + joinedRow.withLeft(leftRow).withRight(bestMatch) + return resultProjection(joinedRow).copy() + } else if (joinType == LeftOuter) { + numOutputRows += 1 + joinedRow.withLeft(leftRow).withRight(nullRightRow) + return resultProjection(joinedRow).copy() + } + // Inner join: no match, skip + } + null + } + } + + /** + * Advance the right side so that rightGroupBuffer contains all right + * rows whose equi-key matches `leftKey`. + */ + private def advanceRightTo(leftKey: UnsafeRow): Unit = { + equiKeyOrdering match { + case None => + // No equi-keys: buffer all right rows once. + // WARNING: This loads the entire right partition into memory. + if (rightGroupBuffer.isEmpty && !rightDone) { + bufferAllRight() + } + case Some(ordering) => + // Check if current buffer already matches + if (rightGroupKey != null && + ordering.compare(leftKey, rightGroupKey) == 0) { + return + } + + // Skip right rows with keys < leftKey + while (!rightDone && rightPeek != null) { + val rightKey = rightKeyProj(rightPeek) + val cmp = ordering.compare(leftKey, rightKey) + if (cmp > 0) { + rightPeek = if (rightIter.hasNext) { + rightIter.next().copy() + } else { + rightDone = true; null + } + } else if (cmp == 0) { + bufferRightGroup(leftKey, ordering) + return + } else { + rightGroupBuffer.clear() + rightGroupKey = null + return + } + } + rightGroupBuffer.clear() + rightGroupKey = null + } + } + + /** Buffer all right rows with the same equi-key as leftKey. */ + private def bufferRightGroup( + leftKey: UnsafeRow, ordering: BaseOrdering): Unit = { + rightGroupBuffer.clear() + rightGroupKey = leftKey.copy() + + while (!rightDone && rightPeek != null) { + val rightKey = rightKeyProj(rightPeek) + if (ordering.compare(leftKey, rightKey) == 0) { + rightGroupBuffer += rightPeek + rightPeek = if (rightIter.hasNext) { + rightIter.next().copy() + } else { + rightDone = true; null + } + } else { + return + } + } + } + + /** Buffer all remaining right rows (no equi-keys case). */ + private def bufferAllRight(): Unit = { + rightGroupBuffer.clear() + if (rightPeek != null) { + rightGroupBuffer += rightPeek + rightPeek = null + } + while (rightIter.hasNext) { + rightGroupBuffer += rightIter.next().copy() + } + rightDone = true + } + + /** + * Find the best matching right row for the given left row within the + * current group buffer. + * + * The buffer is sorted by as-of key ascending. The scan direction is + * chosen based on where the best match is expected: + * - Backward (left >= right): best match near the end -> right-to-left + * - Forward (left <= right): best match near the start -> left-to-right + * - Nearest: full scan needed (left-to-right, stop when distance + * increases after finding a match) + */ + private def findBestInGroup(leftRow: InternalRow): InternalRow = { + if (scanRightToLeft) { + findBestRightToLeft(leftRow) + } else { + findBestLeftToRight(leftRow) + } + } + + /** Scan from end to start (optimal for Backward joins). */ + private def findBestRightToLeft(leftRow: InternalRow): InternalRow = { + var bestMatch: InternalRow = null + var bestDistance: Any = null + + var i = rightGroupBuffer.size - 1 + while (i >= 0) { + val rightRow = rightGroupBuffer(i) + joinedRow.withLeft(leftRow).withRight(rightRow) + + val asOfSatisfied = boundAsOfCond.eval(joinedRow) + if (asOfSatisfied != null && asOfSatisfied.asInstanceOf[Boolean]) { + val residualSatisfied = boundResidualCond.forall { cond => + val result = cond.eval(joinedRow) + result != null && result.asInstanceOf[Boolean] + } + if (residualSatisfied) { + val distance = boundOrderExpr.eval(joinedRow) + if (distance != null) { + if (bestMatch == null) { + bestMatch = rightRow + bestDistance = distance + } else if (distanceOrdering.lt(distance, bestDistance)) { + bestMatch = rightRow + bestDistance = distance + } else { + return bestMatch + } + } + } + } else if (bestMatch != null) { + return bestMatch + } + i -= 1 + } + bestMatch + } + + /** Scan from start to end (optimal for Forward/Nearest joins). */ + private def findBestLeftToRight(leftRow: InternalRow): InternalRow = { + var bestMatch: InternalRow = null + var bestDistance: Any = null + + var i = 0 + while (i < rightGroupBuffer.size) { + val rightRow = rightGroupBuffer(i) + joinedRow.withLeft(leftRow).withRight(rightRow) + + val asOfSatisfied = boundAsOfCond.eval(joinedRow) + if (asOfSatisfied != null && asOfSatisfied.asInstanceOf[Boolean]) { + val residualSatisfied = boundResidualCond.forall { cond => + val result = cond.eval(joinedRow) + result != null && result.asInstanceOf[Boolean] + } + if (residualSatisfied) { + val distance = boundOrderExpr.eval(joinedRow) + if (distance != null) { + if (bestMatch == null) { + bestMatch = rightRow + bestDistance = distance + } else if (distanceOrdering.lt(distance, bestDistance)) { + bestMatch = rightRow + bestDistance = distance + } else { + return bestMatch + } + } + } + } else if (bestMatch != null) { + return bestMatch + } + i += 1 + } + bestMatch + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SortMergeAsOfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SortMergeAsOfJoinSuite.scala new file mode 100644 index 0000000000000..d4ab781589218 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SortMergeAsOfJoinSuite.scala @@ -0,0 +1,420 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.joins.SortMergeAsOfJoinExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class SortMergeAsOfJoinSuite extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key, "true") + } + + override def afterAll(): Unit = { + spark.conf.unset(SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key) + super.afterAll() + } + + def prepareForAsOfJoin(): (classic.DataFrame, classic.DataFrame) = { + val schema1 = StructType( + StructField("a", IntegerType, false) :: + StructField("b", StringType, false) :: + StructField("left_val", StringType, false) :: Nil) + val rowSeq1: List[Row] = List( + Row(1, "x", "a"), Row(5, "y", "b"), Row(10, "z", "c")) + val df1 = spark.createDataFrame(rowSeq1.asJava, schema1) + + val schema2 = StructType( + StructField("a", IntegerType) :: + StructField("b", StringType) :: + StructField("right_val", IntegerType) :: Nil) + val rowSeq2: List[Row] = List( + Row(1, "v", 1), Row(2, "w", 2), Row(3, "x", 3), + Row(6, "y", 6), Row(7, "z", 7)) + val df2 = spark.createDataFrame(rowSeq2.asJava, schema2) + + (df1, df2) + } + + test("uses SortMergeAsOfJoinExec physical operator") { + val (df1, df2) = prepareForAsOfJoin() + val result = df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward") + val plan = result.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { + case _: SortMergeAsOfJoinExec => true + }.nonEmpty, s"Expected SortMergeAsOfJoinExec in plan:\n$plan") + } + + test("backward join - simple") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 3, "x", 3), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - usingColumns") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - left outer") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), + joinType = "leftouter", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", null, null, null), + Row(5, "y", "b", null, null, null), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("forward join") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "forward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 6, "y", 6), + Row(10, "z", "c", null, null, null) + ).filter(_.get(3) != null) // inner join: no match for 10 + ) + } + + test("nearest join") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "nearest"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 6, "y", 6), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - tolerance = 1") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", + tolerance = functions.lit(1), + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(10, "z", "c", null, null, null) + ).filter(_.get(3) != null) + ) + } + + test("backward join - allowExactMatches = false") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = false, direction = "backward"), + Seq( + // left.a=1: no right row with a < 1 → no match + // left.a=5: right.a=3 (3 < 5) → match + Row(5, "y", "b", 3, "x", 3), + // left.a=10: right.a=7 (7 < 10) → match + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("empty left side") { + val (_, df2) = prepareForAsOfJoin() + val emptyDf = spark.createDataFrame( + java.util.Collections.emptyList[Row](), + StructType( + StructField("a", IntegerType, false) :: + StructField("b", StringType, false) :: + StructField("left_val", StringType, false) :: Nil)) + checkAnswer( + emptyDf.joinAsOf( + df2, emptyDf.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq.empty + ) + } + + test("empty right side") { + val (df1, _) = prepareForAsOfJoin() + val emptyDf = spark.createDataFrame( + java.util.Collections.emptyList[Row](), + StructType( + StructField("a", IntegerType) :: + StructField("b", StringType) :: + StructField("right_val", IntegerType) :: Nil)) + // Inner join: no matches possible + checkAnswer( + df1.joinAsOf( + emptyDf, df1.col("a"), emptyDf.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq.empty + ) + // Left outer: all left rows with null right + checkAnswer( + df1.joinAsOf( + emptyDf, df1.col("a"), emptyDf.col("a"), usingColumns = Seq.empty, + joinType = "leftouter", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", null, null, null), + Row(5, "y", "b", null, null, null), + Row(10, "z", "c", null, null, null) + ) + ) + } + + test("null as-of keys") { + val schema1 = StructType( + StructField("a", IntegerType, true) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("a", IntegerType, true) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(null, "x"), Row(3, "y"), Row(7, "z")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(1, "a"), Row(null, "b"), Row(5, "c")).asJava, schema2) + // Null as-of keys should not match anything (as-of condition + // evaluates to null for null inputs) + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "leftouter", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(null, "x", null, null), + Row(3, "y", 1, "a"), + Row(7, "z", 5, "c") + ) + ) + } + + test("multiple rows with same equi-key") { + val schema1 = StructType( + StructField("grp", StringType) :: + StructField("ts", IntegerType) :: Nil) + val schema2 = StructType( + StructField("grp", StringType) :: + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List( + Row("A", 5), Row("A", 10), Row("A", 15), + Row("B", 3), Row("B", 8) + ).asJava, schema1) + val df2 = spark.createDataFrame( + List( + Row("A", 2, "a1"), Row("A", 7, "a2"), Row("A", 12, "a3"), + Row("B", 1, "b1"), Row("B", 6, "b2"), Row("B", 10, "b3") + ).asJava, schema2) + checkAnswer( + df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq("grp"), + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row("A", 5, "A", 2, "a1"), + Row("A", 10, "A", 7, "a2"), + Row("A", 15, "A", 12, "a3"), + Row("B", 3, "B", 1, "b1"), + Row("B", 8, "B", 6, "b2") + ) + ) + } + + test("long type as-of key") { + val schema1 = StructType( + StructField("ts", LongType) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("ts", LongType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(100L, "a"), Row(200L, "b"), Row(300L, "c")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(50L, "x"), Row(150L, "y"), Row(250L, "z")).asJava, schema2) + checkAnswer( + df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(100L, "a", 50L, "x"), + Row(200L, "b", 150L, "y"), + Row(300L, "c", 250L, "z") + ) + ) + } + + test("double type as-of key") { + val schema1 = StructType( + StructField("ts", DoubleType) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("ts", DoubleType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(1.5, "a"), Row(3.0, "b"), Row(5.5, "c")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(1.0, "x"), Row(2.5, "y"), Row(4.0, "z")).asJava, schema2) + checkAnswer( + df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1.5, "a", 1.0, "x"), + Row(3.0, "b", 2.5, "y"), + Row(5.5, "c", 4.0, "z") + ) + ) + } + + test("conf disabled falls back to correlated subquery rewrite") { + val (df1, df2) = prepareForAsOfJoin() + withSQLConf(SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key -> "false") { + val result = df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward") + val plan = result.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { + case _: SortMergeAsOfJoinExec => true + }.isEmpty, "Should NOT use SortMergeAsOfJoinExec when conf is disabled") + // Results should still be correct + checkAnswer(result, Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 3, "x", 3), + Row(10, "z", "c", 7, "z", 7) + )) + } + } + + test("self join") { + val schema = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val df = spark.createDataFrame( + List(Row(1, "a"), Row(3, "b"), Row(5, "c")).asJava, schema) + checkAnswer( + df.joinAsOf( + df, df.col("ts"), df.col("ts"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "a", 1, "a"), + Row(3, "b", 3, "b"), + Row(5, "c", 5, "c") + ) + ) + } + + test("no equi-key - all rows in single partition") { + val schema1 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(2, "a"), Row(5, "b"), Row(9, "c")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(1, "x"), Row(4, "y"), Row(7, "z")).asJava, schema2) + checkAnswer( + df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(2, "a", 1, "x"), + Row(5, "b", 4, "y"), + Row(9, "c", 7, "z") + ) + ) + } + + test("forward join - left outer with no match") { + val schema1 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(1, "a"), Row(5, "b"), Row(10, "c")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(3, "x"), Row(7, "y")).asJava, schema2) + checkAnswer( + df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq.empty, + joinType = "leftouter", tolerance = null, + allowExactMatches = true, direction = "forward"), + Seq( + Row(1, "a", 3, "x"), + Row(5, "b", 7, "y"), + Row(10, "c", null, null) // no right row >= 10 + ) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AsOfJoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AsOfJoinBenchmark.scala new file mode 100644 index 0000000000000..c593376504b84 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AsOfJoinBenchmark.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.classic +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Benchmark to measure AS-OF join performance: sort-merge operator vs correlated subquery. + * To run this benchmark: + * {{{ + * 1. build/sbt "sql/Test/runMain " + * 2. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain " + * Results will be written to + * "benchmarks/AsOfJoinBenchmark-results.txt". + * }}} + */ +object AsOfJoinBenchmark extends SqlBasedBenchmark { + + private def doAsOfJoin( + left: classic.DataFrame, + right: classic.DataFrame, + usingColumns: Seq[String]): Unit = { + left.joinAsOf( + right, left.col("ts"), right.col("ts"), + usingColumns = usingColumns, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward" + ).noop() + } + + private def asOfJoinBenchmark( + leftRows: Int, + rightRows: Int, + numGroups: Int): Unit = { + val left: classic.DataFrame = spark.range(leftRows).select( + (col("id") % numGroups).as("group_id"), + col("id").as("ts"), + lit("left_val").as("left_val") + ).toDF().asInstanceOf[classic.DataFrame] + val right: classic.DataFrame = spark.range(rightRows).select( + (col("id") % numGroups).as("group_id"), + (col("id") * 3 / 2).as("ts"), + lit("right_val").as("right_val") + ).toDF().asInstanceOf[classic.DataFrame] + + val benchmark = new Benchmark( + s"AS-OF Join (left=$leftRows, right=$rightRows, groups=$numGroups)", + leftRows, + output = output) + + benchmark.addCase("Correlated subquery (baseline)") { _ => + withSQLConf( + SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key -> "false", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + doAsOfJoin(left, right, Seq("group_id")) + } + } + + benchmark.addCase("Sort-merge AS-OF join") { _ => + withSQLConf( + SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + doAsOfJoin(left, right, Seq("group_id")) + } + } + + benchmark.run() + } + + private def asOfJoinNoEquiKeyBenchmark( + leftRows: Int, rightRows: Int): Unit = { + val left: classic.DataFrame = spark.range(leftRows).select( + col("id").as("ts"), + lit("left_val").as("left_val") + ).toDF().asInstanceOf[classic.DataFrame] + val right: classic.DataFrame = spark.range(rightRows).select( + (col("id") * 3 / 2).as("ts"), + lit("right_val").as("right_val") + ).toDF().asInstanceOf[classic.DataFrame] + + val benchmark = new Benchmark( + s"AS-OF Join no equi-key (left=$leftRows, right=$rightRows)", + leftRows, + output = output) + + benchmark.addCase("Correlated subquery (baseline)") { _ => + withSQLConf( + SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key -> "false", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + doAsOfJoin(left, right, Seq.empty) + } + } + + benchmark.addCase("Sort-merge AS-OF join") { _ => + withSQLConf( + SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + doAsOfJoin(left, right, Seq.empty) + } + } + + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("AS-OF Join Benchmark") { + // 10K left x 10K right, 100 groups — both paths feasible + asOfJoinBenchmark(leftRows = 10000, rightRows = 10000, numGroups = 100) + // No equi-key: 10K x 10K + asOfJoinNoEquiKeyBenchmark(leftRows = 10000, rightRows = 10000) + } + } +}