From 9ef629ec8183e0ee4fd25cfda5913bcc38a594d2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 19:33:12 +0800
Subject: [PATCH 01/43] modify slf4 dependcy
---
omnioperator/omniop-native-reader/java/pom.xml | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/omnioperator/omniop-native-reader/java/pom.xml b/omnioperator/omniop-native-reader/java/pom.xml
index e7ddfe6c3..bdad33a3a 100644
--- a/omnioperator/omniop-native-reader/java/pom.xml
+++ b/omnioperator/omniop-native-reader/java/pom.xml
@@ -8,13 +8,13 @@
com.huawei.boostkit
boostkit-omniop-native-reader
jar
- 3.3.1-1.6.0
+ 3.4.3-1.6.0
BoostKit Spark Native Sql Engine Extension With OmniOperator
2.12
- 3.3.1
+ 3.4.3
FALSE
../cpp/
../cpp/build/releases/
@@ -35,8 +35,8 @@
org.slf4j
- slf4j-api
- 1.7.32
+ slf4j-simple
+ 1.7.36
junit
@@ -132,4 +132,4 @@
-
\ No newline at end of file
+
--
Gitee
From 6895d78bf89db14198d89e5fda097664c7d6b0cc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 19:36:23 +0800
Subject: [PATCH 02/43] modify version and spark_version
---
omnioperator/omniop-spark-extension/pom.xml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml
index 4376a89be..e949cea2d 100644
--- a/omnioperator/omniop-spark-extension/pom.xml
+++ b/omnioperator/omniop-spark-extension/pom.xml
@@ -8,13 +8,13 @@
com.huawei.kunpeng
boostkit-omniop-spark-parent
pom
- 3.3.1-1.6.0
+ 3.4.3-1.6.0
BoostKit Spark Native Sql Engine Extension Parent Pom
2.12.10
2.12
- 3.3.1
+ 3.4.3
3.2.2
UTF-8
UTF-8
--
Gitee
From 0cc54a9911f8cdb1b90914588cafc1ef8b30f317 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 19:45:53 +0800
Subject: [PATCH 03/43] modify version and spark_version
---
omnioperator/omniop-spark-extension/java/pom.xml | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml
index 9cc1b9d25..a40b415fd 100644
--- a/omnioperator/omniop-spark-extension/java/pom.xml
+++ b/omnioperator/omniop-spark-extension/java/pom.xml
@@ -7,7 +7,7 @@
com.huawei.kunpeng
boostkit-omniop-spark-parent
- 3.3.1-1.6.0
+ 3.4.3-1.6.0
../pom.xml
@@ -52,7 +52,7 @@
com.huawei.boostkit
boostkit-omniop-native-reader
- 3.3.1-1.6.0
+ 3.4.3-1.6.0
junit
@@ -247,7 +247,7 @@
true
false
${project.basedir}/src/main/scala
- ${project.basedir}/src/test/scala
+ ${project.basedir}/src/test/scala
${user.dir}/scalastyle-config.xml
${project.basedir}/target/scalastyle-output.xml
${project.build.sourceEncoding}
@@ -335,7 +335,7 @@
org.scalatest
scalatest-maven-plugin
- false
+ false
${project.build.directory}/surefire-reports
.
@@ -352,4 +352,4 @@
-
\ No newline at end of file
+
--
Gitee
From 630028282846f9e3cd28474e7c7a047e24ccd3ee Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 19:50:41 +0800
Subject: [PATCH 04/43] modify SubqueryExpression parametre nums and add new
method
---
.../apache/spark/sql/catalyst/expressions/runtimefilter.scala | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala
index 0a5d509b0..85192fc36 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala
@@ -39,7 +39,7 @@ case class RuntimeFilterSubquery(
exprId: ExprId = NamedExpression.newExprId,
hint: Option[HintInfo] = None)
extends SubqueryExpression(
- filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty)
+ filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty, hint)
with Unevaluable
with UnaryLike[Expression] {
@@ -74,6 +74,8 @@ case class RuntimeFilterSubquery(
override protected def withNewChildInternal(newChild: Expression): RuntimeFilterSubquery =
copy(filterApplicationSideExp = newChild)
+
+ override def withNewHint(hint: Option[HintInfo]): RuntimeFilterSubquery = copy(hint = hint)
}
/**
--
Gitee
From a93c1f9a86dad08d23068ba5cbf065dcf5ea75bb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 19:51:40 +0800
Subject: [PATCH 05/43] add and remove Expression patterns
---
.../sql/catalyst/Tree/TreePatterns.scala | 20 +++++++++++++++----
1 file changed, 16 insertions(+), 4 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala
index ea2712447..ef17a0740 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala
@@ -25,7 +25,8 @@ object TreePattern extends Enumeration {
// Expression patterns (alphabetically ordered)
val AGGREGATE_EXPRESSION = Value(0)
val ALIAS: Value = Value
- val AND_OR: Value = Value
+// val AND_OR: Value = Value
+ val AND: Value = Value
val ARRAYS_ZIP: Value = Value
val ATTRIBUTE_REFERENCE: Value = Value
val APPEND_COLUMNS: Value = Value
@@ -58,6 +59,7 @@ object TreePattern extends Enumeration {
val JSON_TO_STRUCT: Value = Value
val LAMBDA_FUNCTION: Value = Value
val LAMBDA_VARIABLE: Value = Value
+ val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value // spark3.4.3
val LATERAL_SUBQUERY: Value = Value
val LIKE_FAMLIY: Value = Value
val LIST_SUBQUERY: Value = Value
@@ -69,27 +71,33 @@ object TreePattern extends Enumeration {
val NULL_CHECK: Value = Value
val NULL_LITERAL: Value = Value
val SERIALIZE_FROM_OBJECT: Value = Value
+ val OR: Value = Value // spark3.4.3
val OUTER_REFERENCE: Value = Value
+ val PARAMETER: Value = Value // spark3.4.3
+ val PARAMETERIZED_QUERY: Value = Value // spark3.4.3
val PIVOT: Value = Value
val PLAN_EXPRESSION: Value = Value
val PYTHON_UDF: Value = Value
val REGEXP_EXTRACT_FAMILY: Value = Value
val REGEXP_REPLACE: Value = Value
val RUNTIME_REPLACEABLE: Value = Value
- val RUNTIME_FILTER_EXPRESSION: Value = Value
- val RUNTIME_FILTER_SUBQUERY: Value = Value
+ val RUNTIME_FILTER_EXPRESSION: Value = Value // spark3.4.3移除
+ val RUNTIME_FILTER_SUBQUERY: Value = Value // spark3.4.3移除
val SCALAR_SUBQUERY: Value = Value
val SCALAR_SUBQUERY_REFERENCE: Value = Value
val SCALA_UDF: Value = Value
+ val SESSION_WINDOW: Value = Value // spark3.4.3
val SORT: Value = Value
val SUBQUERY_ALIAS: Value = Value
- val SUBQUERY_WRAPPER: Value = Value
+ val SUBQUERY_WRAPPER: Value = Value // spark3.4.3移除
val SUM: Value = Value
val TIME_WINDOW: Value = Value
val TIME_ZONE_AWARE_EXPRESSION: Value = Value
val TRUE_OR_FALSE_LITERAL: Value = Value
val WINDOW_EXPRESSION: Value = Value
+ val WINDOW_TIME: Value = Value // saprk3.4.3
val UNARY_POSITIVE: Value = Value
+ val UNPIVOT: Value = Value // spark3.4.3
val UPDATE_FIELDS: Value = Value
val UPPER_OR_LOWER: Value = Value
val UP_CAST: Value = Value
@@ -119,6 +127,7 @@ object TreePattern extends Enumeration {
val UNION: Value = Value
val UNRESOLVED_RELATION: Value = Value
val UNRESOLVED_WITH: Value = Value
+ val TEMP_RESOLVED_COLUMN: Value = Value // spark3.4.3
val TYPED_FILTER: Value = Value
val WINDOW: Value = Value
val WITH_WINDOW_DEFINITION: Value = Value
@@ -127,6 +136,7 @@ object TreePattern extends Enumeration {
val UNRESOLVED_ALIAS: Value = Value
val UNRESOLVED_ATTRIBUTE: Value = Value
val UNRESOLVED_DESERIALIZER: Value = Value
+ val UNRESOLVED_HAVING: Value = Value // spark3.4.3
val UNRESOLVED_ORDINAL: Value = Value
val UNRESOLVED_FUNCTION: Value = Value
val UNRESOLVED_HINT: Value = Value
@@ -135,6 +145,8 @@ object TreePattern extends Enumeration {
// Unresolved Plan patterns (Alphabetically ordered)
val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value
val UNRESOLVED_FUNC: Value = Value
+ val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value // spark3.4.3
+ val UNRESOLVED_TVF_ALIASES: Value = Value // spark3.4.3
// Execution expression patterns (alphabetically ordered)
val IN_SUBQUERY_EXEC: Value = Value
--
Gitee
From ac1473dc6e4104405caf6e71896aba689429d931 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 19:52:34 +0800
Subject: [PATCH 06/43] repair parametre nums
---
.../sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala
index 9e4029025..f6ebd716d 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala
@@ -61,7 +61,7 @@ object RewriteSelfJoinInInPredicate extends Rule[LogicalPlan] with PredicateHelp
case f: Filter =>
f transformExpressions {
case in @ InSubquery(_, listQuery @ ListQuery(Project(projectList,
- Join(left, right, Inner, Some(joinCond), _)), _, _, _, _))
+ Join(left, right, Inner, Some(joinCond), _)), _, _, _, _, _))
if left.canonicalized ne right.canonicalized =>
val attrMapping = AttributeMap(right.output.zip(left.output))
val subCondExprs = splitConjunctivePredicates(joinCond transform {
--
Gitee
From 6efd40b5b7f3123935f75a3b4f209890122f0299 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 19:55:09 +0800
Subject: [PATCH 07/43] add normalized ,adjust parametres and modify method
toInternalError
---
.../spark/sql/execution/QueryExecution.scala | 30 ++++++++++++++-----
1 file changed, 23 insertions(+), 7 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index ef33a84de..37f3db5d6 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -105,6 +105,23 @@ class QueryExecution(
case other => other
}
+ // The plan that has been normalized by custom rules, so that it's more likely to hit cache.
+ lazy val normalized: LogicalPlan = {
+ val normalizationRules = sparkSession.sessionState.planNormalizationRules
+ if (normalizationRules.isEmpty) {
+ commandExecuted
+ } else {
+ val planChangeLogger = new PlanChangeLogger[LogicalPlan]()
+ val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) =>
+ val result = rule.apply(p)
+ planChangeLogger.logRule(rule.ruleName, p, result)
+ result
+ }
+ planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized)
+ normalized
+ }
+ } // Spark3.4.3
+
lazy val withCachedData: LogicalPlan = sparkSession.withActive {
assertAnalyzed()
assertSupported()
@@ -227,7 +244,7 @@ class QueryExecution(
// output mode does not matter since there is no `Sink`.
new IncrementalExecution(
sparkSession, logical, OutputMode.Append(), "",
- UUID.randomUUID, UUID.randomUUID, 0, OffsetSeqMetadata(0, 0))
+ UUID.randomUUID, UUID.randomUUID, 0, None ,OffsetSeqMetadata(0, 0))
} else {
this
}
@@ -494,11 +511,10 @@ object QueryExecution {
*/
private[sql] def toInternalError(msg: String, e: Throwable): Throwable = e match {
case e @ (_: java.lang.NullPointerException | _: java.lang.AssertionError) =>
- new SparkException(
- errorClass = "INTERNAL_ERROR",
- messageParameters = Array(msg +
- " Please, fill a bug report in, and provide the full stack trace."),
- cause = e)
+ SparkException.internalError(
+ msg + " You hit a bug in Spark or the Spark plugins you use. Please, report this bug " +
+ "to the corresponding communities or vendors, and provide the full stack trace.",
+ e)
case e: Throwable =>
e
}
@@ -513,4 +529,4 @@ object QueryExecution {
case e: Throwable => throw toInternalError(msg, e)
}
}
-}
\ No newline at end of file
+}
--
Gitee
From 817c0425ce4efb7d7adea3725ac75454662ccc2c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 19:55:32 +0800
Subject: [PATCH 08/43] modify parametres
---
.../sql/execution/adaptive/PlanAdaptiveSubqueries.scala | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
index b5a1ad375..dfdbe2c70 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
@@ -30,11 +30,11 @@ case class PlanAdaptiveSubqueries(
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(
SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY, RUNTIME_FILTER_SUBQUERY)) {
- case expressions.ScalarSubquery(_, _, exprId, _) =>
+ case expressions.ScalarSubquery(_, _, exprId, _, _, _) =>
val subquery = SubqueryExec.createForScalarSubquery(
s"subquery#${exprId.id}", subqueryMap(exprId.id))
execution.ScalarSubquery(subquery, exprId)
- case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _)) =>
+ case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) =>
val expr = if (values.length == 1) {
values.head
} else {
@@ -47,7 +47,7 @@ case class PlanAdaptiveSubqueries(
val subquery = SubqueryExec(s"subquery#${exprId.id}", subqueryMap(exprId.id))
InSubqueryExec(expr, subquery, exprId, shouldBroadcast = true)
case expressions.DynamicPruningSubquery(value, buildPlan,
- buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) =>
+ buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) =>
val name = s"dynamicpruning#${exprId.id}"
val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndex, onlyInBroadcast,
buildPlan, buildKeys, subqueryMap(exprId.id))
--
Gitee
From 7013a77760c297ada613d17245d95a7e5aa94e59 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 19:56:07 +0800
Subject: [PATCH 09/43] modify filePath
---
.../execution/datasources/parquet/OmniParquetFileFormat.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala
index 9504b34d1..3d61f873d 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala
@@ -80,7 +80,7 @@ class OmniParquetFileFormat extends FileFormat with DataSourceRegister with Logg
(file: PartitionedFile) => {
assert(file.partitionValues.numFields == partitionSchema.size)
- val filePath = new Path(new URI(file.filePath))
+ val filePath = file.toPath
val split =
new org.apache.parquet.hadoop.ParquetInputSplit(
filePath,
--
Gitee
From 902deae2b0092bbd810588a185a13d54b82f8b97 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 19:57:33 +0800
Subject: [PATCH 10/43] modify filePath and ORC predicate pushdown
---
.../datasources/orc/OmniOrcFileFormat.scala | 30 ++++++++++++-------
1 file changed, 20 insertions(+), 10 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala
index 334800f51..807369004 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala
@@ -118,17 +118,27 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ
(file: PartitionedFile) => {
val conf = broadcastedConf.value.value
- val filePath = new Path(new URI(file.filePath))
- val isPPDSafeValue = isPPDSafe(filters, dataSchema).reduceOption(_ && _)
+// val filePath = new Path(new URI(file.filePath.urlEncoded))
+ val filePath = file.toPath
- // ORC predicate pushdown
- if (orcFilterPushDown && filters.nonEmpty && isPPDSafeValue.getOrElse(false)) {
- OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach {
- fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f =>
- OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
- }
- }
- }
+ val fs = filePath.getFileSystem(conf)
+ val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
+ val orcSchema =
+ Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions))(_.getSchema)
+ val isPPDSafeValue = isPPDSafe(filters, dataSchema).reduceOption(_ && _)
+ val resultedColPruneInfo = OrcUtils.requestedColumnIds(
+ isCaseSensitive, dataSchema, requiredSchema, orcSchema, conf)
+
+
+ // ORC predicate pushdown
+ if (orcFilterPushDown && filters.nonEmpty && isPPDSafeValue.getOrElse(false)) {
+ val fileSchema = OrcUtils.toCatalystSchema(orcSchema)
+ // OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { fileSchema =>
+ OrcFilters.createFilter(fileSchema, filters).foreach { f =>
+ OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
+ // }
+ }
+ }
val taskConf = new Configuration(conf)
val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty)
--
Gitee
From a1de18295a282d69537f355e84c0a70f929656f7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 19:59:47 +0800
Subject: [PATCH 11/43] modify parametres , remove promotePrecision, modify
CastBase to Cast
---
.../spark/expression/OmniExpressionAdaptor.scala | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
index d1f911a8c..f7ed75fa3 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
@@ -76,7 +76,7 @@ object OmniExpressionAdaptor extends Logging {
}
}
- private def unsupportedCastCheck(expr: Expression, cast: CastBase): Unit = {
+ private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = {
def doSupportCastToString(dataType: DataType): Boolean = {
dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[IntegerType] ||
dataType.isInstanceOf[LongType] || dataType.isInstanceOf[DateType] || dataType.isInstanceOf[DoubleType] ||
@@ -160,8 +160,8 @@ object OmniExpressionAdaptor extends Logging {
throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}")
}
- case promotePrecision: PromotePrecision =>
- rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap)
+// case promotePrecision: PromotePrecision =>
+// rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap)
case sub: Subtract =>
new JSONObject().put("exprType", "BINARY")
@@ -296,7 +296,7 @@ object OmniExpressionAdaptor extends Logging {
.put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.len, exprsIndexMap)))
// Cast
- case cast: CastBase =>
+ case cast: Cast =>
unsupportedCastCheck(expr, cast)
cast.child.dataType match {
case NullType =>
@@ -588,10 +588,10 @@ object OmniExpressionAdaptor extends Logging {
rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap)
} else {
children.head match {
- case base: CastBase if base.child.dataType.isInstanceOf[NullType] =>
+ case base: Cast if base.child.dataType.isInstanceOf[NullType] =>
rewriteToOmniJsonExpressionLiteralJsonObject(children(1), exprsIndexMap)
case _ => children(1) match {
- case base: CastBase if base.child.dataType.isInstanceOf[NullType] =>
+ case base: Cast if base.child.dataType.isInstanceOf[NullType] =>
rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap)
case _ =>
new JSONObject().put("exprType", "FUNCTION")
--
Gitee
From c86a67d2da6bdf99ea66b87b7d3bb7e1acbdd3ed Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 20:00:23 +0800
Subject: [PATCH 12/43] modify parametres
---
.../sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala
index aaa244cdf..e1c620e1c 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala
@@ -43,7 +43,7 @@ class MergeSubqueryFiltersSuite extends PlanTest {
}
private def extractorExpression(cteIndex: Int, output: Seq[Attribute], fieldIndex: Int) = {
- GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output)), fieldIndex)
+ GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output, isStreaming = false)), fieldIndex)
.as("scalarsubquery()")
}
--
Gitee
From 2dcdc3e5cd912af7ec0bcefda0e77594504400d5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 20:00:35 +0800
Subject: [PATCH 13/43] modify parametres
---
.../spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala
index 1b5baa230..c4435379f 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala
@@ -643,7 +643,7 @@ object MergeSubqueryFilters extends Rule[LogicalPlan] {
val subqueryCTE = header.plan.asInstanceOf[CTERelationDef]
GetStructField(
ScalarSubquery(
- CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output),
+ CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output, subqueryCTE.isStreaming),
exprId = ssr.exprId),
ssr.headerIndex)
} else {
--
Gitee
From 3c57000ec57c50a6d683458a61026e68e5747b25 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 20:04:31 +0800
Subject: [PATCH 14/43] modify
injectBloomFilter,injectInsubqueryFilter,isSelectiveFilterOverScan,hasDynamicPruningSubquery,hasInSubquery,tryInJectRuntimeFilter
---
.../optimizer/InjectRuntimeFilter.scala | 67 ++++++++++++++-----
1 file changed, 51 insertions(+), 16 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index 812c387bc..e6c266242 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import com.huawei.boostkit.spark.ColumnarPluginConfig
+import scala.annotation.tailrec
+
/**
* Insert a filter on one side of the join if the other side has a selective predicate.
* The filter could be an IN subquery (converted to a semi join), a bloom filter, or something
@@ -85,12 +87,12 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
val bloomFilterAgg =
if (rowCount.isDefined && rowCount.get.longValue > 0L) {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)),
- Literal(rowCount.get.longValue))
+ rowCount.get.longValue)
} else {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)))
}
- val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None)
- val alias = Alias(aggExp, "bloomFilter")()
+
+ val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")()
val aggregate =
ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan)))
val bloomFilterSubquery = if (canReuseExchange) {
@@ -112,7 +114,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType)
val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp)
val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)()
- val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan)
+ val aggregate =
+ ColumnPruning(Aggregate(Seq(filterCreationSideExp), Seq(alias), filterCreationSidePlan))
if (!canBroadcastBySize(aggregate, conf)) {
// Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold,
// i.e., the semi-join will be a shuffled join, which is not worthwhile.
@@ -129,13 +132,39 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
* do not add a subquery that might have an expensive computation
*/
private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = {
- val ret = plan match {
- case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] =>
- filters.forall(isSimpleExpression) &&
- filters.exists(isLikelySelective)
+ @tailrec
+ def isSelective(
+ p: LogicalPlan,
+ predicateReference: AttributeSet,
+ hasHitFilter: Boolean,
+ hasHitSelectiveFilter: Boolean): Boolean = p match {
+ case Project(projectList, child) =>
+ if (hasHitFilter) {
+ // We need to make sure all expressions referenced by filter predicates are simple
+ // expressions.
+ val referencedExprs = projectList.filter(predicateReference.contains)
+ referencedExprs.forall(isSimpleExpression) &&
+ isSelective(
+ child,
+ referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _),
+ hasHitFilter,
+ hasHitSelectiveFilter)
+ } else {
+ assert(predicateReference.isEmpty && !hasHitSelectiveFilter)
+ isSelective(child, predicateReference, hasHitFilter, hasHitSelectiveFilter)
+ }
+ case Filter(condition, child) =>
+ isSimpleExpression(condition) && isSelective(
+ child,
+ predicateReference ++ condition.references,
+ hasHitFilter = true,
+ hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition))
+ case _: LeafNode => hasHitSelectiveFilter
case _ => false
}
- !plan.isStreaming && ret
+
+ !plan.isStreaming &&
+ isSelective(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false)
}
private def isSimpleExpression(e: Expression): Boolean = {
@@ -184,8 +213,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
/**
* Check that:
- * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the
- * expression references originate from a single leaf node)
+ * - The filterApplicationSideJoinExp can be pushed down through joins, aggregates and windows
+ * (ie the expression references originate from a single leaf node)
* - The filter creation side has a selective predicate
* - The current join is a shuffle join or a broadcast join that has a shuffle below it
* - The max filterApplicationSide scan size is greater than a configurable threshold
@@ -218,9 +247,9 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
leftKey: Expression,
rightKey: Expression): Boolean = {
(left, right) match {
- case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) =>
+ case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan), _) =>
pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey)
- case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) =>
+ case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan)) =>
pruningKey.fastEquals(rightKey) ||
hasDynamicPruningSubquery(left, plan, leftKey, rightKey)
case _ => false
@@ -251,10 +280,10 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
rightKey: Expression): Boolean = {
(left, right) match {
case (Filter(InSubquery(Seq(key),
- ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) =>
+ ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _), _) =>
key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey)))
case (_, Filter(InSubquery(Seq(key),
- ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) =>
+ ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _)) =>
key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey)))
case _ => false
}
@@ -299,7 +328,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
case s: Subquery if s.correlated => plan
case _ if !conf.runtimeFilterSemiJoinReductionEnabled &&
!conf.runtimeFilterBloomFilterEnabled => plan
- case _ => tryInjectRuntimeFilter(plan)
+ case _ =>
+ val newPlan = tryInjectRuntimeFilter(plan)
+ if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) {
+ RewritePredicateSubquery(newPlan)
+ } else {
+ newPlan
+ }
}
}
\ No newline at end of file
--
Gitee
From ae03791184939a9f51756a057b2f2de618f7d623 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 20:05:15 +0800
Subject: [PATCH 15/43] modify metastore.uris and hive.db
---
.../java/src/test/resources/HiveResource.properties | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties
index 89eabe8e6..099e28e8d 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties
+++ b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties
@@ -2,11 +2,13 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
#
-hive.metastore.uris=thrift://server1:9083
+#hive.metastore.uris=thrift://server1:9083
+hive.metastore.uris=thrift://OmniOperator:9083
spark.sql.warehouse.dir=/user/hive/warehouse
spark.memory.offHeap.size=8G
spark.sql.codegen.wholeStage=false
spark.sql.extensions=com.huawei.boostkit.spark.ColumnarPlugin
spark.shuffle.manager=org.apache.spark.shuffle.sort.OmniColumnarShuffleManager
spark.sql.orc.impl=native
-hive.db=tpcds_bin_partitioned_orc_2
\ No newline at end of file
+#hive.db=tpcds_bin_partitioned_orc_2
+hive.db=tpcds_bin_partitioned_varchar_orc_2
--
Gitee
From 2a06713e82ae19b1607369f856dae7b59e6dcd25 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 20:06:35 +0800
Subject: [PATCH 16/43] modify "test child max row" limit(0) to limit(1)
---
.../spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
index f83edb9ca..ea52aca62 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -61,7 +61,7 @@ class CombiningLimitsSuite extends PlanTest {
comparePlans(optimized1, expected1)
// test child max row > limit.
- val query2 = testRelation.select().groupBy()(count(1)).limit(0).analyze
+ val query2 = testRelation2.select($"x").groupBy($"x")(count(1)).limit(1).analyze
val optimized2 = Optimize.execute(query2)
comparePlans(optimized2, query2)
--
Gitee
From 5fd470d2e77d38f3de7a3bdb12d15fc789a4adf2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 20:07:43 +0800
Subject: [PATCH 17/43] modify type of path
---
.../spark/sql/execution/ColumnarFileSourceScanExec.scala | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala
index 800dcf1a0..4b70db452 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala
@@ -479,10 +479,11 @@ abstract class BaseColumnarFileSourceScanExec(
}
}.groupBy { f =>
BucketingUtils
- .getBucketId(new Path(f.filePath).getName)
- .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.filePath))
+ .getBucketId(f.toPath.getName)
+ .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath))
}
+
val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
val bucketSet = optionalBucketSet.get
filesGroupedToBuckets.filter {
--
Gitee
From deb446159bb9483b44cb7e7595145d021b820c96 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 20:08:38 +0800
Subject: [PATCH 18/43] modify traits of ColumnarProjectExec
---
.../execution/ColumnarBasicPhysicalOperators.scala | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
index 486369843..630034bd7 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
@@ -16,10 +16,10 @@
*/
package org.apache.spark.sql.execution
-
import java.util.concurrent.TimeUnit.NANOSECONDS
import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP
+
import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._
import com.huawei.boostkit.spark.util.OmniAdaptorUtil
import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs
@@ -41,11 +41,13 @@ import org.apache.spark.sql.execution.vectorized.OmniColumnVector
import org.apache.spark.sql.expression.ColumnarExpressionConverter
import org.apache.spark.sql.types.{LongType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.sql.catalyst.plans.{AliasAwareOutputExpression, AliasAwareQueryOutputOrdering}
+
case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryExecNode
- with AliasAwareOutputPartitioning
- with AliasAwareOutputOrdering {
+ with AliasAwareOutputExpression
+ with AliasAwareQueryOutputOrdering[SparkPlan] {
override def supportsColumnar: Boolean = true
@@ -267,8 +269,8 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression],
condition: Expression,
child: SparkPlan)
extends UnaryExecNode
- with AliasAwareOutputPartitioning
- with AliasAwareOutputOrdering {
+ with AliasAwareOutputExpression
+ with AliasAwareQueryOutputOrdering[SparkPlan] {
override def supportsColumnar: Boolean = true
--
Gitee
From ccfc3230bd3912eab482e2851db0026f582da553 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 20:12:41 +0800
Subject: [PATCH 19/43] modify "SPARK-30953" ,"SPARK-31658"
ADAPTIVE_EXECUTION_FORCE_APPLY,"SPARK-32932" ADAPTIVE_EXECUTION_ENABLED
---
.../ColumnarAdaptiveQueryExecSuite.scala | 47 +++++++++++--------
1 file changed, 28 insertions(+), 19 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala
index c34ff5bb1..c0be72f31 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala
@@ -19,17 +19,15 @@ package org.apache.spark.sql.execution.adaptive
import java.io.File
import java.net.URI
-
import org.apache.logging.log4j.Level
import org.scalatest.PrivateMethodTester
import org.scalatest.time.SpanSugar._
-
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
-import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec}
+import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
@@ -37,7 +35,7 @@ import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
-import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
+import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLExecutionStart}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
@@ -1122,13 +1120,21 @@ class AdaptiveQueryExecSuite
test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
+ var plan: SparkPlan = null
+ val listener = new QueryExecutionListener {
+ override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
+ plan = qe.executedPlan
+ }
+ override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
+ }
+ spark.listenerManager.register(listener)
withTable("t1") {
- val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan
- assert(plan.isInstanceOf[CommandResultExec])
- val commandResultExec = plan.asInstanceOf[CommandResultExec]
- assert(commandResultExec.commandPhysicalPlan.isInstanceOf[DataWritingCommandExec])
- assert(commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec]
- .child.isInstanceOf[AdaptiveSparkPlanExec])
+ val format = classOf[NoopDataSource].getName
+ Seq((0, 1)).toDF("x", "y").write.format(format).mode("overwrite").save()
+ sparkContext.listenerBus.waitUntilEmpty()
+ assert(plan.isInstanceOf[V2TableWriteExec])
+ assert(plan.asInstanceOf[V2TableWriteExec].child.isInstanceOf[AdaptiveSparkPlanExec])
+ spark.listenerManager.unregister(listener)
}
}
}
@@ -1172,15 +1178,14 @@ class AdaptiveQueryExecSuite
test("SPARK-31658: SQL UI should show write commands") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
+ SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "false") {
withTable("t1") {
- var checkDone = false
+ var commands: Seq[SparkPlanInfo] = Seq.empty
val listener = new SparkListener {
override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
- case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) =>
- assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand")
- checkDone = true
+ case start: SparkListenerSQLExecutionStart =>
+ commands = commands ++ Seq(start.sparkPlanInfo)
case _ => // ignore other events
}
}
@@ -1189,7 +1194,12 @@ class AdaptiveQueryExecSuite
try {
sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect()
spark.sparkContext.listenerBus.waitUntilEmpty()
- assert(checkDone)
+ assert(commands.size == 3)
+ assert(commands.head.nodeName == "Execute CreateDataSourceTableAsSelectCommand")
+ assert(commands(1).nodeName == "Execute InsertIntoHadoopFsRelationCommand")
+ assert(commands(1).children.size == 1)
+ assert(commands(1).children.head.nodeName == "WriteFiles")
+ assert(commands(2).nodeName == "CommandResult")
} finally {
spark.sparkContext.removeSparkListener(listener)
}
@@ -1574,7 +1584,7 @@ class AdaptiveQueryExecSuite
test("SPARK-32932: Do not use local shuffle read at final stage on write command") {
withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString,
SQLConf.SHUFFLE_PARTITIONS.key -> "5",
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
val data = for (
i <- 1L to 10L;
j <- 1L to 3L
@@ -1584,9 +1594,8 @@ class AdaptiveQueryExecSuite
var noLocalread: Boolean = false
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
- qe.executedPlan match {
+ stripAQEPlan(qe.executedPlan) match {
case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) =>
- assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec])
noLocalread = collect(plan) {
case exec: AQEShuffleReadExec if exec.isLocalRead => exec
}.isEmpty
--
Gitee
From 6c2b46f1878f86bdf60559ed19b947aa818af850 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 20:15:01 +0800
Subject: [PATCH 20/43] modify spark exception
---
.../sql/catalyst/expressions/CastSuite.scala | 55 +++++++++++++------
1 file changed, 37 insertions(+), 18 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index e6b786c2a..329295bac 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.SparkException
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.execution.ColumnarSparkPlanTest
import org.apache.spark.sql.types.{DataType, Decimal}
@@ -64,7 +65,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception = intercept[Exception](
result.collect().toSeq.head.getByte(0)
)
- assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast null as short") {
@@ -72,7 +74,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception = intercept[Exception](
result.collect().toSeq.head.getShort(0)
)
- assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast null as int") {
@@ -80,7 +83,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception = intercept[Exception](
result.collect().toSeq.head.getInt(0)
)
- assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast null as long") {
@@ -88,7 +92,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception = intercept[Exception](
result.collect().toSeq.head.getLong(0)
)
- assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast null as float") {
@@ -96,7 +101,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception = intercept[Exception](
result.collect().toSeq.head.getFloat(0)
)
- assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast null as double") {
@@ -104,7 +110,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception = intercept[Exception](
result.collect().toSeq.head.getDouble(0)
)
- assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast null as date") {
@@ -154,13 +161,15 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception4 = intercept[Exception](
result4.collect().toSeq.head.getBoolean(0)
)
- assert(exception4.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception4.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception4.isInstanceOf[SparkException], s"sql: ${sql}")
val result5 = spark.sql("select cast('test' as boolean);")
val exception5 = intercept[Exception](
result5.collect().toSeq.head.getBoolean(0)
)
- assert(exception5.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception5.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception5.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast boolean to string") {
@@ -182,13 +191,15 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception2 = intercept[Exception](
result2.collect().toSeq.head.getByte(0)
)
- assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}")
val result3 = spark.sql("select cast('false' as byte);")
val exception3 = intercept[Exception](
result3.collect().toSeq.head.getByte(0)
)
- assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast byte to string") {
@@ -210,13 +221,15 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception2 = intercept[Exception](
result2.collect().toSeq.head.getShort(0)
)
- assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}")
val result3 = spark.sql("select cast('false' as short);")
val exception3 = intercept[Exception](
result3.collect().toSeq.head.getShort(0)
)
- assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast short to string") {
@@ -238,13 +251,15 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception2 = intercept[Exception](
result2.collect().toSeq.head.getInt(0)
)
- assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}")
val result3 = spark.sql("select cast('false' as int);")
val exception3 = intercept[Exception](
result3.collect().toSeq.head.getInt(0)
)
- assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast int to string") {
@@ -266,13 +281,15 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception2 = intercept[Exception](
result2.collect().toSeq.head.getLong(0)
)
- assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}")
val result3 = spark.sql("select cast('false' as long);")
val exception3 = intercept[Exception](
result3.collect().toSeq.head.getLong(0)
)
- assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast long to string") {
@@ -298,7 +315,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception3 = intercept[Exception](
result3.collect().toSeq.head.getFloat(0)
)
- assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast float to string") {
@@ -324,7 +342,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception3 = intercept[Exception](
result3.collect().toSeq.head.getDouble(0)
)
- assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast double to string") {
--
Gitee
From f7e5c5b235131fb3eeae814e330ccffb099042fa Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 20:16:57 +0800
Subject: [PATCH 21/43] modify AQEOptimizer and optimizeQueryStage
---
.../adaptive/AdaptiveSparkPlanExec.scala | 32 +++++++++++++------
1 file changed, 22 insertions(+), 10 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 2599a1410..e5c8fc0ab 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -19,13 +19,11 @@ package org.apache.spark.sql.execution.adaptive
import java.util
import java.util.concurrent.LinkedBlockingQueue
-
import scala.collection.JavaConverters._
import scala.collection.concurrent.TrieMap
import scala.collection.mutable
import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal
-
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
@@ -35,12 +33,13 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
+import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan
import org.apache.spark.sql.execution.exchange._
-import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric}
+import org.apache.spark.sql.execution.ui.{SQLPlanMetric, SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{SparkFatalException, ThreadUtils}
@@ -83,7 +82,9 @@ case class AdaptiveSparkPlanExec(
@transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
// The logical plan optimizer for re-optimizing the current logical plan.
- @transient private val optimizer = new AQEOptimizer(conf)
+ @transient private val optimizer = new AQEOptimizer(conf,
+ session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules
+ )
// `EnsureRequirements` may remove user-specified repartition and assume the query plan won't
// change its output partitioning. This assumption is not true in AQE. Here we check the
@@ -122,7 +123,7 @@ case class AdaptiveSparkPlanExec(
RemoveRedundantSorts,
DisableUnnecessaryBucketedScan,
OptimizeSkewedJoin(ensureRequirements)
- ) ++ context.session.sessionState.queryStagePrepRules
+ ) ++ context.session.sessionState.adaptiveRulesHolder.queryStagePrepRules
}
// A list of physical optimizer rules to be applied to a new stage before its execution. These
@@ -152,7 +153,13 @@ case class AdaptiveSparkPlanExec(
)
private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = {
- val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) =>
+ val rules = if (isFinalStage &&
+ !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) {
+ queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule])
+ } else {
+ queryStageOptimizerRules
+ }
+ val optimized = rules.foldLeft(plan) { case (latestPlan, rule) =>
val applied = rule.apply(latestPlan)
val result = rule match {
case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) =>
@@ -187,7 +194,7 @@ case class AdaptiveSparkPlanExec(
@volatile private var currentPhysicalPlan = initialPlan
- private var isFinalPlan = false
+ @volatile private var _isFinalPlan = false
private var currentStageId = 0
@@ -204,6 +211,8 @@ case class AdaptiveSparkPlanExec(
def executedPlan: SparkPlan = currentPhysicalPlan
+ def isFinalPlan: Boolean = _isFinalPlan
+
override def conf: SQLConf = context.session.sessionState.conf
override def output: Seq[Attribute] = inputPlan.output
@@ -222,6 +231,8 @@ case class AdaptiveSparkPlanExec(
.map(_.toLong).filter(SQLExecution.getQueryExecution(_) eq context.qe)
}
+ def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity) // saprk 3.4.3
+
private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized {
if (isFinalPlan) return currentPhysicalPlan
@@ -309,7 +320,8 @@ case class AdaptiveSparkPlanExec(
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
if (newCost < origCost ||
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
- logOnLevel(s"Plan changed from $currentPhysicalPlan to $newPhysicalPlan")
+ logOnLevel("Plan changed:\n" +
+ sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n"))
cleanUpTempTags(newPhysicalPlan)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
@@ -325,7 +337,7 @@ case class AdaptiveSparkPlanExec(
optimizeQueryStage(result.newPlan, isFinalStage = true),
postStageCreationRules(supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
- isFinalPlan = true
+ _isFinalPlan = true
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
currentPhysicalPlan
}
@@ -339,7 +351,7 @@ case class AdaptiveSparkPlanExec(
if (!isSubquery && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) {
getExecutionId.foreach(onUpdatePlan(_, Seq.empty))
}
- logOnLevel(s"Final plan: $currentPhysicalPlan")
+ logOnLevel(s"Final plan: \n$currentPhysicalPlan")
}
override def executeCollect(): Array[InternalRow] = {
--
Gitee
From 905f5e86e14f0fcba2db39e3e31c025dda96c3a8 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Thu, 17 Oct 2024 06:58:44 +0000
Subject: [PATCH 22/43] update omnioperator/omniop-native-reader/java/pom.xml.
modify spark_version and logger dependency
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
omnioperator/omniop-native-reader/java/pom.xml | 1 +
1 file changed, 1 insertion(+)
diff --git a/omnioperator/omniop-native-reader/java/pom.xml b/omnioperator/omniop-native-reader/java/pom.xml
index bdad33a3a..773ed2c39 100644
--- a/omnioperator/omniop-native-reader/java/pom.xml
+++ b/omnioperator/omniop-native-reader/java/pom.xml
@@ -34,6 +34,7 @@
1.6.0
+
org.slf4j
slf4j-simple
1.7.36
--
Gitee
From 999221a97b98dcb398b7cf3a6fa5f071737a16b1 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Thu, 17 Oct 2024 07:10:15 +0000
Subject: [PATCH 23/43] restore PromotePrecision mode and modify class CastBase
to Cast
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../boostkit/spark/expression/OmniExpressionAdaptor.scala | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
index f7ed75fa3..be1538172 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
@@ -160,8 +160,8 @@ object OmniExpressionAdaptor extends Logging {
throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}")
}
-// case promotePrecision: PromotePrecision =>
-// rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap)
+ case promotePrecision: PromotePrecision =>
+ rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap)
case sub: Subtract =>
new JSONObject().put("exprType", "BINARY")
--
Gitee
From 30dd5a213b420c8b3f7ffef99adb5105de8b9097 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Thu, 17 Oct 2024 07:15:21 +0000
Subject: [PATCH 24/43] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20analysis?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep | 0
1 file changed, 0 insertions(+), 0 deletions(-)
create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep
new file mode 100644
index 000000000..e69de29bb
--
Gitee
From 2e17b148b13b72ba5a89b20989abfefc99d3e8a3 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Thu, 17 Oct 2024 07:25:39 +0000
Subject: [PATCH 25/43] add DecimalPrecision.scala
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../apache/spark/sql/catalyst/analysis/.keep | 0
.../catalyst/analysis/DecimalPrecision.scala | 347 ++++++++++++++++++
2 files changed, 347 insertions(+)
delete mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep
create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep
deleted file mode 100644
index e69de29bb..000000000
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
new file mode 100644
index 000000000..5a04f02ce
--- /dev/null
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
@@ -0,0 +1,347 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal._
+import org.apache.spark.sql.types._
+
+
+// scalastyle:off
+/**
+ * Calculates and propagates precision for fixed-precision decimals. Hive has a number of
+ * rules for this based on the SQL standard and MS SQL:
+ * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
+ * https://msdn.microsoft.com/en-us/library/ms190476.aspx
+ *
+ * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
+ * respectively, then the following operations have the following precision / scale:
+ *
+ * Operation Result Precision Result Scale
+ * ------------------------------------------------------------------------
+ * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
+ * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
+ * e1 * e2 p1 + p2 + 1 s1 + s2
+ * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)
+ * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)
+ * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2)
+ *
+ * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale
+ * needed are out of the range of available values, the scale is reduced up to 6, in order to
+ * prevent the truncation of the integer part of the decimals.
+ *
+ * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited
+ * precision, do the math on unlimited-precision numbers, then introduce casts back to the
+ * required fixed precision. This allows us to do all rounding and overflow handling in the
+ * cast-to-fixed-precision operator.
+ *
+ * In addition, when mixing non-decimal types with decimals, we use the following rules:
+ * - BYTE gets turned into DECIMAL(3, 0)
+ * - SHORT gets turned into DECIMAL(5, 0)
+ * - INT gets turned into DECIMAL(10, 0)
+ * - LONG gets turned into DECIMAL(20, 0)
+ * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
+ * - Literals INT and LONG get turned into DECIMAL with the precision strictly needed by the value
+ */
+// scalastyle:on
+object DecimalPrecision extends TypeCoercionRule {
+ import scala.math.{max, min}
+
+ private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
+
+ // Returns the wider decimal type that's wider than both of them
+ def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = {
+ widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale)
+ }
+ // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
+ def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
+ val scale = max(s1, s2)
+ val range = max(p1 - s1, p2 - s2)
+ DecimalType.bounded(range + scale, scale)
+ }
+
+ private def promotePrecision(e: Expression, dataType: DataType): Expression = {
+ PromotePrecision(Cast(e, dataType))
+ }
+
+ override def transform: PartialFunction[Expression, Expression] = {
+ decimalAndDecimal()
+ .orElse(integralAndDecimalLiteral)
+ .orElse(nondecimalAndDecimal(conf.literalPickMinimumPrecision))
+ }
+
+ private[catalyst] def decimalAndDecimal(): PartialFunction[Expression, Expression] = {
+ decimalAndDecimal(conf.decimalOperationsAllowPrecisionLoss, !conf.ansiEnabled)
+ }
+
+ /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */
+ private[catalyst] def decimalAndDecimal(allowPrecisionLoss: Boolean, nullOnOverflow: Boolean)
+ : PartialFunction[Expression, Expression] = {
+ // Skip nodes whose children have not been resolved yet
+ case e if !e.childrenResolved => e
+
+ // Skip nodes who is already promoted
+ case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e
+
+ case a @ Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
+ val resultScale = max(s1, s2)
+ val resultType = if (allowPrecisionLoss) {
+ DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
+ resultScale)
+ } else {
+ DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
+ }
+ CheckOverflow(
+ a.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)),
+ resultType, nullOnOverflow)
+
+ case s @ Subtract(e1 @ DecimalType.Expression(p1, s1),
+ e2 @ DecimalType.Expression(p2, s2), _) =>
+ val resultScale = max(s1, s2)
+ val resultType = if (allowPrecisionLoss) {
+ DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
+ resultScale)
+ } else {
+ DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
+ }
+ CheckOverflow(
+ s.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)),
+ resultType, nullOnOverflow)
+
+ case m @ Multiply(
+ e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
+ val resultType = if (allowPrecisionLoss) {
+ DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
+ } else {
+ DecimalType.bounded(p1 + p2 + 1, s1 + s2)
+ }
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ CheckOverflow(
+ m.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
+ resultType, nullOnOverflow)
+
+ case d @ Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
+ val resultType = if (allowPrecisionLoss) {
+ // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
+ // Scale: max(6, s1 + p2 + 1)
+ val intDig = p1 - s1 + s2
+ val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1)
+ val prec = intDig + scale
+ DecimalType.adjustPrecisionScale(prec, scale)
+ } else {
+ var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
+ var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
+ val diff = (intDig + decDig) - DecimalType.MAX_SCALE
+ if (diff > 0) {
+ decDig -= diff / 2 + 1
+ intDig = DecimalType.MAX_SCALE - decDig
+ }
+ DecimalType.bounded(intDig + decDig, decDig)
+ }
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ CheckOverflow(
+ d.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
+ resultType, nullOnOverflow)
+
+ case r @ Remainder(
+ e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
+ val resultType = if (allowPrecisionLoss) {
+ DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ } else {
+ DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ }
+ // resultType may have lower precision, so we cast them into wider type first.
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ CheckOverflow(
+ r.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
+ resultType, nullOnOverflow)
+
+ case p @ Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
+ val resultType = if (allowPrecisionLoss) {
+ DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ } else {
+ DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ }
+ // resultType may have lower precision, so we cast them into wider type first.
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ CheckOverflow(
+ p.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
+ resultType, nullOnOverflow)
+
+ case expr @ IntegralDivide(
+ e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ val promotedExpr = expr.copy(
+ left = promotePrecision(e1, widerType),
+ right = promotePrecision(e2, widerType))
+ if (expr.dataType.isInstanceOf[DecimalType]) {
+ // This follows division rule
+ val intDig = p1 - s1 + s2
+ // No precision loss can happen as the result scale is 0.
+ // Overflow can happen only in the promote precision of the operands, but if none of them
+ // overflows in that phase, no overflow can happen, but CheckOverflow is needed in order
+ // to return a decimal with the proper scale and precision
+ CheckOverflow(promotedExpr, DecimalType.bounded(intDig, 0), nullOnOverflow)
+ } else {
+ promotedExpr
+ }
+
+ case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
+ e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+ val resultType = widerDecimalType(p1, s1, p2, s2)
+ val newE1 = if (e1.dataType == resultType) e1 else Cast(e1, resultType)
+ val newE2 = if (e2.dataType == resultType) e2 else Cast(e2, resultType)
+ b.makeCopy(Array(newE1, newE2))
+ }
+
+ /**
+ * Strength reduction for comparing integral expressions with decimal literals.
+ * 1. int_col > decimal_literal => int_col > floor(decimal_literal)
+ * 2. int_col >= decimal_literal => int_col >= ceil(decimal_literal)
+ * 3. int_col < decimal_literal => int_col < ceil(decimal_literal)
+ * 4. int_col <= decimal_literal => int_col <= floor(decimal_literal)
+ * 5. decimal_literal > int_col => ceil(decimal_literal) > int_col
+ * 6. decimal_literal >= int_col => floor(decimal_literal) >= int_col
+ * 7. decimal_literal < int_col => floor(decimal_literal) < int_col
+ * 8. decimal_literal <= int_col => ceil(decimal_literal) <= int_col
+ *
+ * Note that technically this is an "optimization" and should go into the optimizer. However,
+ * by the time the optimizer runs, these comparison expressions would be pretty hard to pattern
+ * match because there are multiple (at least 2) levels of casts involved.
+ *
+ * There are a lot more possible rules we can implement, but we don't do them
+ * because we are not sure how common they are.
+ */
+ private val integralAndDecimalLiteral: PartialFunction[Expression, Expression] = {
+
+ case GreaterThan(i @ IntegralType(), DecimalLiteral(value)) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ TrueLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ FalseLiteral
+ } else {
+ GreaterThan(i, Literal(value.floor.toLong))
+ }
+
+ case GreaterThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ TrueLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ FalseLiteral
+ } else {
+ GreaterThanOrEqual(i, Literal(value.ceil.toLong))
+ }
+
+ case LessThan(i @ IntegralType(), DecimalLiteral(value)) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ FalseLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ TrueLiteral
+ } else {
+ LessThan(i, Literal(value.ceil.toLong))
+ }
+
+ case LessThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ FalseLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ TrueLiteral
+ } else {
+ LessThanOrEqual(i, Literal(value.floor.toLong))
+ }
+
+ case GreaterThan(DecimalLiteral(value), i @ IntegralType()) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ FalseLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ TrueLiteral
+ } else {
+ GreaterThan(Literal(value.ceil.toLong), i)
+ }
+
+ case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralType()) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ FalseLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ TrueLiteral
+ } else {
+ GreaterThanOrEqual(Literal(value.floor.toLong), i)
+ }
+
+ case LessThan(DecimalLiteral(value), i @ IntegralType()) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ TrueLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ FalseLiteral
+ } else {
+ LessThan(Literal(value.floor.toLong), i)
+ }
+
+ case LessThanOrEqual(DecimalLiteral(value), i @ IntegralType()) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ TrueLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ FalseLiteral
+ } else {
+ LessThanOrEqual(Literal(value.ceil.toLong), i)
+ }
+ }
+
+ /**
+ * Type coercion for BinaryOperator in which one side is a non-decimal numeric, and the other
+ * side is a decimal.
+ */
+ private def nondecimalAndDecimal(literalPickMinimumPrecision: Boolean)
+ : PartialFunction[Expression, Expression] = {
+ // Promote integers inside a binary expression with fixed-precision decimals to decimals,
+ // and fixed-precision decimals in an expression with floats / doubles to doubles
+ case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
+ (left, right) match {
+ // Promote literal integers inside a binary expression with fixed-precision decimals to
+ // decimals. The precision and scale are the ones strictly needed by the integer value.
+ // Requiring more precision than necessary may lead to a useless loss of precision.
+ // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2.
+ // If we use the default precision and scale for the integer type, 2 is considered a
+ // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18),
+ // which is out of range and therefore it will become DECIMAL(38, 7), leading to
+ // potentially loosing 11 digits of the fractional part. Using only the precision needed
+ // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would
+ // become DECIMAL(38, 16), safely having a much lower precision loss.
+ case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] &&
+ l.dataType.isInstanceOf[IntegralType] &&
+ literalPickMinimumPrecision =>
+ b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r))
+ case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] &&
+ r.dataType.isInstanceOf[IntegralType] &&
+ literalPickMinimumPrecision =>
+ b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r))))
+ // Promote integers inside a binary expression with fixed-precision decimals to decimals,
+ // and fixed-precision decimals in an expression with floats / doubles to doubles
+ case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) =>
+ b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r))
+ case (l @ DecimalType.Expression(_, _), r @ IntegralType()) =>
+ b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType))))
+ case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) =>
+ b.makeCopy(Array(l, Cast(r, DoubleType)))
+ case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) =>
+ b.makeCopy(Array(Cast(l, DoubleType), r))
+ case _ => b
+ }
+ }
+
+}
\ No newline at end of file
--
Gitee
From 6f5460f564982ca139b58f4ac125dfe22d1fb87c Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Thu, 17 Oct 2024 07:30:53 +0000
Subject: [PATCH 26/43] add decimalExpressions
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../expressions/decimalExpressions.scala | 234 ++++++++++++++++++
1 file changed, 234 insertions(+)
create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
new file mode 100644
index 000000000..ac02f03c2
--- /dev/null
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+/**
+ * Return the unscaled Long value of a Decimal, assuming it fits in a Long.
+ * Note: this expression is internal and created only by the optimizer,
+ * we don't need to do type check for it.
+ */
+case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant {
+
+ override def dataType: DataType = LongType
+ override def toString: String = s"UnscaledValue($child)"
+
+ protected override def nullSafeEval(input: Any): Any =
+ input.asInstanceOf[Decimal].toUnscaledLong
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()")
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): UnscaledValue =
+ copy(child = newChild)
+}
+
+/**
+ * Create a Decimal from an unscaled Long value.
+ * Note: this expression is internal and created only by the optimizer,
+ * we don't need to do type check for it.
+ */
+case class MakeDecimal(
+ child: Expression,
+ precision: Int,
+ scale: Int,
+ nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant {
+
+ def this(child: Expression, precision: Int, scale: Int) = {
+ this(child, precision, scale, !SQLConf.get.ansiEnabled)
+ }
+
+ override def dataType: DataType = DecimalType(precision, scale)
+ override def nullable: Boolean = child.nullable || nullOnOverflow
+ override def toString: String = s"MakeDecimal($child,$precision,$scale)"
+
+ protected override def nullSafeEval(input: Any): Any = {
+ val longInput = input.asInstanceOf[Long]
+ val result = new Decimal()
+ if (nullOnOverflow) {
+ result.setOrNull(longInput, precision, scale)
+ } else {
+ result.set(longInput, precision, scale)
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, eval => {
+ val setMethod = if (nullOnOverflow) {
+ "setOrNull"
+ } else {
+ "set"
+ }
+ val setNull = if (nullable) {
+ s"${ev.isNull} = ${ev.value} == null;"
+ } else {
+ ""
+ }
+ s"""
+ |${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale);
+ |$setNull
+ |""".stripMargin
+ })
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): MakeDecimal =
+ copy(child = newChild)
+}
+
+object MakeDecimal {
+ def apply(child: Expression, precision: Int, scale: Int): MakeDecimal = {
+ new MakeDecimal(child, precision, scale)
+ }
+}
+
+/**
+ * An expression used to wrap the children when promote the precision of DecimalType to avoid
+ * promote multiple times.
+ */
+case class PromotePrecision(child: Expression) extends UnaryExpression {
+ override def dataType: DataType = child.dataType
+ override def eval(input: InternalRow): Any = child.eval(input)
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ child.genCode(ctx)
+ override def prettyName: String = "promote_precision"
+ override def sql: String = child.sql
+ override lazy val canonicalized: Expression = child.canonicalized
+
+ override protected def withNewChildInternal(newChild: Expression): Expression =
+ copy(child = newChild)
+}
+
+/**
+ * Rounds the decimal to given scale and check whether the decimal can fit in provided precision
+ * or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an
+ * `ArithmeticException` is thrown.
+ */
+case class CheckOverflow(
+ child: Expression,
+ dataType: DecimalType,
+ nullOnOverflow: Boolean) extends UnaryExpression with SupportQueryContext {
+
+ override def nullable: Boolean = true
+
+ override def nullSafeEval(input: Any): Any =
+ input.asInstanceOf[Decimal].toPrecision(
+ dataType.precision,
+ dataType.scale,
+ Decimal.ROUND_HALF_UP,
+ nullOnOverflow,
+ queryContext)
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val errorContextCode = if (nullOnOverflow) {
+ "\"\""
+ } else {
+ ctx.addReferenceObj("errCtx", queryContext)
+ }
+ nullSafeCodeGen(ctx, ev, eval => {
+ // scalastyle:off line.size.limit
+ s"""
+ |${ev.value} = $eval.toPrecision(
+ | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode);
+ |${ev.isNull} = ${ev.value} == null;
+ """.stripMargin
+ // scalastyle:on line.size.limit
+ })
+ }
+
+ override def toString: String = s"CheckOverflow($child, $dataType)"
+
+ override def sql: String = child.sql
+
+ override protected def withNewChildInternal(newChild: Expression): CheckOverflow =
+ copy(child = newChild)
+
+ override def initQueryContext(): String = if (nullOnOverflow) {
+ ""
+ } else {
+ origin.context
+ }
+}
+
+// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`.
+case class CheckOverflowInSum(
+ child: Expression,
+ dataType: DecimalType,
+ nullOnOverflow: Boolean,
+ queryContext: String = "") extends UnaryExpression {
+
+ override def nullable: Boolean = true
+
+ override def eval(input: InternalRow): Any = {
+ val value = child.eval(input)
+ if (value == null) {
+ if (nullOnOverflow) null
+ else throw QueryExecutionErrors.overflowInSumOfDecimalError(queryContext)
+ } else {
+ value.asInstanceOf[Decimal].toPrecision(
+ dataType.precision,
+ dataType.scale,
+ Decimal.ROUND_HALF_UP,
+ nullOnOverflow,
+ queryContext)
+ }
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val childGen = child.genCode(ctx)
+ val errorContextCode = if (nullOnOverflow) {
+ "\"\""
+ } else {
+ ctx.addReferenceObj("errCtx", queryContext)
+ }
+ val nullHandling = if (nullOnOverflow) {
+ ""
+ } else {
+ s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);"
+ }
+ // scalastyle:off line.size.limit
+ val code = code"""
+ |${childGen.code}
+ |boolean ${ev.isNull} = ${childGen.isNull};
+ |Decimal ${ev.value} = null;
+ |if (${childGen.isNull}) {
+ | $nullHandling
+ |} else {
+ | ${ev.value} = ${childGen.value}.toPrecision(
+ | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode);
+ | ${ev.isNull} = ${ev.value} == null;
+ |}
+ |""".stripMargin
+ // scalastyle:on line.size.limit
+
+ ev.copy(code = code)
+ }
+
+ override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)"
+
+ override def sql: String = child.sql
+
+ override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum =
+ copy(child = newChild)
+}
\ No newline at end of file
--
Gitee
From 853985b46efab210acc49d04614113e3f7eead03 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Thu, 17 Oct 2024 07:48:33 +0000
Subject: [PATCH 27/43] update decimalExpressions
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../expressions/decimalExpressions.scala | 17 +++++++++--------
1 file changed, 9 insertions(+), 8 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index ac02f03c2..2b6d3ff33 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.SQLQueryContext
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -138,7 +139,7 @@ case class CheckOverflow(
dataType.scale,
Decimal.ROUND_HALF_UP,
nullOnOverflow,
- queryContext)
+ getContextOrNull())
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val errorContextCode = if (nullOnOverflow) {
@@ -164,10 +165,10 @@ case class CheckOverflow(
override protected def withNewChildInternal(newChild: Expression): CheckOverflow =
copy(child = newChild)
- override def initQueryContext(): String = if (nullOnOverflow) {
- ""
+ override def initQueryContext(): Option[SQLQueryContext] = if (nullOnOverflow) {
+ Some(origin.context)
} else {
- origin.context
+ None
}
}
@@ -176,7 +177,7 @@ case class CheckOverflowInSum(
child: Expression,
dataType: DecimalType,
nullOnOverflow: Boolean,
- queryContext: String = "") extends UnaryExpression {
+ context: SQLQueryContext) extends UnaryExpression {
override def nullable: Boolean = true
@@ -184,14 +185,14 @@ case class CheckOverflowInSum(
val value = child.eval(input)
if (value == null) {
if (nullOnOverflow) null
- else throw QueryExecutionErrors.overflowInSumOfDecimalError(queryContext)
+ else throw QueryExecutionErrors.overflowInSumOfDecimalError(context)
} else {
value.asInstanceOf[Decimal].toPrecision(
dataType.precision,
dataType.scale,
Decimal.ROUND_HALF_UP,
nullOnOverflow,
- queryContext)
+ context)
}
}
@@ -200,7 +201,7 @@ case class CheckOverflowInSum(
val errorContextCode = if (nullOnOverflow) {
"\"\""
} else {
- ctx.addReferenceObj("errCtx", queryContext)
+ ctx.addReferenceObj("errCtx", context)
}
val nullHandling = if (nullOnOverflow) {
""
--
Gitee
From c3519a4694f21d114b1fb16a5fc31a8c31c161fb Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Thu, 17 Oct 2024 08:19:51 +0000
Subject: [PATCH 28/43] update injectRuntimeFilter
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../optimizer/InjectRuntimeFilter.scala | 59 ++++---------------
1 file changed, 12 insertions(+), 47 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index e6c266242..b758b69e6 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -27,8 +27,6 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import com.huawei.boostkit.spark.ColumnarPluginConfig
-import scala.annotation.tailrec
-
/**
* Insert a filter on one side of the join if the other side has a selective predicate.
* The filter could be an IN subquery (converted to a semi join), a bloom filter, or something
@@ -87,12 +85,12 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
val bloomFilterAgg =
if (rowCount.isDefined && rowCount.get.longValue > 0L) {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)),
- rowCount.get.longValue)
+ Literal(rowCount.get.longValue)
} else {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)))
}
-
- val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")()
+ val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None)
+ val alias = Alias(aggExp, "bloomFilter")()
val aggregate =
ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan)))
val bloomFilterSubquery = if (canReuseExchange) {
@@ -114,8 +112,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType)
val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp)
val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)()
- val aggregate =
- ColumnPruning(Aggregate(Seq(filterCreationSideExp), Seq(alias), filterCreationSidePlan))
+ val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan)
if (!canBroadcastBySize(aggregate, conf)) {
// Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold,
// i.e., the semi-join will be a shuffled join, which is not worthwhile.
@@ -132,39 +129,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
* do not add a subquery that might have an expensive computation
*/
private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = {
- @tailrec
- def isSelective(
- p: LogicalPlan,
- predicateReference: AttributeSet,
- hasHitFilter: Boolean,
- hasHitSelectiveFilter: Boolean): Boolean = p match {
- case Project(projectList, child) =>
- if (hasHitFilter) {
- // We need to make sure all expressions referenced by filter predicates are simple
- // expressions.
- val referencedExprs = projectList.filter(predicateReference.contains)
- referencedExprs.forall(isSimpleExpression) &&
- isSelective(
- child,
- referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _),
- hasHitFilter,
- hasHitSelectiveFilter)
- } else {
- assert(predicateReference.isEmpty && !hasHitSelectiveFilter)
- isSelective(child, predicateReference, hasHitFilter, hasHitSelectiveFilter)
- }
- case Filter(condition, child) =>
- isSimpleExpression(condition) && isSelective(
- child,
- predicateReference ++ condition.references,
- hasHitFilter = true,
- hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition))
- case _: LeafNode => hasHitSelectiveFilter
+ val ret = plan match {
+ case PhysicalOperation(_, filters, child) if child.isInstance[LeafNode] =>
+ filters.forall(isSimpleExpression) &&
+ filters.exists(isLikelySelective)
case _ => false
}
-
- !plan.isStreaming &&
- isSelective(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false)
+ !plan.isStreaming && ret
}
private def isSimpleExpression(e: Expression): Boolean = {
@@ -213,8 +184,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
/**
* Check that:
- * - The filterApplicationSideJoinExp can be pushed down through joins, aggregates and windows
- * (ie the expression references originate from a single leaf node)
+ * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the
+ * expression references originate from a single leaf node)
* - The filter creation side has a selective predicate
* - The current join is a shuffle join or a broadcast join that has a shuffle below it
* - The max filterApplicationSide scan size is greater than a configurable threshold
@@ -328,13 +299,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
case s: Subquery if s.correlated => plan
case _ if !conf.runtimeFilterSemiJoinReductionEnabled &&
!conf.runtimeFilterBloomFilterEnabled => plan
- case _ =>
- val newPlan = tryInjectRuntimeFilter(plan)
- if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) {
- RewritePredicateSubquery(newPlan)
- } else {
- newPlan
- }
+ case _ => tryInjectRuntimeFilter(plan)
}
}
\ No newline at end of file
--
Gitee
From 6eeee26dc0015796f8a6ef7eec40af0d2bb4c3c9 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Thu, 17 Oct 2024 08:23:52 +0000
Subject: [PATCH 29/43] update TreePatterns
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../sql/catalyst/Tree/TreePatterns.scala | 30 +++++++++----------
1 file changed, 15 insertions(+), 15 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala
index ef17a0740..4d8c246ab 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala
@@ -25,7 +25,7 @@ object TreePattern extends Enumeration {
// Expression patterns (alphabetically ordered)
val AGGREGATE_EXPRESSION = Value(0)
val ALIAS: Value = Value
-// val AND_OR: Value = Value
+ val AND_OR: Value = Value
val AND: Value = Value
val ARRAYS_ZIP: Value = Value
val ATTRIBUTE_REFERENCE: Value = Value
@@ -59,7 +59,7 @@ object TreePattern extends Enumeration {
val JSON_TO_STRUCT: Value = Value
val LAMBDA_FUNCTION: Value = Value
val LAMBDA_VARIABLE: Value = Value
- val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value // spark3.4.3
+ val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value
val LATERAL_SUBQUERY: Value = Value
val LIKE_FAMLIY: Value = Value
val LIST_SUBQUERY: Value = Value
@@ -71,33 +71,33 @@ object TreePattern extends Enumeration {
val NULL_CHECK: Value = Value
val NULL_LITERAL: Value = Value
val SERIALIZE_FROM_OBJECT: Value = Value
- val OR: Value = Value // spark3.4.3
+ val OR: Value = Value
val OUTER_REFERENCE: Value = Value
- val PARAMETER: Value = Value // spark3.4.3
- val PARAMETERIZED_QUERY: Value = Value // spark3.4.3
+ val PARAMETER: Value = Value
+ val PARAMETERIZED_QUERY: Value = Value
val PIVOT: Value = Value
val PLAN_EXPRESSION: Value = Value
val PYTHON_UDF: Value = Value
val REGEXP_EXTRACT_FAMILY: Value = Value
val REGEXP_REPLACE: Value = Value
val RUNTIME_REPLACEABLE: Value = Value
- val RUNTIME_FILTER_EXPRESSION: Value = Value // spark3.4.3移除
- val RUNTIME_FILTER_SUBQUERY: Value = Value // spark3.4.3移除
+ val RUNTIME_FILTER_EXPRESSION: Value = Value
+ val RUNTIME_FILTER_SUBQUERY: Value = Value
val SCALAR_SUBQUERY: Value = Value
val SCALAR_SUBQUERY_REFERENCE: Value = Value
val SCALA_UDF: Value = Value
- val SESSION_WINDOW: Value = Value // spark3.4.3
+ val SESSION_WINDOW: Value = Value
val SORT: Value = Value
val SUBQUERY_ALIAS: Value = Value
- val SUBQUERY_WRAPPER: Value = Value // spark3.4.3移除
+ val SUBQUERY_WRAPPER: Value = Value
val SUM: Value = Value
val TIME_WINDOW: Value = Value
val TIME_ZONE_AWARE_EXPRESSION: Value = Value
val TRUE_OR_FALSE_LITERAL: Value = Value
val WINDOW_EXPRESSION: Value = Value
- val WINDOW_TIME: Value = Value // saprk3.4.3
+ val WINDOW_TIME: Value = Value
val UNARY_POSITIVE: Value = Value
- val UNPIVOT: Value = Value // spark3.4.3
+ val UNPIVOT: Value = Value
val UPDATE_FIELDS: Value = Value
val UPPER_OR_LOWER: Value = Value
val UP_CAST: Value = Value
@@ -127,7 +127,7 @@ object TreePattern extends Enumeration {
val UNION: Value = Value
val UNRESOLVED_RELATION: Value = Value
val UNRESOLVED_WITH: Value = Value
- val TEMP_RESOLVED_COLUMN: Value = Value // spark3.4.3
+ val TEMP_RESOLVED_COLUMN: Value = Value
val TYPED_FILTER: Value = Value
val WINDOW: Value = Value
val WITH_WINDOW_DEFINITION: Value = Value
@@ -136,7 +136,7 @@ object TreePattern extends Enumeration {
val UNRESOLVED_ALIAS: Value = Value
val UNRESOLVED_ATTRIBUTE: Value = Value
val UNRESOLVED_DESERIALIZER: Value = Value
- val UNRESOLVED_HAVING: Value = Value // spark3.4.3
+ val UNRESOLVED_HAVING: Value = Value
val UNRESOLVED_ORDINAL: Value = Value
val UNRESOLVED_FUNCTION: Value = Value
val UNRESOLVED_HINT: Value = Value
@@ -145,8 +145,8 @@ object TreePattern extends Enumeration {
// Unresolved Plan patterns (Alphabetically ordered)
val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value
val UNRESOLVED_FUNC: Value = Value
- val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value // spark3.4.3
- val UNRESOLVED_TVF_ALIASES: Value = Value // spark3.4.3
+ val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value
+ val UNRESOLVED_TVF_ALIASES: Value = Value
// Execution expression patterns (alphabetically ordered)
val IN_SUBQUERY_EXEC: Value = Value
--
Gitee
From cd9fe00c2c3ed2adb232da4d47a160965f4ce3d3 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Thu, 17 Oct 2024 08:31:04 +0000
Subject: [PATCH 30/43] update AdaptiveSparkPlanExec
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../adaptive/AdaptiveSparkPlanExec.scala | 24 ++++++-------------
1 file changed, 7 insertions(+), 17 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index e5c8fc0ab..b5d638138 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.adaptive
import java.util
import java.util.concurrent.LinkedBlockingQueue
+
import scala.collection.JavaConverters._
import scala.collection.concurrent.TrieMap
import scala.collection.mutable
import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal
+
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
@@ -33,7 +35,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
-import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
@@ -153,13 +154,7 @@ case class AdaptiveSparkPlanExec(
)
private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = {
- val rules = if (isFinalStage &&
- !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) {
- queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule])
- } else {
- queryStageOptimizerRules
- }
- val optimized = rules.foldLeft(plan) { case (latestPlan, rule) =>
+ val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) =>
val applied = rule.apply(latestPlan)
val result = rule match {
case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) =>
@@ -194,7 +189,7 @@ case class AdaptiveSparkPlanExec(
@volatile private var currentPhysicalPlan = initialPlan
- @volatile private var _isFinalPlan = false
+ volatile private var isFinalPlan = false
private var currentStageId = 0
@@ -211,8 +206,6 @@ case class AdaptiveSparkPlanExec(
def executedPlan: SparkPlan = currentPhysicalPlan
- def isFinalPlan: Boolean = _isFinalPlan
-
override def conf: SQLConf = context.session.sessionState.conf
override def output: Seq[Attribute] = inputPlan.output
@@ -231,8 +224,6 @@ case class AdaptiveSparkPlanExec(
.map(_.toLong).filter(SQLExecution.getQueryExecution(_) eq context.qe)
}
- def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity) // saprk 3.4.3
-
private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized {
if (isFinalPlan) return currentPhysicalPlan
@@ -320,8 +311,7 @@ case class AdaptiveSparkPlanExec(
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
if (newCost < origCost ||
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
- logOnLevel("Plan changed:\n" +
- sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n"))
+ logOnLevel("Plan changed from $currentPhysicalPlan to $newPhysicalPlan")
cleanUpTempTags(newPhysicalPlan)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
@@ -337,7 +327,7 @@ case class AdaptiveSparkPlanExec(
optimizeQueryStage(result.newPlan, isFinalStage = true),
postStageCreationRules(supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
- _isFinalPlan = true
+ isFinalPlan = true
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
currentPhysicalPlan
}
@@ -351,7 +341,7 @@ case class AdaptiveSparkPlanExec(
if (!isSubquery && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) {
getExecutionId.foreach(onUpdatePlan(_, Seq.empty))
}
- logOnLevel(s"Final plan: \n$currentPhysicalPlan")
+ logOnLevel(s"Final plan: $currentPhysicalPlan")
}
override def executeCollect(): Array[InternalRow] = {
--
Gitee
From cac18688e9ec1ed567dd8521b44f27c5bee02d2c Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Thu, 17 Oct 2024 08:35:09 +0000
Subject: [PATCH 31/43] update ColumnarBasicPhysicalOperators
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../spark/sql/execution/ColumnarBasicPhysicalOperators.scala | 2 --
1 file changed, 2 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
index 630034bd7..b6dd2ab51 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution
import java.util.concurrent.TimeUnit.NANOSECONDS
import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP
-
import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._
import com.huawei.boostkit.spark.util.OmniAdaptorUtil
import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs
@@ -43,7 +42,6 @@ import org.apache.spark.sql.types.{LongType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.sql.catalyst.plans.{AliasAwareOutputExpression, AliasAwareQueryOutputOrdering}
-
case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryExecNode
with AliasAwareOutputExpression
--
Gitee
From 5d435130ec8b3c618391d55d5ddcfc5d90e66263 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Thu, 17 Oct 2024 08:37:20 +0000
Subject: [PATCH 32/43] update QueryExecution
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../spark/sql/execution/QueryExecution.scala | 19 +------------------
1 file changed, 1 insertion(+), 18 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 37f3db5d6..e49c88191 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -105,23 +105,6 @@ class QueryExecution(
case other => other
}
- // The plan that has been normalized by custom rules, so that it's more likely to hit cache.
- lazy val normalized: LogicalPlan = {
- val normalizationRules = sparkSession.sessionState.planNormalizationRules
- if (normalizationRules.isEmpty) {
- commandExecuted
- } else {
- val planChangeLogger = new PlanChangeLogger[LogicalPlan]()
- val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) =>
- val result = rule.apply(p)
- planChangeLogger.logRule(rule.ruleName, p, result)
- result
- }
- planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized)
- normalized
- }
- } // Spark3.4.3
-
lazy val withCachedData: LogicalPlan = sparkSession.withActive {
assertAnalyzed()
assertSupported()
@@ -529,4 +512,4 @@ object QueryExecution {
case e: Throwable => throw toInternalError(msg, e)
}
}
-}
+}
\ No newline at end of file
--
Gitee
From a14e7de692a395ad21cabbcff492c03f84d1d301 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?=
<14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 14 Oct 2024 19:45:53 +0800
Subject: [PATCH 33/43] modify version and spark_version
modify SubqueryExpression parametre nums and add new method
add and remove Expression patterns
repair parametre nums
add normalized ,adjust parametres and modify method toInternalError
modify parametres
modify filePath
modify filePath and ORC predicate pushdown
modify parametres , remove promotePrecision, modify CastBase to Cast
modify parametres
modify parametres
modify injectBloomFilter,injectInsubqueryFilter,isSelectiveFilterOverScan,hasDynamicPruningSubquery,hasInSubquery,tryInJectRuntimeFilter
modify metastore.uris and hive.db
modify "test child max row" limit(0) to limit(1)
modify type of path
modify traits of ColumnarProjectExec
modify "SPARK-30953" ,"SPARK-31658" ADAPTIVE_EXECUTION_FORCE_APPLY,"SPARK-32932" ADAPTIVE_EXECUTION_ENABLED
modify spark exception
modify AQEOptimizer and optimizeQueryStage
---
.../omniop-spark-extension/java/pom.xml | 10 +--
.../expression/OmniExpressionAdaptor.scala | 12 ++--
.../sql/catalyst/Tree/TreePatterns.scala | 20 ++++--
.../catalyst/expressions/runtimefilter.scala | 4 +-
.../optimizer/InjectRuntimeFilter.scala | 67 ++++++++++++++-----
.../optimizer/MergeSubqueryFilters.scala | 2 +-
.../RewriteSelfJoinInInPredicate.scala | 2 +-
.../ColumnarBasicPhysicalOperators.scala | 12 ++--
.../ColumnarFileSourceScanExec.scala | 5 +-
.../spark/sql/execution/QueryExecution.scala | 30 +++++++--
.../adaptive/AdaptiveSparkPlanExec.scala | 32 ++++++---
.../adaptive/PlanAdaptiveSubqueries.scala | 6 +-
.../datasources/orc/OmniOrcFileFormat.scala | 30 ++++++---
.../parquet/OmniParquetFileFormat.scala | 2 +-
.../test/resources/HiveResource.properties | 6 +-
.../sql/catalyst/expressions/CastSuite.scala | 55 ++++++++++-----
.../optimizer/CombiningLimitsSuite.scala | 2 +-
.../optimizer/MergeSubqueryFiltersSuite.scala | 2 +-
.../ColumnarAdaptiveQueryExecSuite.scala | 47 +++++++------
omnioperator/omniop-spark-extension/pom.xml | 4 +-
20 files changed, 235 insertions(+), 115 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml
index 9cc1b9d25..a40b415fd 100644
--- a/omnioperator/omniop-spark-extension/java/pom.xml
+++ b/omnioperator/omniop-spark-extension/java/pom.xml
@@ -7,7 +7,7 @@
com.huawei.kunpeng
boostkit-omniop-spark-parent
- 3.3.1-1.6.0
+ 3.4.3-1.6.0
../pom.xml
@@ -52,7 +52,7 @@
com.huawei.boostkit
boostkit-omniop-native-reader
- 3.3.1-1.6.0
+ 3.4.3-1.6.0
junit
@@ -247,7 +247,7 @@
true
false
${project.basedir}/src/main/scala
- ${project.basedir}/src/test/scala
+ ${project.basedir}/src/test/scala
${user.dir}/scalastyle-config.xml
${project.basedir}/target/scalastyle-output.xml
${project.build.sourceEncoding}
@@ -335,7 +335,7 @@
org.scalatest
scalatest-maven-plugin
- false
+ false
${project.build.directory}/surefire-reports
.
@@ -352,4 +352,4 @@
-
\ No newline at end of file
+
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
index d1f911a8c..f7ed75fa3 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
@@ -76,7 +76,7 @@ object OmniExpressionAdaptor extends Logging {
}
}
- private def unsupportedCastCheck(expr: Expression, cast: CastBase): Unit = {
+ private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = {
def doSupportCastToString(dataType: DataType): Boolean = {
dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[IntegerType] ||
dataType.isInstanceOf[LongType] || dataType.isInstanceOf[DateType] || dataType.isInstanceOf[DoubleType] ||
@@ -160,8 +160,8 @@ object OmniExpressionAdaptor extends Logging {
throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}")
}
- case promotePrecision: PromotePrecision =>
- rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap)
+// case promotePrecision: PromotePrecision =>
+// rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap)
case sub: Subtract =>
new JSONObject().put("exprType", "BINARY")
@@ -296,7 +296,7 @@ object OmniExpressionAdaptor extends Logging {
.put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.len, exprsIndexMap)))
// Cast
- case cast: CastBase =>
+ case cast: Cast =>
unsupportedCastCheck(expr, cast)
cast.child.dataType match {
case NullType =>
@@ -588,10 +588,10 @@ object OmniExpressionAdaptor extends Logging {
rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap)
} else {
children.head match {
- case base: CastBase if base.child.dataType.isInstanceOf[NullType] =>
+ case base: Cast if base.child.dataType.isInstanceOf[NullType] =>
rewriteToOmniJsonExpressionLiteralJsonObject(children(1), exprsIndexMap)
case _ => children(1) match {
- case base: CastBase if base.child.dataType.isInstanceOf[NullType] =>
+ case base: Cast if base.child.dataType.isInstanceOf[NullType] =>
rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap)
case _ =>
new JSONObject().put("exprType", "FUNCTION")
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala
index ea2712447..ef17a0740 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala
@@ -25,7 +25,8 @@ object TreePattern extends Enumeration {
// Expression patterns (alphabetically ordered)
val AGGREGATE_EXPRESSION = Value(0)
val ALIAS: Value = Value
- val AND_OR: Value = Value
+// val AND_OR: Value = Value
+ val AND: Value = Value
val ARRAYS_ZIP: Value = Value
val ATTRIBUTE_REFERENCE: Value = Value
val APPEND_COLUMNS: Value = Value
@@ -58,6 +59,7 @@ object TreePattern extends Enumeration {
val JSON_TO_STRUCT: Value = Value
val LAMBDA_FUNCTION: Value = Value
val LAMBDA_VARIABLE: Value = Value
+ val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value // spark3.4.3
val LATERAL_SUBQUERY: Value = Value
val LIKE_FAMLIY: Value = Value
val LIST_SUBQUERY: Value = Value
@@ -69,27 +71,33 @@ object TreePattern extends Enumeration {
val NULL_CHECK: Value = Value
val NULL_LITERAL: Value = Value
val SERIALIZE_FROM_OBJECT: Value = Value
+ val OR: Value = Value // spark3.4.3
val OUTER_REFERENCE: Value = Value
+ val PARAMETER: Value = Value // spark3.4.3
+ val PARAMETERIZED_QUERY: Value = Value // spark3.4.3
val PIVOT: Value = Value
val PLAN_EXPRESSION: Value = Value
val PYTHON_UDF: Value = Value
val REGEXP_EXTRACT_FAMILY: Value = Value
val REGEXP_REPLACE: Value = Value
val RUNTIME_REPLACEABLE: Value = Value
- val RUNTIME_FILTER_EXPRESSION: Value = Value
- val RUNTIME_FILTER_SUBQUERY: Value = Value
+ val RUNTIME_FILTER_EXPRESSION: Value = Value // spark3.4.3移除
+ val RUNTIME_FILTER_SUBQUERY: Value = Value // spark3.4.3移除
val SCALAR_SUBQUERY: Value = Value
val SCALAR_SUBQUERY_REFERENCE: Value = Value
val SCALA_UDF: Value = Value
+ val SESSION_WINDOW: Value = Value // spark3.4.3
val SORT: Value = Value
val SUBQUERY_ALIAS: Value = Value
- val SUBQUERY_WRAPPER: Value = Value
+ val SUBQUERY_WRAPPER: Value = Value // spark3.4.3移除
val SUM: Value = Value
val TIME_WINDOW: Value = Value
val TIME_ZONE_AWARE_EXPRESSION: Value = Value
val TRUE_OR_FALSE_LITERAL: Value = Value
val WINDOW_EXPRESSION: Value = Value
+ val WINDOW_TIME: Value = Value // saprk3.4.3
val UNARY_POSITIVE: Value = Value
+ val UNPIVOT: Value = Value // spark3.4.3
val UPDATE_FIELDS: Value = Value
val UPPER_OR_LOWER: Value = Value
val UP_CAST: Value = Value
@@ -119,6 +127,7 @@ object TreePattern extends Enumeration {
val UNION: Value = Value
val UNRESOLVED_RELATION: Value = Value
val UNRESOLVED_WITH: Value = Value
+ val TEMP_RESOLVED_COLUMN: Value = Value // spark3.4.3
val TYPED_FILTER: Value = Value
val WINDOW: Value = Value
val WITH_WINDOW_DEFINITION: Value = Value
@@ -127,6 +136,7 @@ object TreePattern extends Enumeration {
val UNRESOLVED_ALIAS: Value = Value
val UNRESOLVED_ATTRIBUTE: Value = Value
val UNRESOLVED_DESERIALIZER: Value = Value
+ val UNRESOLVED_HAVING: Value = Value // spark3.4.3
val UNRESOLVED_ORDINAL: Value = Value
val UNRESOLVED_FUNCTION: Value = Value
val UNRESOLVED_HINT: Value = Value
@@ -135,6 +145,8 @@ object TreePattern extends Enumeration {
// Unresolved Plan patterns (Alphabetically ordered)
val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value
val UNRESOLVED_FUNC: Value = Value
+ val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value // spark3.4.3
+ val UNRESOLVED_TVF_ALIASES: Value = Value // spark3.4.3
// Execution expression patterns (alphabetically ordered)
val IN_SUBQUERY_EXEC: Value = Value
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala
index 0a5d509b0..85192fc36 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala
@@ -39,7 +39,7 @@ case class RuntimeFilterSubquery(
exprId: ExprId = NamedExpression.newExprId,
hint: Option[HintInfo] = None)
extends SubqueryExpression(
- filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty)
+ filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty, hint)
with Unevaluable
with UnaryLike[Expression] {
@@ -74,6 +74,8 @@ case class RuntimeFilterSubquery(
override protected def withNewChildInternal(newChild: Expression): RuntimeFilterSubquery =
copy(filterApplicationSideExp = newChild)
+
+ override def withNewHint(hint: Option[HintInfo]): RuntimeFilterSubquery = copy(hint = hint)
}
/**
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index 812c387bc..e6c266242 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import com.huawei.boostkit.spark.ColumnarPluginConfig
+import scala.annotation.tailrec
+
/**
* Insert a filter on one side of the join if the other side has a selective predicate.
* The filter could be an IN subquery (converted to a semi join), a bloom filter, or something
@@ -85,12 +87,12 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
val bloomFilterAgg =
if (rowCount.isDefined && rowCount.get.longValue > 0L) {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)),
- Literal(rowCount.get.longValue))
+ rowCount.get.longValue)
} else {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)))
}
- val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None)
- val alias = Alias(aggExp, "bloomFilter")()
+
+ val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")()
val aggregate =
ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan)))
val bloomFilterSubquery = if (canReuseExchange) {
@@ -112,7 +114,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType)
val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp)
val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)()
- val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan)
+ val aggregate =
+ ColumnPruning(Aggregate(Seq(filterCreationSideExp), Seq(alias), filterCreationSidePlan))
if (!canBroadcastBySize(aggregate, conf)) {
// Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold,
// i.e., the semi-join will be a shuffled join, which is not worthwhile.
@@ -129,13 +132,39 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
* do not add a subquery that might have an expensive computation
*/
private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = {
- val ret = plan match {
- case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] =>
- filters.forall(isSimpleExpression) &&
- filters.exists(isLikelySelective)
+ @tailrec
+ def isSelective(
+ p: LogicalPlan,
+ predicateReference: AttributeSet,
+ hasHitFilter: Boolean,
+ hasHitSelectiveFilter: Boolean): Boolean = p match {
+ case Project(projectList, child) =>
+ if (hasHitFilter) {
+ // We need to make sure all expressions referenced by filter predicates are simple
+ // expressions.
+ val referencedExprs = projectList.filter(predicateReference.contains)
+ referencedExprs.forall(isSimpleExpression) &&
+ isSelective(
+ child,
+ referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _),
+ hasHitFilter,
+ hasHitSelectiveFilter)
+ } else {
+ assert(predicateReference.isEmpty && !hasHitSelectiveFilter)
+ isSelective(child, predicateReference, hasHitFilter, hasHitSelectiveFilter)
+ }
+ case Filter(condition, child) =>
+ isSimpleExpression(condition) && isSelective(
+ child,
+ predicateReference ++ condition.references,
+ hasHitFilter = true,
+ hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition))
+ case _: LeafNode => hasHitSelectiveFilter
case _ => false
}
- !plan.isStreaming && ret
+
+ !plan.isStreaming &&
+ isSelective(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false)
}
private def isSimpleExpression(e: Expression): Boolean = {
@@ -184,8 +213,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
/**
* Check that:
- * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the
- * expression references originate from a single leaf node)
+ * - The filterApplicationSideJoinExp can be pushed down through joins, aggregates and windows
+ * (ie the expression references originate from a single leaf node)
* - The filter creation side has a selective predicate
* - The current join is a shuffle join or a broadcast join that has a shuffle below it
* - The max filterApplicationSide scan size is greater than a configurable threshold
@@ -218,9 +247,9 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
leftKey: Expression,
rightKey: Expression): Boolean = {
(left, right) match {
- case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) =>
+ case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan), _) =>
pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey)
- case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) =>
+ case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan)) =>
pruningKey.fastEquals(rightKey) ||
hasDynamicPruningSubquery(left, plan, leftKey, rightKey)
case _ => false
@@ -251,10 +280,10 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
rightKey: Expression): Boolean = {
(left, right) match {
case (Filter(InSubquery(Seq(key),
- ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) =>
+ ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _), _) =>
key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey)))
case (_, Filter(InSubquery(Seq(key),
- ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) =>
+ ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _)) =>
key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey)))
case _ => false
}
@@ -299,7 +328,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
case s: Subquery if s.correlated => plan
case _ if !conf.runtimeFilterSemiJoinReductionEnabled &&
!conf.runtimeFilterBloomFilterEnabled => plan
- case _ => tryInjectRuntimeFilter(plan)
+ case _ =>
+ val newPlan = tryInjectRuntimeFilter(plan)
+ if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) {
+ RewritePredicateSubquery(newPlan)
+ } else {
+ newPlan
+ }
}
}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala
index 1b5baa230..c4435379f 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala
@@ -643,7 +643,7 @@ object MergeSubqueryFilters extends Rule[LogicalPlan] {
val subqueryCTE = header.plan.asInstanceOf[CTERelationDef]
GetStructField(
ScalarSubquery(
- CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output),
+ CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output, subqueryCTE.isStreaming),
exprId = ssr.exprId),
ssr.headerIndex)
} else {
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala
index 9e4029025..f6ebd716d 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala
@@ -61,7 +61,7 @@ object RewriteSelfJoinInInPredicate extends Rule[LogicalPlan] with PredicateHelp
case f: Filter =>
f transformExpressions {
case in @ InSubquery(_, listQuery @ ListQuery(Project(projectList,
- Join(left, right, Inner, Some(joinCond), _)), _, _, _, _))
+ Join(left, right, Inner, Some(joinCond), _)), _, _, _, _, _))
if left.canonicalized ne right.canonicalized =>
val attrMapping = AttributeMap(right.output.zip(left.output))
val subCondExprs = splitConjunctivePredicates(joinCond transform {
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
index 486369843..630034bd7 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
@@ -16,10 +16,10 @@
*/
package org.apache.spark.sql.execution
-
import java.util.concurrent.TimeUnit.NANOSECONDS
import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP
+
import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._
import com.huawei.boostkit.spark.util.OmniAdaptorUtil
import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs
@@ -41,11 +41,13 @@ import org.apache.spark.sql.execution.vectorized.OmniColumnVector
import org.apache.spark.sql.expression.ColumnarExpressionConverter
import org.apache.spark.sql.types.{LongType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.sql.catalyst.plans.{AliasAwareOutputExpression, AliasAwareQueryOutputOrdering}
+
case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryExecNode
- with AliasAwareOutputPartitioning
- with AliasAwareOutputOrdering {
+ with AliasAwareOutputExpression
+ with AliasAwareQueryOutputOrdering[SparkPlan] {
override def supportsColumnar: Boolean = true
@@ -267,8 +269,8 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression],
condition: Expression,
child: SparkPlan)
extends UnaryExecNode
- with AliasAwareOutputPartitioning
- with AliasAwareOutputOrdering {
+ with AliasAwareOutputExpression
+ with AliasAwareQueryOutputOrdering[SparkPlan] {
override def supportsColumnar: Boolean = true
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala
index 800dcf1a0..4b70db452 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala
@@ -479,10 +479,11 @@ abstract class BaseColumnarFileSourceScanExec(
}
}.groupBy { f =>
BucketingUtils
- .getBucketId(new Path(f.filePath).getName)
- .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.filePath))
+ .getBucketId(f.toPath.getName)
+ .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath))
}
+
val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
val bucketSet = optionalBucketSet.get
filesGroupedToBuckets.filter {
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index ef33a84de..37f3db5d6 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -105,6 +105,23 @@ class QueryExecution(
case other => other
}
+ // The plan that has been normalized by custom rules, so that it's more likely to hit cache.
+ lazy val normalized: LogicalPlan = {
+ val normalizationRules = sparkSession.sessionState.planNormalizationRules
+ if (normalizationRules.isEmpty) {
+ commandExecuted
+ } else {
+ val planChangeLogger = new PlanChangeLogger[LogicalPlan]()
+ val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) =>
+ val result = rule.apply(p)
+ planChangeLogger.logRule(rule.ruleName, p, result)
+ result
+ }
+ planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized)
+ normalized
+ }
+ } // Spark3.4.3
+
lazy val withCachedData: LogicalPlan = sparkSession.withActive {
assertAnalyzed()
assertSupported()
@@ -227,7 +244,7 @@ class QueryExecution(
// output mode does not matter since there is no `Sink`.
new IncrementalExecution(
sparkSession, logical, OutputMode.Append(), "",
- UUID.randomUUID, UUID.randomUUID, 0, OffsetSeqMetadata(0, 0))
+ UUID.randomUUID, UUID.randomUUID, 0, None ,OffsetSeqMetadata(0, 0))
} else {
this
}
@@ -494,11 +511,10 @@ object QueryExecution {
*/
private[sql] def toInternalError(msg: String, e: Throwable): Throwable = e match {
case e @ (_: java.lang.NullPointerException | _: java.lang.AssertionError) =>
- new SparkException(
- errorClass = "INTERNAL_ERROR",
- messageParameters = Array(msg +
- " Please, fill a bug report in, and provide the full stack trace."),
- cause = e)
+ SparkException.internalError(
+ msg + " You hit a bug in Spark or the Spark plugins you use. Please, report this bug " +
+ "to the corresponding communities or vendors, and provide the full stack trace.",
+ e)
case e: Throwable =>
e
}
@@ -513,4 +529,4 @@ object QueryExecution {
case e: Throwable => throw toInternalError(msg, e)
}
}
-}
\ No newline at end of file
+}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 2599a1410..e5c8fc0ab 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -19,13 +19,11 @@ package org.apache.spark.sql.execution.adaptive
import java.util
import java.util.concurrent.LinkedBlockingQueue
-
import scala.collection.JavaConverters._
import scala.collection.concurrent.TrieMap
import scala.collection.mutable
import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal
-
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
@@ -35,12 +33,13 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
+import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan
import org.apache.spark.sql.execution.exchange._
-import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric}
+import org.apache.spark.sql.execution.ui.{SQLPlanMetric, SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{SparkFatalException, ThreadUtils}
@@ -83,7 +82,9 @@ case class AdaptiveSparkPlanExec(
@transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
// The logical plan optimizer for re-optimizing the current logical plan.
- @transient private val optimizer = new AQEOptimizer(conf)
+ @transient private val optimizer = new AQEOptimizer(conf,
+ session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules
+ )
// `EnsureRequirements` may remove user-specified repartition and assume the query plan won't
// change its output partitioning. This assumption is not true in AQE. Here we check the
@@ -122,7 +123,7 @@ case class AdaptiveSparkPlanExec(
RemoveRedundantSorts,
DisableUnnecessaryBucketedScan,
OptimizeSkewedJoin(ensureRequirements)
- ) ++ context.session.sessionState.queryStagePrepRules
+ ) ++ context.session.sessionState.adaptiveRulesHolder.queryStagePrepRules
}
// A list of physical optimizer rules to be applied to a new stage before its execution. These
@@ -152,7 +153,13 @@ case class AdaptiveSparkPlanExec(
)
private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = {
- val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) =>
+ val rules = if (isFinalStage &&
+ !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) {
+ queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule])
+ } else {
+ queryStageOptimizerRules
+ }
+ val optimized = rules.foldLeft(plan) { case (latestPlan, rule) =>
val applied = rule.apply(latestPlan)
val result = rule match {
case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) =>
@@ -187,7 +194,7 @@ case class AdaptiveSparkPlanExec(
@volatile private var currentPhysicalPlan = initialPlan
- private var isFinalPlan = false
+ @volatile private var _isFinalPlan = false
private var currentStageId = 0
@@ -204,6 +211,8 @@ case class AdaptiveSparkPlanExec(
def executedPlan: SparkPlan = currentPhysicalPlan
+ def isFinalPlan: Boolean = _isFinalPlan
+
override def conf: SQLConf = context.session.sessionState.conf
override def output: Seq[Attribute] = inputPlan.output
@@ -222,6 +231,8 @@ case class AdaptiveSparkPlanExec(
.map(_.toLong).filter(SQLExecution.getQueryExecution(_) eq context.qe)
}
+ def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity) // saprk 3.4.3
+
private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized {
if (isFinalPlan) return currentPhysicalPlan
@@ -309,7 +320,8 @@ case class AdaptiveSparkPlanExec(
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
if (newCost < origCost ||
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
- logOnLevel(s"Plan changed from $currentPhysicalPlan to $newPhysicalPlan")
+ logOnLevel("Plan changed:\n" +
+ sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n"))
cleanUpTempTags(newPhysicalPlan)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
@@ -325,7 +337,7 @@ case class AdaptiveSparkPlanExec(
optimizeQueryStage(result.newPlan, isFinalStage = true),
postStageCreationRules(supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
- isFinalPlan = true
+ _isFinalPlan = true
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
currentPhysicalPlan
}
@@ -339,7 +351,7 @@ case class AdaptiveSparkPlanExec(
if (!isSubquery && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) {
getExecutionId.foreach(onUpdatePlan(_, Seq.empty))
}
- logOnLevel(s"Final plan: $currentPhysicalPlan")
+ logOnLevel(s"Final plan: \n$currentPhysicalPlan")
}
override def executeCollect(): Array[InternalRow] = {
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
index b5a1ad375..dfdbe2c70 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
@@ -30,11 +30,11 @@ case class PlanAdaptiveSubqueries(
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(
SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY, RUNTIME_FILTER_SUBQUERY)) {
- case expressions.ScalarSubquery(_, _, exprId, _) =>
+ case expressions.ScalarSubquery(_, _, exprId, _, _, _) =>
val subquery = SubqueryExec.createForScalarSubquery(
s"subquery#${exprId.id}", subqueryMap(exprId.id))
execution.ScalarSubquery(subquery, exprId)
- case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _)) =>
+ case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) =>
val expr = if (values.length == 1) {
values.head
} else {
@@ -47,7 +47,7 @@ case class PlanAdaptiveSubqueries(
val subquery = SubqueryExec(s"subquery#${exprId.id}", subqueryMap(exprId.id))
InSubqueryExec(expr, subquery, exprId, shouldBroadcast = true)
case expressions.DynamicPruningSubquery(value, buildPlan,
- buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) =>
+ buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) =>
val name = s"dynamicpruning#${exprId.id}"
val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndex, onlyInBroadcast,
buildPlan, buildKeys, subqueryMap(exprId.id))
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala
index 334800f51..807369004 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala
@@ -118,17 +118,27 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ
(file: PartitionedFile) => {
val conf = broadcastedConf.value.value
- val filePath = new Path(new URI(file.filePath))
- val isPPDSafeValue = isPPDSafe(filters, dataSchema).reduceOption(_ && _)
+// val filePath = new Path(new URI(file.filePath.urlEncoded))
+ val filePath = file.toPath
- // ORC predicate pushdown
- if (orcFilterPushDown && filters.nonEmpty && isPPDSafeValue.getOrElse(false)) {
- OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach {
- fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f =>
- OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
- }
- }
- }
+ val fs = filePath.getFileSystem(conf)
+ val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
+ val orcSchema =
+ Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions))(_.getSchema)
+ val isPPDSafeValue = isPPDSafe(filters, dataSchema).reduceOption(_ && _)
+ val resultedColPruneInfo = OrcUtils.requestedColumnIds(
+ isCaseSensitive, dataSchema, requiredSchema, orcSchema, conf)
+
+
+ // ORC predicate pushdown
+ if (orcFilterPushDown && filters.nonEmpty && isPPDSafeValue.getOrElse(false)) {
+ val fileSchema = OrcUtils.toCatalystSchema(orcSchema)
+ // OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { fileSchema =>
+ OrcFilters.createFilter(fileSchema, filters).foreach { f =>
+ OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
+ // }
+ }
+ }
val taskConf = new Configuration(conf)
val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala
index 9504b34d1..3d61f873d 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala
@@ -80,7 +80,7 @@ class OmniParquetFileFormat extends FileFormat with DataSourceRegister with Logg
(file: PartitionedFile) => {
assert(file.partitionValues.numFields == partitionSchema.size)
- val filePath = new Path(new URI(file.filePath))
+ val filePath = file.toPath
val split =
new org.apache.parquet.hadoop.ParquetInputSplit(
filePath,
diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties
index 89eabe8e6..099e28e8d 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties
+++ b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties
@@ -2,11 +2,13 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
#
-hive.metastore.uris=thrift://server1:9083
+#hive.metastore.uris=thrift://server1:9083
+hive.metastore.uris=thrift://OmniOperator:9083
spark.sql.warehouse.dir=/user/hive/warehouse
spark.memory.offHeap.size=8G
spark.sql.codegen.wholeStage=false
spark.sql.extensions=com.huawei.boostkit.spark.ColumnarPlugin
spark.shuffle.manager=org.apache.spark.shuffle.sort.OmniColumnarShuffleManager
spark.sql.orc.impl=native
-hive.db=tpcds_bin_partitioned_orc_2
\ No newline at end of file
+#hive.db=tpcds_bin_partitioned_orc_2
+hive.db=tpcds_bin_partitioned_varchar_orc_2
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index e6b786c2a..329295bac 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.SparkException
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.execution.ColumnarSparkPlanTest
import org.apache.spark.sql.types.{DataType, Decimal}
@@ -64,7 +65,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception = intercept[Exception](
result.collect().toSeq.head.getByte(0)
)
- assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast null as short") {
@@ -72,7 +74,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception = intercept[Exception](
result.collect().toSeq.head.getShort(0)
)
- assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast null as int") {
@@ -80,7 +83,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception = intercept[Exception](
result.collect().toSeq.head.getInt(0)
)
- assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast null as long") {
@@ -88,7 +92,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception = intercept[Exception](
result.collect().toSeq.head.getLong(0)
)
- assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast null as float") {
@@ -96,7 +101,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception = intercept[Exception](
result.collect().toSeq.head.getFloat(0)
)
- assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast null as double") {
@@ -104,7 +110,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception = intercept[Exception](
result.collect().toSeq.head.getDouble(0)
)
- assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast null as date") {
@@ -154,13 +161,15 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception4 = intercept[Exception](
result4.collect().toSeq.head.getBoolean(0)
)
- assert(exception4.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception4.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception4.isInstanceOf[SparkException], s"sql: ${sql}")
val result5 = spark.sql("select cast('test' as boolean);")
val exception5 = intercept[Exception](
result5.collect().toSeq.head.getBoolean(0)
)
- assert(exception5.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception5.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception5.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast boolean to string") {
@@ -182,13 +191,15 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception2 = intercept[Exception](
result2.collect().toSeq.head.getByte(0)
)
- assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}")
val result3 = spark.sql("select cast('false' as byte);")
val exception3 = intercept[Exception](
result3.collect().toSeq.head.getByte(0)
)
- assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast byte to string") {
@@ -210,13 +221,15 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception2 = intercept[Exception](
result2.collect().toSeq.head.getShort(0)
)
- assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}")
val result3 = spark.sql("select cast('false' as short);")
val exception3 = intercept[Exception](
result3.collect().toSeq.head.getShort(0)
)
- assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast short to string") {
@@ -238,13 +251,15 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception2 = intercept[Exception](
result2.collect().toSeq.head.getInt(0)
)
- assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}")
val result3 = spark.sql("select cast('false' as int);")
val exception3 = intercept[Exception](
result3.collect().toSeq.head.getInt(0)
)
- assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast int to string") {
@@ -266,13 +281,15 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception2 = intercept[Exception](
result2.collect().toSeq.head.getLong(0)
)
- assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}")
val result3 = spark.sql("select cast('false' as long);")
val exception3 = intercept[Exception](
result3.collect().toSeq.head.getLong(0)
)
- assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast long to string") {
@@ -298,7 +315,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception3 = intercept[Exception](
result3.collect().toSeq.head.getFloat(0)
)
- assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast float to string") {
@@ -324,7 +342,8 @@ class CastSuite extends ColumnarSparkPlanTest {
val exception3 = intercept[Exception](
result3.collect().toSeq.head.getDouble(0)
)
- assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}")
+ assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}")
}
test("cast double to string") {
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
index f83edb9ca..ea52aca62 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -61,7 +61,7 @@ class CombiningLimitsSuite extends PlanTest {
comparePlans(optimized1, expected1)
// test child max row > limit.
- val query2 = testRelation.select().groupBy()(count(1)).limit(0).analyze
+ val query2 = testRelation2.select($"x").groupBy($"x")(count(1)).limit(1).analyze
val optimized2 = Optimize.execute(query2)
comparePlans(optimized2, query2)
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala
index aaa244cdf..e1c620e1c 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala
@@ -43,7 +43,7 @@ class MergeSubqueryFiltersSuite extends PlanTest {
}
private def extractorExpression(cteIndex: Int, output: Seq[Attribute], fieldIndex: Int) = {
- GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output)), fieldIndex)
+ GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output, isStreaming = false)), fieldIndex)
.as("scalarsubquery()")
}
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala
index c34ff5bb1..c0be72f31 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala
@@ -19,17 +19,15 @@ package org.apache.spark.sql.execution.adaptive
import java.io.File
import java.net.URI
-
import org.apache.logging.log4j.Level
import org.scalatest.PrivateMethodTester
import org.scalatest.time.SpanSugar._
-
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
-import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec}
+import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
@@ -37,7 +35,7 @@ import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
-import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
+import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLExecutionStart}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
@@ -1122,13 +1120,21 @@ class AdaptiveQueryExecSuite
test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
+ var plan: SparkPlan = null
+ val listener = new QueryExecutionListener {
+ override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
+ plan = qe.executedPlan
+ }
+ override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
+ }
+ spark.listenerManager.register(listener)
withTable("t1") {
- val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan
- assert(plan.isInstanceOf[CommandResultExec])
- val commandResultExec = plan.asInstanceOf[CommandResultExec]
- assert(commandResultExec.commandPhysicalPlan.isInstanceOf[DataWritingCommandExec])
- assert(commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec]
- .child.isInstanceOf[AdaptiveSparkPlanExec])
+ val format = classOf[NoopDataSource].getName
+ Seq((0, 1)).toDF("x", "y").write.format(format).mode("overwrite").save()
+ sparkContext.listenerBus.waitUntilEmpty()
+ assert(plan.isInstanceOf[V2TableWriteExec])
+ assert(plan.asInstanceOf[V2TableWriteExec].child.isInstanceOf[AdaptiveSparkPlanExec])
+ spark.listenerManager.unregister(listener)
}
}
}
@@ -1172,15 +1178,14 @@ class AdaptiveQueryExecSuite
test("SPARK-31658: SQL UI should show write commands") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
+ SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "false") {
withTable("t1") {
- var checkDone = false
+ var commands: Seq[SparkPlanInfo] = Seq.empty
val listener = new SparkListener {
override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
- case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) =>
- assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand")
- checkDone = true
+ case start: SparkListenerSQLExecutionStart =>
+ commands = commands ++ Seq(start.sparkPlanInfo)
case _ => // ignore other events
}
}
@@ -1189,7 +1194,12 @@ class AdaptiveQueryExecSuite
try {
sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect()
spark.sparkContext.listenerBus.waitUntilEmpty()
- assert(checkDone)
+ assert(commands.size == 3)
+ assert(commands.head.nodeName == "Execute CreateDataSourceTableAsSelectCommand")
+ assert(commands(1).nodeName == "Execute InsertIntoHadoopFsRelationCommand")
+ assert(commands(1).children.size == 1)
+ assert(commands(1).children.head.nodeName == "WriteFiles")
+ assert(commands(2).nodeName == "CommandResult")
} finally {
spark.sparkContext.removeSparkListener(listener)
}
@@ -1574,7 +1584,7 @@ class AdaptiveQueryExecSuite
test("SPARK-32932: Do not use local shuffle read at final stage on write command") {
withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString,
SQLConf.SHUFFLE_PARTITIONS.key -> "5",
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
val data = for (
i <- 1L to 10L;
j <- 1L to 3L
@@ -1584,9 +1594,8 @@ class AdaptiveQueryExecSuite
var noLocalread: Boolean = false
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
- qe.executedPlan match {
+ stripAQEPlan(qe.executedPlan) match {
case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) =>
- assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec])
noLocalread = collect(plan) {
case exec: AQEShuffleReadExec if exec.isLocalRead => exec
}.isEmpty
diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml
index 4376a89be..e949cea2d 100644
--- a/omnioperator/omniop-spark-extension/pom.xml
+++ b/omnioperator/omniop-spark-extension/pom.xml
@@ -8,13 +8,13 @@
com.huawei.kunpeng
boostkit-omniop-spark-parent
pom
- 3.3.1-1.6.0
+ 3.4.3-1.6.0
BoostKit Spark Native Sql Engine Extension Parent Pom
2.12.10
2.12
- 3.3.1
+ 3.4.3
3.2.2
UTF-8
UTF-8
--
Gitee
From f12d2c02f024f683fe998ea31042ff35dfefd573 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Thu, 17 Oct 2024 06:58:44 +0000
Subject: [PATCH 34/43] update omnioperator/omniop-native-reader/java/pom.xml.
modify spark_version and logger dependency
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../omniop-native-reader/java/pom.xml | 1 +
.../expression/OmniExpressionAdaptor.scala | 4 +-
.../sql/catalyst/Tree/TreePatterns.scala | 30 +-
.../catalyst/analysis/DecimalPrecision.scala | 347 ++++++++++++++++++
.../expressions/decimalExpressions.scala | 235 ++++++++++++
.../optimizer/InjectRuntimeFilter.scala | 59 +--
.../ColumnarBasicPhysicalOperators.scala | 2 -
.../spark/sql/execution/QueryExecution.scala | 19 +-
.../adaptive/AdaptiveSparkPlanExec.scala | 24 +-
9 files changed, 620 insertions(+), 101 deletions(-)
create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
diff --git a/omnioperator/omniop-native-reader/java/pom.xml b/omnioperator/omniop-native-reader/java/pom.xml
index bdad33a3a..773ed2c39 100644
--- a/omnioperator/omniop-native-reader/java/pom.xml
+++ b/omnioperator/omniop-native-reader/java/pom.xml
@@ -34,6 +34,7 @@
1.6.0
+
org.slf4j
slf4j-simple
1.7.36
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
index f7ed75fa3..be1538172 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
@@ -160,8 +160,8 @@ object OmniExpressionAdaptor extends Logging {
throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}")
}
-// case promotePrecision: PromotePrecision =>
-// rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap)
+ case promotePrecision: PromotePrecision =>
+ rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap)
case sub: Subtract =>
new JSONObject().put("exprType", "BINARY")
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala
index ef17a0740..4d8c246ab 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala
@@ -25,7 +25,7 @@ object TreePattern extends Enumeration {
// Expression patterns (alphabetically ordered)
val AGGREGATE_EXPRESSION = Value(0)
val ALIAS: Value = Value
-// val AND_OR: Value = Value
+ val AND_OR: Value = Value
val AND: Value = Value
val ARRAYS_ZIP: Value = Value
val ATTRIBUTE_REFERENCE: Value = Value
@@ -59,7 +59,7 @@ object TreePattern extends Enumeration {
val JSON_TO_STRUCT: Value = Value
val LAMBDA_FUNCTION: Value = Value
val LAMBDA_VARIABLE: Value = Value
- val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value // spark3.4.3
+ val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value
val LATERAL_SUBQUERY: Value = Value
val LIKE_FAMLIY: Value = Value
val LIST_SUBQUERY: Value = Value
@@ -71,33 +71,33 @@ object TreePattern extends Enumeration {
val NULL_CHECK: Value = Value
val NULL_LITERAL: Value = Value
val SERIALIZE_FROM_OBJECT: Value = Value
- val OR: Value = Value // spark3.4.3
+ val OR: Value = Value
val OUTER_REFERENCE: Value = Value
- val PARAMETER: Value = Value // spark3.4.3
- val PARAMETERIZED_QUERY: Value = Value // spark3.4.3
+ val PARAMETER: Value = Value
+ val PARAMETERIZED_QUERY: Value = Value
val PIVOT: Value = Value
val PLAN_EXPRESSION: Value = Value
val PYTHON_UDF: Value = Value
val REGEXP_EXTRACT_FAMILY: Value = Value
val REGEXP_REPLACE: Value = Value
val RUNTIME_REPLACEABLE: Value = Value
- val RUNTIME_FILTER_EXPRESSION: Value = Value // spark3.4.3移除
- val RUNTIME_FILTER_SUBQUERY: Value = Value // spark3.4.3移除
+ val RUNTIME_FILTER_EXPRESSION: Value = Value
+ val RUNTIME_FILTER_SUBQUERY: Value = Value
val SCALAR_SUBQUERY: Value = Value
val SCALAR_SUBQUERY_REFERENCE: Value = Value
val SCALA_UDF: Value = Value
- val SESSION_WINDOW: Value = Value // spark3.4.3
+ val SESSION_WINDOW: Value = Value
val SORT: Value = Value
val SUBQUERY_ALIAS: Value = Value
- val SUBQUERY_WRAPPER: Value = Value // spark3.4.3移除
+ val SUBQUERY_WRAPPER: Value = Value
val SUM: Value = Value
val TIME_WINDOW: Value = Value
val TIME_ZONE_AWARE_EXPRESSION: Value = Value
val TRUE_OR_FALSE_LITERAL: Value = Value
val WINDOW_EXPRESSION: Value = Value
- val WINDOW_TIME: Value = Value // saprk3.4.3
+ val WINDOW_TIME: Value = Value
val UNARY_POSITIVE: Value = Value
- val UNPIVOT: Value = Value // spark3.4.3
+ val UNPIVOT: Value = Value
val UPDATE_FIELDS: Value = Value
val UPPER_OR_LOWER: Value = Value
val UP_CAST: Value = Value
@@ -127,7 +127,7 @@ object TreePattern extends Enumeration {
val UNION: Value = Value
val UNRESOLVED_RELATION: Value = Value
val UNRESOLVED_WITH: Value = Value
- val TEMP_RESOLVED_COLUMN: Value = Value // spark3.4.3
+ val TEMP_RESOLVED_COLUMN: Value = Value
val TYPED_FILTER: Value = Value
val WINDOW: Value = Value
val WITH_WINDOW_DEFINITION: Value = Value
@@ -136,7 +136,7 @@ object TreePattern extends Enumeration {
val UNRESOLVED_ALIAS: Value = Value
val UNRESOLVED_ATTRIBUTE: Value = Value
val UNRESOLVED_DESERIALIZER: Value = Value
- val UNRESOLVED_HAVING: Value = Value // spark3.4.3
+ val UNRESOLVED_HAVING: Value = Value
val UNRESOLVED_ORDINAL: Value = Value
val UNRESOLVED_FUNCTION: Value = Value
val UNRESOLVED_HINT: Value = Value
@@ -145,8 +145,8 @@ object TreePattern extends Enumeration {
// Unresolved Plan patterns (Alphabetically ordered)
val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value
val UNRESOLVED_FUNC: Value = Value
- val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value // spark3.4.3
- val UNRESOLVED_TVF_ALIASES: Value = Value // spark3.4.3
+ val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value
+ val UNRESOLVED_TVF_ALIASES: Value = Value
// Execution expression patterns (alphabetically ordered)
val IN_SUBQUERY_EXEC: Value = Value
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
new file mode 100644
index 000000000..5a04f02ce
--- /dev/null
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
@@ -0,0 +1,347 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal._
+import org.apache.spark.sql.types._
+
+
+// scalastyle:off
+/**
+ * Calculates and propagates precision for fixed-precision decimals. Hive has a number of
+ * rules for this based on the SQL standard and MS SQL:
+ * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
+ * https://msdn.microsoft.com/en-us/library/ms190476.aspx
+ *
+ * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
+ * respectively, then the following operations have the following precision / scale:
+ *
+ * Operation Result Precision Result Scale
+ * ------------------------------------------------------------------------
+ * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
+ * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
+ * e1 * e2 p1 + p2 + 1 s1 + s2
+ * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)
+ * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)
+ * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2)
+ *
+ * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale
+ * needed are out of the range of available values, the scale is reduced up to 6, in order to
+ * prevent the truncation of the integer part of the decimals.
+ *
+ * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited
+ * precision, do the math on unlimited-precision numbers, then introduce casts back to the
+ * required fixed precision. This allows us to do all rounding and overflow handling in the
+ * cast-to-fixed-precision operator.
+ *
+ * In addition, when mixing non-decimal types with decimals, we use the following rules:
+ * - BYTE gets turned into DECIMAL(3, 0)
+ * - SHORT gets turned into DECIMAL(5, 0)
+ * - INT gets turned into DECIMAL(10, 0)
+ * - LONG gets turned into DECIMAL(20, 0)
+ * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
+ * - Literals INT and LONG get turned into DECIMAL with the precision strictly needed by the value
+ */
+// scalastyle:on
+object DecimalPrecision extends TypeCoercionRule {
+ import scala.math.{max, min}
+
+ private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
+
+ // Returns the wider decimal type that's wider than both of them
+ def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = {
+ widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale)
+ }
+ // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
+ def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
+ val scale = max(s1, s2)
+ val range = max(p1 - s1, p2 - s2)
+ DecimalType.bounded(range + scale, scale)
+ }
+
+ private def promotePrecision(e: Expression, dataType: DataType): Expression = {
+ PromotePrecision(Cast(e, dataType))
+ }
+
+ override def transform: PartialFunction[Expression, Expression] = {
+ decimalAndDecimal()
+ .orElse(integralAndDecimalLiteral)
+ .orElse(nondecimalAndDecimal(conf.literalPickMinimumPrecision))
+ }
+
+ private[catalyst] def decimalAndDecimal(): PartialFunction[Expression, Expression] = {
+ decimalAndDecimal(conf.decimalOperationsAllowPrecisionLoss, !conf.ansiEnabled)
+ }
+
+ /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */
+ private[catalyst] def decimalAndDecimal(allowPrecisionLoss: Boolean, nullOnOverflow: Boolean)
+ : PartialFunction[Expression, Expression] = {
+ // Skip nodes whose children have not been resolved yet
+ case e if !e.childrenResolved => e
+
+ // Skip nodes who is already promoted
+ case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e
+
+ case a @ Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
+ val resultScale = max(s1, s2)
+ val resultType = if (allowPrecisionLoss) {
+ DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
+ resultScale)
+ } else {
+ DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
+ }
+ CheckOverflow(
+ a.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)),
+ resultType, nullOnOverflow)
+
+ case s @ Subtract(e1 @ DecimalType.Expression(p1, s1),
+ e2 @ DecimalType.Expression(p2, s2), _) =>
+ val resultScale = max(s1, s2)
+ val resultType = if (allowPrecisionLoss) {
+ DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
+ resultScale)
+ } else {
+ DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
+ }
+ CheckOverflow(
+ s.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)),
+ resultType, nullOnOverflow)
+
+ case m @ Multiply(
+ e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
+ val resultType = if (allowPrecisionLoss) {
+ DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
+ } else {
+ DecimalType.bounded(p1 + p2 + 1, s1 + s2)
+ }
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ CheckOverflow(
+ m.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
+ resultType, nullOnOverflow)
+
+ case d @ Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
+ val resultType = if (allowPrecisionLoss) {
+ // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
+ // Scale: max(6, s1 + p2 + 1)
+ val intDig = p1 - s1 + s2
+ val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1)
+ val prec = intDig + scale
+ DecimalType.adjustPrecisionScale(prec, scale)
+ } else {
+ var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
+ var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
+ val diff = (intDig + decDig) - DecimalType.MAX_SCALE
+ if (diff > 0) {
+ decDig -= diff / 2 + 1
+ intDig = DecimalType.MAX_SCALE - decDig
+ }
+ DecimalType.bounded(intDig + decDig, decDig)
+ }
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ CheckOverflow(
+ d.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
+ resultType, nullOnOverflow)
+
+ case r @ Remainder(
+ e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
+ val resultType = if (allowPrecisionLoss) {
+ DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ } else {
+ DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ }
+ // resultType may have lower precision, so we cast them into wider type first.
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ CheckOverflow(
+ r.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
+ resultType, nullOnOverflow)
+
+ case p @ Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
+ val resultType = if (allowPrecisionLoss) {
+ DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ } else {
+ DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ }
+ // resultType may have lower precision, so we cast them into wider type first.
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ CheckOverflow(
+ p.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
+ resultType, nullOnOverflow)
+
+ case expr @ IntegralDivide(
+ e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ val promotedExpr = expr.copy(
+ left = promotePrecision(e1, widerType),
+ right = promotePrecision(e2, widerType))
+ if (expr.dataType.isInstanceOf[DecimalType]) {
+ // This follows division rule
+ val intDig = p1 - s1 + s2
+ // No precision loss can happen as the result scale is 0.
+ // Overflow can happen only in the promote precision of the operands, but if none of them
+ // overflows in that phase, no overflow can happen, but CheckOverflow is needed in order
+ // to return a decimal with the proper scale and precision
+ CheckOverflow(promotedExpr, DecimalType.bounded(intDig, 0), nullOnOverflow)
+ } else {
+ promotedExpr
+ }
+
+ case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
+ e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+ val resultType = widerDecimalType(p1, s1, p2, s2)
+ val newE1 = if (e1.dataType == resultType) e1 else Cast(e1, resultType)
+ val newE2 = if (e2.dataType == resultType) e2 else Cast(e2, resultType)
+ b.makeCopy(Array(newE1, newE2))
+ }
+
+ /**
+ * Strength reduction for comparing integral expressions with decimal literals.
+ * 1. int_col > decimal_literal => int_col > floor(decimal_literal)
+ * 2. int_col >= decimal_literal => int_col >= ceil(decimal_literal)
+ * 3. int_col < decimal_literal => int_col < ceil(decimal_literal)
+ * 4. int_col <= decimal_literal => int_col <= floor(decimal_literal)
+ * 5. decimal_literal > int_col => ceil(decimal_literal) > int_col
+ * 6. decimal_literal >= int_col => floor(decimal_literal) >= int_col
+ * 7. decimal_literal < int_col => floor(decimal_literal) < int_col
+ * 8. decimal_literal <= int_col => ceil(decimal_literal) <= int_col
+ *
+ * Note that technically this is an "optimization" and should go into the optimizer. However,
+ * by the time the optimizer runs, these comparison expressions would be pretty hard to pattern
+ * match because there are multiple (at least 2) levels of casts involved.
+ *
+ * There are a lot more possible rules we can implement, but we don't do them
+ * because we are not sure how common they are.
+ */
+ private val integralAndDecimalLiteral: PartialFunction[Expression, Expression] = {
+
+ case GreaterThan(i @ IntegralType(), DecimalLiteral(value)) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ TrueLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ FalseLiteral
+ } else {
+ GreaterThan(i, Literal(value.floor.toLong))
+ }
+
+ case GreaterThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ TrueLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ FalseLiteral
+ } else {
+ GreaterThanOrEqual(i, Literal(value.ceil.toLong))
+ }
+
+ case LessThan(i @ IntegralType(), DecimalLiteral(value)) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ FalseLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ TrueLiteral
+ } else {
+ LessThan(i, Literal(value.ceil.toLong))
+ }
+
+ case LessThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ FalseLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ TrueLiteral
+ } else {
+ LessThanOrEqual(i, Literal(value.floor.toLong))
+ }
+
+ case GreaterThan(DecimalLiteral(value), i @ IntegralType()) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ FalseLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ TrueLiteral
+ } else {
+ GreaterThan(Literal(value.ceil.toLong), i)
+ }
+
+ case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralType()) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ FalseLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ TrueLiteral
+ } else {
+ GreaterThanOrEqual(Literal(value.floor.toLong), i)
+ }
+
+ case LessThan(DecimalLiteral(value), i @ IntegralType()) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ TrueLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ FalseLiteral
+ } else {
+ LessThan(Literal(value.floor.toLong), i)
+ }
+
+ case LessThanOrEqual(DecimalLiteral(value), i @ IntegralType()) =>
+ if (DecimalLiteral.smallerThanSmallestLong(value)) {
+ TrueLiteral
+ } else if (DecimalLiteral.largerThanLargestLong(value)) {
+ FalseLiteral
+ } else {
+ LessThanOrEqual(Literal(value.ceil.toLong), i)
+ }
+ }
+
+ /**
+ * Type coercion for BinaryOperator in which one side is a non-decimal numeric, and the other
+ * side is a decimal.
+ */
+ private def nondecimalAndDecimal(literalPickMinimumPrecision: Boolean)
+ : PartialFunction[Expression, Expression] = {
+ // Promote integers inside a binary expression with fixed-precision decimals to decimals,
+ // and fixed-precision decimals in an expression with floats / doubles to doubles
+ case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
+ (left, right) match {
+ // Promote literal integers inside a binary expression with fixed-precision decimals to
+ // decimals. The precision and scale are the ones strictly needed by the integer value.
+ // Requiring more precision than necessary may lead to a useless loss of precision.
+ // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2.
+ // If we use the default precision and scale for the integer type, 2 is considered a
+ // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18),
+ // which is out of range and therefore it will become DECIMAL(38, 7), leading to
+ // potentially loosing 11 digits of the fractional part. Using only the precision needed
+ // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would
+ // become DECIMAL(38, 16), safely having a much lower precision loss.
+ case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] &&
+ l.dataType.isInstanceOf[IntegralType] &&
+ literalPickMinimumPrecision =>
+ b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r))
+ case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] &&
+ r.dataType.isInstanceOf[IntegralType] &&
+ literalPickMinimumPrecision =>
+ b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r))))
+ // Promote integers inside a binary expression with fixed-precision decimals to decimals,
+ // and fixed-precision decimals in an expression with floats / doubles to doubles
+ case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) =>
+ b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r))
+ case (l @ DecimalType.Expression(_, _), r @ IntegralType()) =>
+ b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType))))
+ case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) =>
+ b.makeCopy(Array(l, Cast(r, DoubleType)))
+ case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) =>
+ b.makeCopy(Array(Cast(l, DoubleType), r))
+ case _ => b
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
new file mode 100644
index 000000000..2b6d3ff33
--- /dev/null
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.SQLQueryContext
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+/**
+ * Return the unscaled Long value of a Decimal, assuming it fits in a Long.
+ * Note: this expression is internal and created only by the optimizer,
+ * we don't need to do type check for it.
+ */
+case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant {
+
+ override def dataType: DataType = LongType
+ override def toString: String = s"UnscaledValue($child)"
+
+ protected override def nullSafeEval(input: Any): Any =
+ input.asInstanceOf[Decimal].toUnscaledLong
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()")
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): UnscaledValue =
+ copy(child = newChild)
+}
+
+/**
+ * Create a Decimal from an unscaled Long value.
+ * Note: this expression is internal and created only by the optimizer,
+ * we don't need to do type check for it.
+ */
+case class MakeDecimal(
+ child: Expression,
+ precision: Int,
+ scale: Int,
+ nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant {
+
+ def this(child: Expression, precision: Int, scale: Int) = {
+ this(child, precision, scale, !SQLConf.get.ansiEnabled)
+ }
+
+ override def dataType: DataType = DecimalType(precision, scale)
+ override def nullable: Boolean = child.nullable || nullOnOverflow
+ override def toString: String = s"MakeDecimal($child,$precision,$scale)"
+
+ protected override def nullSafeEval(input: Any): Any = {
+ val longInput = input.asInstanceOf[Long]
+ val result = new Decimal()
+ if (nullOnOverflow) {
+ result.setOrNull(longInput, precision, scale)
+ } else {
+ result.set(longInput, precision, scale)
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, eval => {
+ val setMethod = if (nullOnOverflow) {
+ "setOrNull"
+ } else {
+ "set"
+ }
+ val setNull = if (nullable) {
+ s"${ev.isNull} = ${ev.value} == null;"
+ } else {
+ ""
+ }
+ s"""
+ |${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale);
+ |$setNull
+ |""".stripMargin
+ })
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): MakeDecimal =
+ copy(child = newChild)
+}
+
+object MakeDecimal {
+ def apply(child: Expression, precision: Int, scale: Int): MakeDecimal = {
+ new MakeDecimal(child, precision, scale)
+ }
+}
+
+/**
+ * An expression used to wrap the children when promote the precision of DecimalType to avoid
+ * promote multiple times.
+ */
+case class PromotePrecision(child: Expression) extends UnaryExpression {
+ override def dataType: DataType = child.dataType
+ override def eval(input: InternalRow): Any = child.eval(input)
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ child.genCode(ctx)
+ override def prettyName: String = "promote_precision"
+ override def sql: String = child.sql
+ override lazy val canonicalized: Expression = child.canonicalized
+
+ override protected def withNewChildInternal(newChild: Expression): Expression =
+ copy(child = newChild)
+}
+
+/**
+ * Rounds the decimal to given scale and check whether the decimal can fit in provided precision
+ * or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an
+ * `ArithmeticException` is thrown.
+ */
+case class CheckOverflow(
+ child: Expression,
+ dataType: DecimalType,
+ nullOnOverflow: Boolean) extends UnaryExpression with SupportQueryContext {
+
+ override def nullable: Boolean = true
+
+ override def nullSafeEval(input: Any): Any =
+ input.asInstanceOf[Decimal].toPrecision(
+ dataType.precision,
+ dataType.scale,
+ Decimal.ROUND_HALF_UP,
+ nullOnOverflow,
+ getContextOrNull())
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val errorContextCode = if (nullOnOverflow) {
+ "\"\""
+ } else {
+ ctx.addReferenceObj("errCtx", queryContext)
+ }
+ nullSafeCodeGen(ctx, ev, eval => {
+ // scalastyle:off line.size.limit
+ s"""
+ |${ev.value} = $eval.toPrecision(
+ | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode);
+ |${ev.isNull} = ${ev.value} == null;
+ """.stripMargin
+ // scalastyle:on line.size.limit
+ })
+ }
+
+ override def toString: String = s"CheckOverflow($child, $dataType)"
+
+ override def sql: String = child.sql
+
+ override protected def withNewChildInternal(newChild: Expression): CheckOverflow =
+ copy(child = newChild)
+
+ override def initQueryContext(): Option[SQLQueryContext] = if (nullOnOverflow) {
+ Some(origin.context)
+ } else {
+ None
+ }
+}
+
+// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`.
+case class CheckOverflowInSum(
+ child: Expression,
+ dataType: DecimalType,
+ nullOnOverflow: Boolean,
+ context: SQLQueryContext) extends UnaryExpression {
+
+ override def nullable: Boolean = true
+
+ override def eval(input: InternalRow): Any = {
+ val value = child.eval(input)
+ if (value == null) {
+ if (nullOnOverflow) null
+ else throw QueryExecutionErrors.overflowInSumOfDecimalError(context)
+ } else {
+ value.asInstanceOf[Decimal].toPrecision(
+ dataType.precision,
+ dataType.scale,
+ Decimal.ROUND_HALF_UP,
+ nullOnOverflow,
+ context)
+ }
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val childGen = child.genCode(ctx)
+ val errorContextCode = if (nullOnOverflow) {
+ "\"\""
+ } else {
+ ctx.addReferenceObj("errCtx", context)
+ }
+ val nullHandling = if (nullOnOverflow) {
+ ""
+ } else {
+ s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);"
+ }
+ // scalastyle:off line.size.limit
+ val code = code"""
+ |${childGen.code}
+ |boolean ${ev.isNull} = ${childGen.isNull};
+ |Decimal ${ev.value} = null;
+ |if (${childGen.isNull}) {
+ | $nullHandling
+ |} else {
+ | ${ev.value} = ${childGen.value}.toPrecision(
+ | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode);
+ | ${ev.isNull} = ${ev.value} == null;
+ |}
+ |""".stripMargin
+ // scalastyle:on line.size.limit
+
+ ev.copy(code = code)
+ }
+
+ override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)"
+
+ override def sql: String = child.sql
+
+ override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum =
+ copy(child = newChild)
+}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index e6c266242..b758b69e6 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -27,8 +27,6 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import com.huawei.boostkit.spark.ColumnarPluginConfig
-import scala.annotation.tailrec
-
/**
* Insert a filter on one side of the join if the other side has a selective predicate.
* The filter could be an IN subquery (converted to a semi join), a bloom filter, or something
@@ -87,12 +85,12 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
val bloomFilterAgg =
if (rowCount.isDefined && rowCount.get.longValue > 0L) {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)),
- rowCount.get.longValue)
+ Literal(rowCount.get.longValue)
} else {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)))
}
-
- val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")()
+ val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None)
+ val alias = Alias(aggExp, "bloomFilter")()
val aggregate =
ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan)))
val bloomFilterSubquery = if (canReuseExchange) {
@@ -114,8 +112,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType)
val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp)
val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)()
- val aggregate =
- ColumnPruning(Aggregate(Seq(filterCreationSideExp), Seq(alias), filterCreationSidePlan))
+ val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan)
if (!canBroadcastBySize(aggregate, conf)) {
// Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold,
// i.e., the semi-join will be a shuffled join, which is not worthwhile.
@@ -132,39 +129,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
* do not add a subquery that might have an expensive computation
*/
private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = {
- @tailrec
- def isSelective(
- p: LogicalPlan,
- predicateReference: AttributeSet,
- hasHitFilter: Boolean,
- hasHitSelectiveFilter: Boolean): Boolean = p match {
- case Project(projectList, child) =>
- if (hasHitFilter) {
- // We need to make sure all expressions referenced by filter predicates are simple
- // expressions.
- val referencedExprs = projectList.filter(predicateReference.contains)
- referencedExprs.forall(isSimpleExpression) &&
- isSelective(
- child,
- referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _),
- hasHitFilter,
- hasHitSelectiveFilter)
- } else {
- assert(predicateReference.isEmpty && !hasHitSelectiveFilter)
- isSelective(child, predicateReference, hasHitFilter, hasHitSelectiveFilter)
- }
- case Filter(condition, child) =>
- isSimpleExpression(condition) && isSelective(
- child,
- predicateReference ++ condition.references,
- hasHitFilter = true,
- hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition))
- case _: LeafNode => hasHitSelectiveFilter
+ val ret = plan match {
+ case PhysicalOperation(_, filters, child) if child.isInstance[LeafNode] =>
+ filters.forall(isSimpleExpression) &&
+ filters.exists(isLikelySelective)
case _ => false
}
-
- !plan.isStreaming &&
- isSelective(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false)
+ !plan.isStreaming && ret
}
private def isSimpleExpression(e: Expression): Boolean = {
@@ -213,8 +184,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
/**
* Check that:
- * - The filterApplicationSideJoinExp can be pushed down through joins, aggregates and windows
- * (ie the expression references originate from a single leaf node)
+ * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the
+ * expression references originate from a single leaf node)
* - The filter creation side has a selective predicate
* - The current join is a shuffle join or a broadcast join that has a shuffle below it
* - The max filterApplicationSide scan size is greater than a configurable threshold
@@ -328,13 +299,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
case s: Subquery if s.correlated => plan
case _ if !conf.runtimeFilterSemiJoinReductionEnabled &&
!conf.runtimeFilterBloomFilterEnabled => plan
- case _ =>
- val newPlan = tryInjectRuntimeFilter(plan)
- if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) {
- RewritePredicateSubquery(newPlan)
- } else {
- newPlan
- }
+ case _ => tryInjectRuntimeFilter(plan)
}
}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
index 630034bd7..b6dd2ab51 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution
import java.util.concurrent.TimeUnit.NANOSECONDS
import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP
-
import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._
import com.huawei.boostkit.spark.util.OmniAdaptorUtil
import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs
@@ -43,7 +42,6 @@ import org.apache.spark.sql.types.{LongType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.sql.catalyst.plans.{AliasAwareOutputExpression, AliasAwareQueryOutputOrdering}
-
case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryExecNode
with AliasAwareOutputExpression
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 37f3db5d6..e49c88191 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -105,23 +105,6 @@ class QueryExecution(
case other => other
}
- // The plan that has been normalized by custom rules, so that it's more likely to hit cache.
- lazy val normalized: LogicalPlan = {
- val normalizationRules = sparkSession.sessionState.planNormalizationRules
- if (normalizationRules.isEmpty) {
- commandExecuted
- } else {
- val planChangeLogger = new PlanChangeLogger[LogicalPlan]()
- val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) =>
- val result = rule.apply(p)
- planChangeLogger.logRule(rule.ruleName, p, result)
- result
- }
- planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized)
- normalized
- }
- } // Spark3.4.3
-
lazy val withCachedData: LogicalPlan = sparkSession.withActive {
assertAnalyzed()
assertSupported()
@@ -529,4 +512,4 @@ object QueryExecution {
case e: Throwable => throw toInternalError(msg, e)
}
}
-}
+}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index e5c8fc0ab..b5d638138 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.adaptive
import java.util
import java.util.concurrent.LinkedBlockingQueue
+
import scala.collection.JavaConverters._
import scala.collection.concurrent.TrieMap
import scala.collection.mutable
import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal
+
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
@@ -33,7 +35,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
-import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
@@ -153,13 +154,7 @@ case class AdaptiveSparkPlanExec(
)
private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = {
- val rules = if (isFinalStage &&
- !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) {
- queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule])
- } else {
- queryStageOptimizerRules
- }
- val optimized = rules.foldLeft(plan) { case (latestPlan, rule) =>
+ val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) =>
val applied = rule.apply(latestPlan)
val result = rule match {
case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) =>
@@ -194,7 +189,7 @@ case class AdaptiveSparkPlanExec(
@volatile private var currentPhysicalPlan = initialPlan
- @volatile private var _isFinalPlan = false
+ volatile private var isFinalPlan = false
private var currentStageId = 0
@@ -211,8 +206,6 @@ case class AdaptiveSparkPlanExec(
def executedPlan: SparkPlan = currentPhysicalPlan
- def isFinalPlan: Boolean = _isFinalPlan
-
override def conf: SQLConf = context.session.sessionState.conf
override def output: Seq[Attribute] = inputPlan.output
@@ -231,8 +224,6 @@ case class AdaptiveSparkPlanExec(
.map(_.toLong).filter(SQLExecution.getQueryExecution(_) eq context.qe)
}
- def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity) // saprk 3.4.3
-
private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized {
if (isFinalPlan) return currentPhysicalPlan
@@ -320,8 +311,7 @@ case class AdaptiveSparkPlanExec(
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
if (newCost < origCost ||
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
- logOnLevel("Plan changed:\n" +
- sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n"))
+ logOnLevel("Plan changed from $currentPhysicalPlan to $newPhysicalPlan")
cleanUpTempTags(newPhysicalPlan)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
@@ -337,7 +327,7 @@ case class AdaptiveSparkPlanExec(
optimizeQueryStage(result.newPlan, isFinalStage = true),
postStageCreationRules(supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
- _isFinalPlan = true
+ isFinalPlan = true
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
currentPhysicalPlan
}
@@ -351,7 +341,7 @@ case class AdaptiveSparkPlanExec(
if (!isSubquery && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) {
getExecutionId.foreach(onUpdatePlan(_, Seq.empty))
}
- logOnLevel(s"Final plan: \n$currentPhysicalPlan")
+ logOnLevel(s"Final plan: $currentPhysicalPlan")
}
override def executeCollect(): Array[InternalRow] = {
--
Gitee
From 40acdcc48e193b6d3386b17b1182ab75b0dc1d8e Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Fri, 18 Oct 2024 03:24:45 +0000
Subject: [PATCH 35/43] fix EOF bug
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index b758b69e6..328be083f 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -85,8 +85,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
val bloomFilterAgg =
if (rowCount.isDefined && rowCount.get.longValue > 0L) {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)),
- Literal(rowCount.get.longValue)
- } else {
+ Literal(rowCount.get.longValue))
+ else {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)))
}
val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None)
--
Gitee
From 9cc62a176cb499370ca287fce5c40625821c3e33 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Fri, 18 Oct 2024 03:26:00 +0000
Subject: [PATCH 36/43] fix bug
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index b5d638138..a1a5d09f1 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -189,7 +189,7 @@ case class AdaptiveSparkPlanExec(
@volatile private var currentPhysicalPlan = initialPlan
- volatile private var isFinalPlan = false
+ private var isFinalPlan = false
private var currentStageId = 0
--
Gitee
From 83cc6aaf2fcb0e1c4cceda8953f154639a7d54d4 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Fri, 18 Oct 2024 03:51:10 +0000
Subject: [PATCH 37/43] fix EOF bug
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index 328be083f..895e7c3d7 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -86,7 +86,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
if (rowCount.isDefined && rowCount.get.longValue > 0L) {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)),
Literal(rowCount.get.longValue))
- else {
+ } else {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)))
}
val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None)
--
Gitee
From d49c55ec71923ae0d70e73a995814ddd5afb953f Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Fri, 18 Oct 2024 04:12:46 +0000
Subject: [PATCH 38/43] fix bug
Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com>
---
.../spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index 895e7c3d7..2dd8efe58 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -130,7 +130,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
*/
private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = {
val ret = plan match {
- case PhysicalOperation(_, filters, child) if child.isInstance[LeafNode] =>
+ case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] =>
filters.forall(isSimpleExpression) &&
filters.exists(isLikelySelective)
case _ => false
--
Gitee
From 25e9fb1f0bd4c8a88fa37933a170f71b6520ccf4 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Sun, 20 Oct 2024 19:07:54 +0800
Subject: [PATCH 39/43] add normalized method
---
.../spark/sql/execution/QueryExecution.scala | 19 ++++++++++++++++++-
1 file changed, 18 insertions(+), 1 deletion(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index e49c88191..3630b0f2e 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -105,12 +105,29 @@ class QueryExecution(
case other => other
}
+ // The plan that has been normalized by custom rules, so that it's more likely to hit cache.
+ lazy val normalized: LogicalPlan = {
+ val normalizationRules = sparkSession.sessionState.planNormalizationRules
+ if (normalizationRules.isEmpty) {
+ commandExecuted
+ } else {
+ val planChangeLogger = new PlanChangeLogger[LogicalPlan]()
+ val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) =>
+ val result = rule.apply(p)
+ planChangeLogger.logRule(rule.ruleName, p, result)
+ result
+ }
+ planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized)
+ normalized
+ }
+ }
+
lazy val withCachedData: LogicalPlan = sparkSession.withActive {
assertAnalyzed()
assertSupported()
// clone the plan to avoid sharing the plan instance between different stages like analyzing,
// optimizing and planning.
- sparkSession.sharedState.cacheManager.useCachedData(commandExecuted.clone())
+ sparkSession.sharedState.cacheManager.useCachedData(normalized.clone())
}
def assertCommandExecuted(): Unit = commandExecuted
--
Gitee
From ff0aafbe64859794ae67914e76af78b6b326c2be Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Sun, 20 Oct 2024 19:14:55 +0800
Subject: [PATCH 40/43] fix UT bug
---
.../optimizer/InjectRuntimeFilter.scala | 59 +++++++++++++++----
1 file changed, 46 insertions(+), 13 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index b758b69e6..a6124f5e3 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -84,15 +84,14 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
val rowCount = filterCreationSidePlan.stats.rowCount
val bloomFilterAgg =
if (rowCount.isDefined && rowCount.get.longValue > 0L) {
- new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)),
- Literal(rowCount.get.longValue)
+ new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), rowCount.get.longValue)
} else {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)))
}
- val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None)
- val alias = Alias(aggExp, "bloomFilter")()
+
+ val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")()
val aggregate =
- ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan)))
+ ColumnPruning(Aggregate(Seq(filterCreationSideExp), Seq(alias), filterCreationSidePlan))
val bloomFilterSubquery = if (canReuseExchange) {
// Try to reuse the results of broadcast exchange.
RuntimeFilterSubquery(filterApplicationSideExp, aggregate, filterCreationSideExp)
@@ -129,13 +128,39 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
* do not add a subquery that might have an expensive computation
*/
private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = {
- val ret = plan match {
- case PhysicalOperation(_, filters, child) if child.isInstance[LeafNode] =>
- filters.forall(isSimpleExpression) &&
- filters.exists(isLikelySelective)
+ @tailrec
+ def isSelective(
+ p: LogicalPlan,
+ predicateReference: AttributeSet,
+ hasHitFilter: Boolean,
+ hasHitSelectiveFilter: Boolean): Boolean = p match {
+ case Project(projectList, child) =>
+ if (hasHitFilter) {
+ // We need to make sure all expressions referenced by filter predicates are simple
+ // expressions.
+ val referencedExprs = projectList.filter(predicateReference.contains)
+ referencedExprs.forall(isSimpleExpression) &&
+ isSelective(
+ child,
+ referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _),
+ hasHitFilter,
+ hasHitSelectiveFilter)
+ } else {
+ assert(predicateReference.isEmpty && !hasHitSelectiveFilter)
+ isSelective(child, predicateReference, hasHitFilter, hasHitSelectiveFilter)
+ }
+ case Filter(condition, child) =>
+ isSimpleExpression(condition) && isSelective(
+ child,
+ predicateReference ++ condition.references,
+ hasHitFilter = true,
+ hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition))
+ case _: LeafNode => hasHitSelectiveFilter
case _ => false
}
- !plan.isStreaming && ret
+
+ !plan.isStreaming &&
+ isSelective(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false)
}
private def isSimpleExpression(e: Expression): Boolean = {
@@ -153,6 +178,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
plan.exists {
case Join(left, right, _, _, hint) => isProbablyShuffleJoin(left, right, hint)
case _: Aggregate => true
+ case _: Window => true
case _ => false
}
}
@@ -184,8 +210,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
/**
* Check that:
- * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the
- * expression references originate from a single leaf node)
+ * - The filterApplicationSideJoinExp can be pushed down through joins, aggregates and windows
+ * (ie the expression references originate from a single leaf node)
* - The filter creation side has a selective predicate
* - The current join is a shuffle join or a broadcast join that has a shuffle below it
* - The max filterApplicationSide scan size is greater than a configurable threshold
@@ -212,6 +238,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
}
// This checks if there is already a DPP filter, as this rule is called just after DPP.
+ @tailrec
def hasDynamicPruningSubquery(
left: LogicalPlan,
right: LogicalPlan,
@@ -299,7 +326,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
case s: Subquery if s.correlated => plan
case _ if !conf.runtimeFilterSemiJoinReductionEnabled &&
!conf.runtimeFilterBloomFilterEnabled => plan
- case _ => tryInjectRuntimeFilter(plan)
+ case _ =>
+ val newPlan = tryInjectRuntimeFilter(plan)
+ if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) {
+ RewritePredicateSubquery(newPlan)
+ } else {
+ newPlan
+ }
}
}
\ No newline at end of file
--
Gitee
From 2e880e09f5f25e466daaa2937df59bc7b0bca9e6 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 21 Oct 2024 10:57:53 +0800
Subject: [PATCH 41/43] fix BUG
---
.../spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index b5d638138..a1a5d09f1 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -189,7 +189,7 @@ case class AdaptiveSparkPlanExec(
@volatile private var currentPhysicalPlan = initialPlan
- volatile private var isFinalPlan = false
+ private var isFinalPlan = false
private var currentStageId = 0
--
Gitee
From e2222aab325779d2d80b79e901834af759d6b2d9 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Mon, 21 Oct 2024 11:09:39 +0800
Subject: [PATCH 42/43] fix BUG
---
.../spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala | 2 ++
1 file changed, 2 insertions(+)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index a6124f5e3..c4a94f5db 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import com.huawei.boostkit.spark.ColumnarPluginConfig
+import scala.annotation.tailrec
+
/**
* Insert a filter on one side of the join if the other side has a selective predicate.
* The filter could be an IN subquery (converted to a semi join), a bloom filter, or something
--
Gitee
From 8a74794acde7c58155dcaa22f9c1334a0fed1435 Mon Sep 17 00:00:00 2001
From: wangwei <14424757+wiwimao@user.noreply.gitee.com>
Date: Tue, 22 Oct 2024 22:44:21 +0800
Subject: [PATCH 43/43] refactor orc push down filter
---
.../spark/jni/OrcColumnarBatchScanReader.java | 445 ++++++++++++------
.../orc/OmniOrcColumnarBatchReader.java | 161 +++----
.../datasources/orc/OmniOrcFileFormat.scala | 107 +----
...OrcColumnarBatchJniReaderDataTypeTest.java | 2 +-
...ColumnarBatchJniReaderNotPushDownTest.java | 2 +-
...OrcColumnarBatchJniReaderPushDownTest.java | 2 +-
...BatchJniReaderSparkORCNotPushDownTest.java | 2 +-
...narBatchJniReaderSparkORCPushDownTest.java | 2 +-
.../jni/OrcColumnarBatchJniReaderTest.java | 134 +++---
9 files changed, 437 insertions(+), 420 deletions(-)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java
index 73438aa43..064f50416 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java
+++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java
@@ -22,59 +22,78 @@ import com.huawei.boostkit.scan.jni.OrcColumnarBatchJniReader;
import nova.hetu.omniruntime.type.DataType;
import nova.hetu.omniruntime.vector.*;
+import org.apache.spark.sql.catalyst.util.CharVarcharUtils;
import org.apache.spark.sql.catalyst.util.RebaseDateTime;
-import org.apache.hadoop.hive.ql.io.sarg.ExpressionTree;
-import org.apache.hadoop.hive.ql.io.sarg.PredicateLeaf;
-import org.apache.orc.OrcFile.ReaderOptions;
-import org.apache.orc.Reader.Options;
+import org.apache.spark.sql.sources.And;
+import org.apache.spark.sql.sources.EqualTo;
+import org.apache.spark.sql.sources.Filter;
+import org.apache.spark.sql.sources.GreaterThan;
+import org.apache.spark.sql.sources.GreaterThanOrEqual;
+import org.apache.spark.sql.sources.In;
+import org.apache.spark.sql.sources.IsNotNull;
+import org.apache.spark.sql.sources.IsNull;
+import org.apache.spark.sql.sources.LessThan;
+import org.apache.spark.sql.sources.LessThanOrEqual;
+import org.apache.spark.sql.sources.Not;
+import org.apache.spark.sql.sources.Or;
+import org.apache.spark.sql.types.BooleanType;
+import org.apache.spark.sql.types.DateType;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.DoubleType;
+import org.apache.spark.sql.types.IntegerType;
+import org.apache.spark.sql.types.LongType;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.ShortType;
+import org.apache.spark.sql.types.StringType;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
import org.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.math.BigDecimal;
import java.net.URI;
-import java.sql.Date;
+import java.time.LocalDate;
+import java.text.DateFormat;
+import java.text.ParseException;
+import java.text.SimpleDateFormat;
import java.util.ArrayList;
-import java.util.List;
+import java.util.Arrays;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.TimeZone;
public class OrcColumnarBatchScanReader {
private static final Logger LOGGER = LoggerFactory.getLogger(OrcColumnarBatchScanReader.class);
+ private boolean nativeSupportTimestampRebase;
+ private static final Pattern CHAR_TYPE = Pattern.compile("char\\(\\s*(\\d+)\\s*\\)");
+
+ private static final int MAX_LEAF_THRESHOLD = 256;
+
public long reader;
public long recordReader;
public long batchReader;
- public int[] colsToGet;
- public int realColsCnt;
- public ArrayList fildsNames;
+ // All ORC fieldNames
+ public ArrayList allFieldsNames;
- public ArrayList colToInclu;
+ // Indicate columns to read
+ public int[] colsToGet;
- public String[] requiredfieldNames;
+ // Actual columns to read
+ public ArrayList includedColumns;
- public int[] precisionArray;
+ // max threshold for leaf node
+ private int leafIndex;
- public int[] scaleArray;
+ // spark required schema
+ private StructType requiredSchema;
public OrcColumnarBatchJniReader jniReader;
public OrcColumnarBatchScanReader() {
jniReader = new OrcColumnarBatchJniReader();
- fildsNames = new ArrayList();
- }
-
- public JSONObject getSubJson(ExpressionTree node) {
- JSONObject jsonObject = new JSONObject();
- jsonObject.put("op", node.getOperator().ordinal());
- if (node.getOperator().toString().equals("LEAF")) {
- jsonObject.put("leaf", node.toString());
- return jsonObject;
- }
- ArrayList child = new ArrayList();
- for (ExpressionTree childNode : node.getChildren()) {
- JSONObject rtnJson = getSubJson(childNode);
- child.add(rtnJson);
- }
- jsonObject.put("child", child);
- return jsonObject;
+ allFieldsNames = new ArrayList();
}
public String padZeroForDecimals(String [] decimalStrArray, int decimalScale) {
@@ -86,98 +105,15 @@ public class OrcColumnarBatchScanReader {
return String.format("%1$-" + decimalScale + "s", decimalVal).replace(' ', '0');
}
- public int getPrecision(String colname) {
- for (int i = 0; i < requiredfieldNames.length; i++) {
- if (colname.equals(requiredfieldNames[i])) {
- return precisionArray[i];
- }
- }
-
- return -1;
- }
-
- public int getScale(String colname) {
- for (int i = 0; i < requiredfieldNames.length; i++) {
- if (colname.equals(requiredfieldNames[i])) {
- return scaleArray[i];
- }
- }
-
- return -1;
- }
-
- public JSONObject getLeavesJson(List leaves) {
- JSONObject jsonObjectList = new JSONObject();
- for (int i = 0; i < leaves.size(); i++) {
- PredicateLeaf pl = leaves.get(i);
- JSONObject jsonObject = new JSONObject();
- jsonObject.put("op", pl.getOperator().ordinal());
- jsonObject.put("name", pl.getColumnName());
- jsonObject.put("type", pl.getType().ordinal());
- if (pl.getLiteral() != null) {
- if (pl.getType() == PredicateLeaf.Type.DATE) {
- jsonObject.put("literal", ((int)Math.ceil(((Date)pl.getLiteral()).getTime()* 1.0/3600/24/1000)) + "");
- } else if (pl.getType() == PredicateLeaf.Type.DECIMAL) {
- int decimalP = getPrecision(pl.getColumnName());
- int decimalS = getScale(pl.getColumnName());
- String[] spiltValues = pl.getLiteral().toString().split("\\.");
- if (decimalS == 0) {
- jsonObject.put("literal", spiltValues[0] + " " + decimalP + " " + decimalS);
- } else {
- String scalePadZeroStr = padZeroForDecimals(spiltValues, decimalS);
- jsonObject.put("literal", spiltValues[0] + "." + scalePadZeroStr + " " + decimalP + " " + decimalS);
- }
- } else {
- jsonObject.put("literal", pl.getLiteral().toString());
- }
- } else {
- jsonObject.put("literal", "");
- }
- if ((pl.getLiteralList() != null) && (pl.getLiteralList().size() != 0)){
- List lst = new ArrayList<>();
- for (Object ob : pl.getLiteralList()) {
- if (ob == null) {
- lst.add(null);
- continue;
- }
- if (pl.getType() == PredicateLeaf.Type.DECIMAL) {
- int decimalP = getPrecision(pl.getColumnName());
- int decimalS = getScale(pl.getColumnName());
- String[] spiltValues = ob.toString().split("\\.");
- if (decimalS == 0) {
- lst.add(spiltValues[0] + " " + decimalP + " " + decimalS);
- } else {
- String scalePadZeroStr = padZeroForDecimals(spiltValues, decimalS);
- lst.add(spiltValues[0] + "." + scalePadZeroStr + " " + decimalP + " " + decimalS);
- }
- } else if (pl.getType() == PredicateLeaf.Type.DATE) {
- lst.add(((int)Math.ceil(((Date)ob).getTime()* 1.0/3600/24/1000)) + "");
- } else {
- lst.add(ob.toString());
- }
- }
- jsonObject.put("literalList", lst);
- } else {
- jsonObject.put("literalList", new ArrayList());
- }
- jsonObjectList.put("leaf-" + i, jsonObject);
- }
- return jsonObjectList;
- }
-
/**
* Init Orc reader.
*
* @param uri split file path
- * @param options split file options
*/
- public long initializeReaderJava(URI uri, ReaderOptions options) {
+ public long initializeReaderJava(URI uri) {
JSONObject job = new JSONObject();
- if (options.getOrcTail() == null) {
- job.put("serializedTail", "");
- } else {
- job.put("serializedTail", options.getOrcTail().getSerializedTail().toString());
- }
+
+ job.put("serializedTail", "");
job.put("tailLocation", 9223372036854775807L);
job.put("scheme", uri.getScheme() == null ? "" : uri.getScheme());
@@ -185,37 +121,36 @@ public class OrcColumnarBatchScanReader {
job.put("port", uri.getPort());
job.put("path", uri.getPath() == null ? "" : uri.getPath());
- reader = jniReader.initializeReader(job, fildsNames);
+ reader = jniReader.initializeReader(job, allFieldsNames);
return reader;
}
/**
* Init Orc RecordReader.
*
- * @param options split file options
+ * @param offset split file offset
+ * @param length split file read length
+ * @param pushedFilter the filter push down to native
+ * @param requiredSchema the columns read from native
*/
- public long initializeRecordReaderJava(Options options) {
+ public long initializeRecordReaderJava(long offset, long length, Filter pushedFilter, StructType requiredSchema) {
+ this.requiredSchema = requiredSchema;
JSONObject job = new JSONObject();
- if (options.getInclude() == null) {
- job.put("include", "");
- } else {
- job.put("include", options.getInclude().toString());
- }
- job.put("offset", options.getOffset());
- job.put("length", options.getLength());
- // When the number of pushedFilters > hive.CNF_COMBINATIONS_THRESHOLD, the expression is rewritten to
- // 'YES_NO_NULL'. Under the circumstances, filter push down will be skipped.
- if (options.getSearchArgument() != null
- && !options.getSearchArgument().toString().contains("YES_NO_NULL")) {
- LOGGER.debug("SearchArgument: {}", options.getSearchArgument().toString());
- JSONObject jsonexpressionTree = getSubJson(options.getSearchArgument().getExpression());
- job.put("expressionTree", jsonexpressionTree);
- JSONObject jsonleaves = getLeavesJson(options.getSearchArgument().getLeaves());
- job.put("leaves", jsonleaves);
- }
- job.put("includedColumns", colToInclu.toArray());
+ job.put("offset", offset);
+ job.put("length", length);
+ if (pushedFilter != null) {
+ JSONObject jsonExpressionTree = new JSONObject();
+ JSONObject jsonLeaves = new JSONObject();
+ boolean flag = canPushDown(pushedFilter, jsonExpressionTree, jsonLeaves);
+ if (flag) {
+ job.put("expressionTree", jsonExpressionTree);
+ job.put("leaves", jsonLeaves);
+ }
+ }
+
+ job.put("includedColumns", includedColumns.toArray());
recordReader = jniReader.initializeRecordReader(reader, job);
return recordReader;
}
@@ -253,14 +188,22 @@ public class OrcColumnarBatchScanReader {
}
}
+ public void convertJulianToGregorianMicros(LongVec longVec, long rowNumber) {
+ long gregorianValue;
+ for (int rowIndex = 0; rowIndex < rowNumber; rowIndex++) {
+ gregorianValue = RebaseDateTime.rebaseJulianToGregorianMicros(longVec.get(rowIndex));
+ longVec.set(rowIndex, gregorianValue);
+ }
+ }
+
public int next(Vec[] vecList, int[] typeIds) {
- long[] vecNativeIds = new long[realColsCnt];
+ long[] vecNativeIds = new long[typeIds.length];
long rtn = jniReader.recordReaderNext(recordReader, batchReader, typeIds, vecNativeIds);
if (rtn == 0) {
return 0;
}
int nativeGetId = 0;
- for (int i = 0; i < realColsCnt; i++) {
+ for (int i = 0; i < colsToGet.length; i++) {
if (colsToGet[i] != 0) {
continue;
}
@@ -301,7 +244,7 @@ public class OrcColumnarBatchScanReader {
}
default: {
throw new RuntimeException("UnSupport type for ColumnarFileScan:" +
- DataType.DataTypeId.values()[typeIds[i]]);
+ DataType.DataTypeId.values()[typeIds[i]]);
}
}
nativeGetId++;
@@ -309,6 +252,228 @@ public class OrcColumnarBatchScanReader {
return (int)rtn;
}
+ enum OrcOperator {
+ OR,
+ AND,
+ NOT,
+ LEAF,
+ CONSTANT
+ }
+
+ enum OrcLeafOperator {
+ EQUALS,
+ NULL_SAFE_EQUALS,
+ LESS_THAN,
+ LESS_THAN_EQUALS,
+ IN,
+ BETWEEN, // not use, spark transfers it to gt and lt
+ IS_NULL
+ }
+
+ enum OrcPredicateDataType {
+ LONG, // all of integer types
+ FLOAT, // float and double
+ STRING, // string, char, varchar
+ DATE,
+ DECIMAL,
+ TIMESTAMP,
+ BOOLEAN
+ }
+
+ private OrcPredicateDataType getOrcPredicateDataType(String attribute) {
+ StructField field = requiredSchema.apply(attribute);
+ org.apache.spark.sql.types.DataType dataType = field.dataType();
+ if (dataType instanceof ShortType || dataType instanceof IntegerType ||
+ dataType instanceof LongType) {
+ return OrcPredicateDataType.LONG;
+ } else if (dataType instanceof DoubleType) {
+ return OrcPredicateDataType.FLOAT;
+ } else if (dataType instanceof StringType) {
+ if (isCharType(field.metadata())) {
+ throw new UnsupportedOperationException("Unsupported orc push down filter data type: char");
+ }
+ return OrcPredicateDataType.STRING;
+ } else if (dataType instanceof DateType) {
+ return OrcPredicateDataType.DATE;
+ } else if (dataType instanceof DecimalType) {
+ return OrcPredicateDataType.DECIMAL;
+ } else if (dataType instanceof BooleanType) {
+ return OrcPredicateDataType.BOOLEAN;
+ } else {
+ throw new UnsupportedOperationException("Unsupported orc push down filter data type: " +
+ dataType.getClass().getSimpleName());
+ }
+ }
+
+ // Check the type whether is char type, which orc native does not support push down
+ private boolean isCharType(Metadata metadata) {
+ if (metadata != null) {
+ String rawTypeString = CharVarcharUtils.getRawTypeString(metadata).getOrElse(null);
+ if (rawTypeString != null) {
+ Matcher matcher = CHAR_TYPE.matcher(rawTypeString);
+ return matcher.matches();
+ }
+ }
+ return false;
+ }
+
+ private boolean canPushDown(Filter pushedFilter, JSONObject jsonExpressionTree,
+ JSONObject jsonLeaves) {
+ try {
+ getExprJson(pushedFilter, jsonExpressionTree, jsonLeaves);
+ if (leafIndex > MAX_LEAF_THRESHOLD) {
+ throw new UnsupportedOperationException("leaf node nums is " + leafIndex +
+ ", which is bigger than max threshold " + MAX_LEAF_THRESHOLD + ".");
+ }
+ return true;
+ } catch (Exception e) {
+ LOGGER.info("Unable to push down orc filter because " + e.getMessage());
+ return false;
+ }
+ }
+
+ private void getExprJson(Filter filterPredicate, JSONObject jsonExpressionTree,
+ JSONObject jsonLeaves) {
+ if (filterPredicate instanceof And) {
+ addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.AND,
+ ((And) filterPredicate).left(), ((And) filterPredicate).right());
+ } else if (filterPredicate instanceof Or) {
+ addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.OR,
+ ((Or) filterPredicate).left(), ((Or) filterPredicate).right());
+ } else if (filterPredicate instanceof Not) {
+ addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.NOT,
+ ((Not) filterPredicate).child());
+ } else if (filterPredicate instanceof EqualTo) {
+ addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false);
+ addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.EQUALS, jsonLeaves,
+ ((EqualTo) filterPredicate).attribute(), ((EqualTo) filterPredicate).value(), null);
+ leafIndex++;
+ } else if (filterPredicate instanceof GreaterThan) {
+ addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true);
+ addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN_EQUALS, jsonLeaves,
+ ((GreaterThan) filterPredicate).attribute(), ((GreaterThan) filterPredicate).value(), null);
+ leafIndex++;
+ } else if (filterPredicate instanceof GreaterThanOrEqual) {
+ addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true);
+ addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN, jsonLeaves,
+ ((GreaterThanOrEqual) filterPredicate).attribute(), ((GreaterThanOrEqual) filterPredicate).value(), null);
+ leafIndex++;
+ } else if (filterPredicate instanceof LessThan) {
+ addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false);
+ addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN, jsonLeaves,
+ ((LessThan) filterPredicate).attribute(), ((LessThan) filterPredicate).value(), null);
+ leafIndex++;
+ } else if (filterPredicate instanceof LessThanOrEqual) {
+ addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false);
+ addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN_EQUALS, jsonLeaves,
+ ((LessThanOrEqual) filterPredicate).attribute(), ((LessThanOrEqual) filterPredicate).value(), null);
+ leafIndex++;
+ } else if (filterPredicate instanceof IsNotNull) {
+ addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true);
+ addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IS_NULL, jsonLeaves,
+ ((IsNotNull) filterPredicate).attribute(), null, null);
+ leafIndex++;
+ } else if (filterPredicate instanceof IsNull) {
+ addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false);
+ addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IS_NULL, jsonLeaves,
+ ((IsNull) filterPredicate).attribute(), null, null);
+ leafIndex++;
+ } else if (filterPredicate instanceof In) {
+ addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false);
+ addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IN, jsonLeaves,
+ ((In) filterPredicate).attribute(), null, Arrays.stream(((In) filterPredicate).values()).toArray());
+ leafIndex++;
+ } else {
+ throw new UnsupportedOperationException("Unsupported orc push down filter operation: " +
+ filterPredicate.getClass().getSimpleName());
+ }
+ }
+
+ private void addLiteralToJsonLeaves(String leaf, OrcLeafOperator leafOperator, JSONObject jsonLeaves,
+ String name, Object literal, Object[] literals) {
+ JSONObject leafJson = new JSONObject();
+ leafJson.put("op", leafOperator.ordinal());
+ leafJson.put("name", name);
+ leafJson.put("type", getOrcPredicateDataType(name).ordinal());
+
+ leafJson.put("literal", getLiteralValue(literal));
+
+ ArrayList literalList = new ArrayList<>();
+ if (literals != null) {
+ for (Object lit: literalList) {
+ literalList.add(getLiteralValue(literal));
+ }
+ }
+ leafJson.put("literalList", literalList);
+ jsonLeaves.put(leaf, leafJson);
+ }
+
+ private void addToJsonExpressionTree(String leaf, JSONObject jsonExpressionTree, boolean addNot) {
+ if (addNot) {
+ jsonExpressionTree.put("op", OrcOperator.NOT.ordinal());
+ ArrayList child = new ArrayList<>();
+ JSONObject subJson = new JSONObject();
+ subJson.put("op", OrcOperator.LEAF.ordinal());
+ subJson.put("leaf", leaf);
+ child.add(subJson);
+ jsonExpressionTree.put("child", child);
+ } else {
+ jsonExpressionTree.put("op", OrcOperator.LEAF.ordinal());
+ jsonExpressionTree.put("leaf", leaf);
+ }
+ }
+
+ private void addChildJson(JSONObject jsonExpressionTree, JSONObject jsonLeaves,
+ OrcOperator orcOperator, Filter ... filters) {
+ jsonExpressionTree.put("op", orcOperator.ordinal());
+ ArrayList