Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,43 +48,53 @@ import org.apache.spark.sql.catalyst.rules._
* }}}
*/
object RewriteAsOfJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithNewOutput {
case j @ AsOfJoin(left, right, asOfCondition, condition, joinType, orderExpression, _) =>
val conditionWithOuterReference =
condition.map(And(_, asOfCondition)).getOrElse(asOfCondition).transformUp {
case a: AttributeReference if left.outputSet.contains(a) =>
OuterReference(a)
}
val filtered = Filter(conditionWithOuterReference, right)
def apply(plan: LogicalPlan): LogicalPlan = {
// When the sort-merge AS-OF join operator is enabled, skip this rewrite
// so that the AsOfJoin logical node reaches the planner intact and
// AsOfJoinSelection can produce a dedicated physical operator.
// This conf-based gating (rather than excludedRules) is used because
// the planner strategy and this optimizer rule must be kept in sync:
// if the rewrite runs, the planner never sees AsOfJoin.
if (conf.sortMergeAsOfJoinEnabled) return plan

val orderExpressionWithOuterReference = orderExpression.transformUp {
case a: AttributeReference if left.outputSet.contains(a) =>
OuterReference(a)
plan.transformUpWithNewOutput {
case j @ AsOfJoin(left, right, asOfCondition, condition, joinType, orderExpression, _) =>
val conditionWithOuterReference =
condition.map(And(_, asOfCondition)).getOrElse(asOfCondition).transformUp {
case a: AttributeReference if left.outputSet.contains(a) =>
OuterReference(a)
}
val rightStruct = CreateStruct(right.output)
val nearestRight = MinBy(rightStruct, orderExpressionWithOuterReference)
.toAggregateExpression()
val aggExpr = Alias(nearestRight, "__nearest_right__")()
val aggregate = Aggregate(Seq.empty, Seq(aggExpr), filtered)
val filtered = Filter(conditionWithOuterReference, right)

val projectWithScalarSubquery = Project(
left.output :+ Alias(ScalarSubquery(aggregate, left.output), "__right__")(),
left)
val orderExpressionWithOuterReference = orderExpression.transformUp {
case a: AttributeReference if left.outputSet.contains(a) =>
OuterReference(a)
}
val rightStruct = CreateStruct(right.output)
val nearestRight = MinBy(rightStruct, orderExpressionWithOuterReference)
.toAggregateExpression()
val aggExpr = Alias(nearestRight, "__nearest_right__")()
val aggregate = Aggregate(Seq.empty, Seq(aggExpr), filtered)

val filterRight = joinType match {
case LeftOuter => projectWithScalarSubquery
case _ =>
Filter(IsNotNull(projectWithScalarSubquery.output.last), projectWithScalarSubquery)
}
val projectWithScalarSubquery = Project(
left.output :+ Alias(ScalarSubquery(aggregate, left.output), "__right__")(),
left)

val project = Project(
left.output ++ right.output.zipWithIndex.map {
case (out, idx) =>
Alias(GetStructField(filterRight.output.last, idx), out.name)()
},
filterRight)
val attrMapping = j.output.zip(project.output)
val filterRight = joinType match {
case LeftOuter => projectWithScalarSubquery
case _ =>
Filter(IsNotNull(projectWithScalarSubquery.output.last), projectWithScalarSubquery)
}

val project = Project(
left.output ++ right.output.zipWithIndex.map {
case (out, idx) =>
Alias(GetStructField(filterRight.output.last, idx), out.name)()
},
filterRight)
val attrMapping = j.output.zip(project.output)

project -> attrMapping
project -> attrMapping
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,17 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val SORT_MERGE_AS_OF_JOIN_ENABLED =
buildConf("spark.sql.join.sortMergeAsOfJoin.enabled")
.doc("When true, use a dedicated sort-merge physical operator for AS-OF joins " +
"instead of rewriting to a correlated subquery. The sort-merge operator evaluates " +
"AS-OF joins in O(N log N) by co-sorting both sides on the as-of key and scanning " +
"in a single pass.")
.version("4.3.0")
.withBindingPolicy(ConfigBindingPolicy.SESSION)
.booleanConf
.createWithDefault(false)

val REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION =
buildConf("spark.sql.requireAllClusterKeysForCoPartition")
.internal()
Expand Down Expand Up @@ -7960,6 +7971,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN)

def sortMergeAsOfJoinEnabled: Boolean = getConf(SORT_MERGE_AS_OF_JOIN_ENABLED)

def enableRadixSort: Boolean = getConf(RADIX_SORT_ENABLED)

def isParquetSchemaMergingEnabled: Boolean = getConf(PARQUET_SCHEMA_MERGING_ENABLED)
Expand Down
19 changes: 19 additions & 0 deletions sql/core/benchmarks/AsOfJoinBenchmark-jdk21-results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
================================================================================================
AS-OF Join Benchmark
================================================================================================

OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1010-azure
AMD EPYC 7763 64-Core Processor
AS-OF Join (left=10000, right=10000, groups=100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------
Correlated subquery (baseline) 37494 37697 287 0.0 3749357.9 1.0X
Sort-merge AS-OF join 62 85 19 0.2 6236.4 601.2X

OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1010-azure
AMD EPYC 7763 64-Core Processor
AS-OF Join no equi-key (left=10000, right=10000): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------
Correlated subquery (baseline) 23664 23684 27 0.0 2366425.3 1.0X
Sort-merge AS-OF join 1773 1795 32 0.0 177269.8 13.3X


19 changes: 19 additions & 0 deletions sql/core/benchmarks/AsOfJoinBenchmark-jdk25-results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
================================================================================================
AS-OF Join Benchmark
================================================================================================

OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1010-azure
AMD EPYC 7763 64-Core Processor
AS-OF Join (left=10000, right=10000, groups=100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------
Correlated subquery (baseline) 37549 37690 199 0.0 3754903.0 1.0X
Sort-merge AS-OF join 56 73 14 0.2 5550.4 676.5X

OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1010-azure
AMD EPYC 7763 64-Core Processor
AS-OF Join no equi-key (left=10000, right=10000): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------
Correlated subquery (baseline) 22100 22128 40 0.0 2209969.1 1.0X
Sort-merge AS-OF join 1614 1624 13 0.0 161417.4 13.7X


19 changes: 19 additions & 0 deletions sql/core/benchmarks/AsOfJoinBenchmark-results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
================================================================================================
AS-OF Join Benchmark
================================================================================================

OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1010-azure
AMD EPYC 7763 64-Core Processor
AS-OF Join (left=10000, right=10000, groups=100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------
Correlated subquery (baseline) 37391 37469 110 0.0 3739110.1 1.0X
Sort-merge AS-OF join 59 70 9 0.2 5918.1 631.8X

OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1010-azure
AMD EPYC 7763 64-Core Processor
AS-OF Join no equi-key (left=10000, right=10000): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------
Correlated subquery (baseline) 24013 24033 28 0.0 2401313.2 1.0X
Sort-merge AS-OF join 1713 1719 9 0.0 171343.3 14.0X


Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class SparkPlanner(val session: SparkSession, val experimentalMethods: Experimen
Window ::
WindowGroupLimit ::
JoinSelection ::
AsOfJoinSelection ::
InMemoryScans ::
SparkScripts ::
Pipelines ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,149 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* Supports both equi-joins and non-equi-joins.
* Supports only inner like joins.
*/
/**
* Plans AS-OF joins using a dedicated sort-merge operator when the
* conf is enabled.
*/
object AsOfJoinSelection extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case j @ AsOfJoin(left, right, asOfCondition, condition, joinType,
orderExpression, _) if conf.sortMergeAsOfJoinEnabled =>
val (leftKeys, rightKeys, residual) = condition match {
case Some(cond) => extractEquiJoinKeys(cond, left, right)
case None => (Seq.empty[Expression], Seq.empty[Expression], None)
}
val (leftAsOf, rightAsOf) = extractAsOfExprs(
asOfCondition, orderExpression, left, right)

joins.SortMergeAsOfJoinExec(
leftKeys, rightKeys, leftAsOf, rightAsOf,
asOfCondition, orderExpression, joinType, residual,
planLater(left), planLater(right)) :: Nil
case _ => Nil
}

/**
* Extract equi-join key pairs and residual (non-equi) condition
* from a conjunction. Only EqualTo is treated as equi-key;
* EqualNullSafe is excluded because the Scanner does not implement
* null-safe comparison semantics.
*/
private def extractEquiJoinKeys(
condition: Expression,
left: LogicalPlan,
right: LogicalPlan): (Seq[Expression], Seq[Expression], Option[Expression]) = {
val leftKeys =
new scala.collection.mutable.ArrayBuffer[Expression]()
val rightKeys =
new scala.collection.mutable.ArrayBuffer[Expression]()
val residuals =
new scala.collection.mutable.ArrayBuffer[Expression]()

flattenAnd(condition).foreach {
case EqualTo(l, r)
if l.references.subsetOf(left.outputSet) &&
r.references.subsetOf(right.outputSet) =>
leftKeys += l; rightKeys += r
case EqualTo(l, r)
if r.references.subsetOf(left.outputSet) &&
l.references.subsetOf(right.outputSet) =>
leftKeys += r; rightKeys += l
case other =>
residuals += other
}
val residual = residuals.reduceOption(And)
(leftKeys.toSeq, rightKeys.toSeq, residual)
}

private def flattenAnd(expr: Expression): Seq[Expression] = expr match {
case And(l, r) => flattenAnd(l) ++ flattenAnd(r)
case other => Seq(other)
}

private def extractAsOfExprs(
asOfCondition: Expression,
orderExpression: Expression,
left: LogicalPlan,
right: LogicalPlan): (Expression, Expression) = {
val leftAttrs = left.outputSet
val rightAttrs = right.outputSet

def find(expr: Expression): Option[(Expression, Expression)] = expr match {
case GreaterThanOrEqual(l, r)
if l.references.subsetOf(leftAttrs) && r.references.subsetOf(rightAttrs) =>
Some((l, r))
case GreaterThan(l, r)
if l.references.subsetOf(leftAttrs) && r.references.subsetOf(rightAttrs) =>
Some((l, r))
case LessThanOrEqual(l, r)
if l.references.subsetOf(leftAttrs) && r.references.subsetOf(rightAttrs) =>
Some((l, r))
case LessThan(l, r)
if l.references.subsetOf(leftAttrs) && r.references.subsetOf(rightAttrs) =>
Some((l, r))
case GreaterThanOrEqual(l, r)
if l.references.subsetOf(rightAttrs) && r.references.subsetOf(leftAttrs) =>
Some((r, l))
case GreaterThan(l, r)
if l.references.subsetOf(rightAttrs) && r.references.subsetOf(leftAttrs) =>
Some((r, l))
case LessThanOrEqual(l, r)
if l.references.subsetOf(rightAttrs) && r.references.subsetOf(leftAttrs) =>
Some((r, l))
case LessThan(l, r)
if l.references.subsetOf(rightAttrs) && r.references.subsetOf(leftAttrs) =>
Some((r, l))
case And(l, r) => find(l).orElse(find(r))
case _ => None
}

find(asOfCondition).orElse {
// For Nearest direction, asOfCondition may be TrueLiteral.
// Extract as-of keys from the orderExpression instead.
findFromOrder(orderExpression, leftAttrs, rightAttrs)
}.getOrElse {
// Last resort: find attributes from orderExpression
val allAttrs = orderExpression.collect {
case a: AttributeReference => a
}
val leftExpr = allAttrs.find(a => leftAttrs.contains(a)).getOrElse {
throw new IllegalStateException(
"Cannot extract left as-of key from AS-OF join condition " +
s"or order expression: $asOfCondition / $orderExpression")
}
val rightExpr = allAttrs.find(a => rightAttrs.contains(a)).getOrElse {
throw new IllegalStateException(
"Cannot extract right as-of key from AS-OF join condition " +
s"or order expression: $asOfCondition / $orderExpression")
}
(leftExpr, rightExpr)
}
}

/** Extract as-of key pair from orderExpression (distance metric). */
private def findFromOrder(
expr: Expression,
leftAttrs: AttributeSet,
rightAttrs: AttributeSet): Option[(Expression, Expression)] = {
expr match {
// Backward: Subtract(leftAsOf, rightAsOf)
// Forward: Subtract(rightAsOf, leftAsOf)
case Subtract(l, r, _)
if l.references.subsetOf(leftAttrs) &&
r.references.subsetOf(rightAttrs) =>
Some((l, r))
case Subtract(l, r, _)
if l.references.subsetOf(rightAttrs) &&
r.references.subsetOf(leftAttrs) =>
Some((r, l))
// Nearest: If(GT(left, right), Sub(left, right), Sub(right, left))
case If(_, thenExpr, _) => findFromOrder(thenExpr, leftAttrs, rightAttrs)
case _ => None
}
}
}

object JoinSelection extends Strategy with JoinSelectionHelper {
private val hintErrorHandler = conf.hintErrorHandler

Expand Down
Loading