From 9ef629ec8183e0ee4fd25cfda5913bcc38a594d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 19:33:12 +0800 Subject: [PATCH 01/43] modify slf4 dependcy --- omnioperator/omniop-native-reader/java/pom.xml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/omnioperator/omniop-native-reader/java/pom.xml b/omnioperator/omniop-native-reader/java/pom.xml index e7ddfe6c3..bdad33a3a 100644 --- a/omnioperator/omniop-native-reader/java/pom.xml +++ b/omnioperator/omniop-native-reader/java/pom.xml @@ -8,13 +8,13 @@ com.huawei.boostkit boostkit-omniop-native-reader jar - 3.3.1-1.6.0 + 3.4.3-1.6.0 BoostKit Spark Native Sql Engine Extension With OmniOperator 2.12 - 3.3.1 + 3.4.3 FALSE ../cpp/ ../cpp/build/releases/ @@ -35,8 +35,8 @@ org.slf4j - slf4j-api - 1.7.32 + slf4j-simple + 1.7.36 junit @@ -132,4 +132,4 @@ - \ No newline at end of file + -- Gitee From 6895d78bf89db14198d89e5fda097664c7d6b0cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 19:36:23 +0800 Subject: [PATCH 02/43] modify version and spark_version --- omnioperator/omniop-spark-extension/pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml index 4376a89be..e949cea2d 100644 --- a/omnioperator/omniop-spark-extension/pom.xml +++ b/omnioperator/omniop-spark-extension/pom.xml @@ -8,13 +8,13 @@ com.huawei.kunpeng boostkit-omniop-spark-parent pom - 3.3.1-1.6.0 + 3.4.3-1.6.0 BoostKit Spark Native Sql Engine Extension Parent Pom 2.12.10 2.12 - 3.3.1 + 3.4.3 3.2.2 UTF-8 UTF-8 -- Gitee From 0cc54a9911f8cdb1b90914588cafc1ef8b30f317 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 19:45:53 +0800 Subject: [PATCH 03/43] modify version and spark_version --- omnioperator/omniop-spark-extension/java/pom.xml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml index 9cc1b9d25..a40b415fd 100644 --- a/omnioperator/omniop-spark-extension/java/pom.xml +++ b/omnioperator/omniop-spark-extension/java/pom.xml @@ -7,7 +7,7 @@ com.huawei.kunpeng boostkit-omniop-spark-parent - 3.3.1-1.6.0 + 3.4.3-1.6.0 ../pom.xml @@ -52,7 +52,7 @@ com.huawei.boostkit boostkit-omniop-native-reader - 3.3.1-1.6.0 + 3.4.3-1.6.0 junit @@ -247,7 +247,7 @@ true false ${project.basedir}/src/main/scala - ${project.basedir}/src/test/scala + ${project.basedir}/src/test/scala ${user.dir}/scalastyle-config.xml ${project.basedir}/target/scalastyle-output.xml ${project.build.sourceEncoding} @@ -335,7 +335,7 @@ org.scalatest scalatest-maven-plugin - false + false ${project.build.directory}/surefire-reports . @@ -352,4 +352,4 @@ - \ No newline at end of file + -- Gitee From 630028282846f9e3cd28474e7c7a047e24ccd3ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 19:50:41 +0800 Subject: [PATCH 04/43] modify SubqueryExpression parametre nums and add new method --- .../apache/spark/sql/catalyst/expressions/runtimefilter.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala index 0a5d509b0..85192fc36 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala @@ -39,7 +39,7 @@ case class RuntimeFilterSubquery( exprId: ExprId = NamedExpression.newExprId, hint: Option[HintInfo] = None) extends SubqueryExpression( - filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty) + filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty, hint) with Unevaluable with UnaryLike[Expression] { @@ -74,6 +74,8 @@ case class RuntimeFilterSubquery( override protected def withNewChildInternal(newChild: Expression): RuntimeFilterSubquery = copy(filterApplicationSideExp = newChild) + + override def withNewHint(hint: Option[HintInfo]): RuntimeFilterSubquery = copy(hint = hint) } /** -- Gitee From a93c1f9a86dad08d23068ba5cbf065dcf5ea75bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 19:51:40 +0800 Subject: [PATCH 05/43] add and remove Expression patterns --- .../sql/catalyst/Tree/TreePatterns.scala | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala index ea2712447..ef17a0740 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala @@ -25,7 +25,8 @@ object TreePattern extends Enumeration { // Expression patterns (alphabetically ordered) val AGGREGATE_EXPRESSION = Value(0) val ALIAS: Value = Value - val AND_OR: Value = Value +// val AND_OR: Value = Value + val AND: Value = Value val ARRAYS_ZIP: Value = Value val ATTRIBUTE_REFERENCE: Value = Value val APPEND_COLUMNS: Value = Value @@ -58,6 +59,7 @@ object TreePattern extends Enumeration { val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value + val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value // spark3.4.3 val LATERAL_SUBQUERY: Value = Value val LIKE_FAMLIY: Value = Value val LIST_SUBQUERY: Value = Value @@ -69,27 +71,33 @@ object TreePattern extends Enumeration { val NULL_CHECK: Value = Value val NULL_LITERAL: Value = Value val SERIALIZE_FROM_OBJECT: Value = Value + val OR: Value = Value // spark3.4.3 val OUTER_REFERENCE: Value = Value + val PARAMETER: Value = Value // spark3.4.3 + val PARAMETERIZED_QUERY: Value = Value // spark3.4.3 val PIVOT: Value = Value val PLAN_EXPRESSION: Value = Value val PYTHON_UDF: Value = Value val REGEXP_EXTRACT_FAMILY: Value = Value val REGEXP_REPLACE: Value = Value val RUNTIME_REPLACEABLE: Value = Value - val RUNTIME_FILTER_EXPRESSION: Value = Value - val RUNTIME_FILTER_SUBQUERY: Value = Value + val RUNTIME_FILTER_EXPRESSION: Value = Value // spark3.4.3移除 + val RUNTIME_FILTER_SUBQUERY: Value = Value // spark3.4.3移除 val SCALAR_SUBQUERY: Value = Value val SCALAR_SUBQUERY_REFERENCE: Value = Value val SCALA_UDF: Value = Value + val SESSION_WINDOW: Value = Value // spark3.4.3 val SORT: Value = Value val SUBQUERY_ALIAS: Value = Value - val SUBQUERY_WRAPPER: Value = Value + val SUBQUERY_WRAPPER: Value = Value // spark3.4.3移除 val SUM: Value = Value val TIME_WINDOW: Value = Value val TIME_ZONE_AWARE_EXPRESSION: Value = Value val TRUE_OR_FALSE_LITERAL: Value = Value val WINDOW_EXPRESSION: Value = Value + val WINDOW_TIME: Value = Value // saprk3.4.3 val UNARY_POSITIVE: Value = Value + val UNPIVOT: Value = Value // spark3.4.3 val UPDATE_FIELDS: Value = Value val UPPER_OR_LOWER: Value = Value val UP_CAST: Value = Value @@ -119,6 +127,7 @@ object TreePattern extends Enumeration { val UNION: Value = Value val UNRESOLVED_RELATION: Value = Value val UNRESOLVED_WITH: Value = Value + val TEMP_RESOLVED_COLUMN: Value = Value // spark3.4.3 val TYPED_FILTER: Value = Value val WINDOW: Value = Value val WITH_WINDOW_DEFINITION: Value = Value @@ -127,6 +136,7 @@ object TreePattern extends Enumeration { val UNRESOLVED_ALIAS: Value = Value val UNRESOLVED_ATTRIBUTE: Value = Value val UNRESOLVED_DESERIALIZER: Value = Value + val UNRESOLVED_HAVING: Value = Value // spark3.4.3 val UNRESOLVED_ORDINAL: Value = Value val UNRESOLVED_FUNCTION: Value = Value val UNRESOLVED_HINT: Value = Value @@ -135,6 +145,8 @@ object TreePattern extends Enumeration { // Unresolved Plan patterns (Alphabetically ordered) val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value val UNRESOLVED_FUNC: Value = Value + val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value // spark3.4.3 + val UNRESOLVED_TVF_ALIASES: Value = Value // spark3.4.3 // Execution expression patterns (alphabetically ordered) val IN_SUBQUERY_EXEC: Value = Value -- Gitee From ac1473dc6e4104405caf6e71896aba689429d931 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 19:52:34 +0800 Subject: [PATCH 06/43] repair parametre nums --- .../sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala index 9e4029025..f6ebd716d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala @@ -61,7 +61,7 @@ object RewriteSelfJoinInInPredicate extends Rule[LogicalPlan] with PredicateHelp case f: Filter => f transformExpressions { case in @ InSubquery(_, listQuery @ ListQuery(Project(projectList, - Join(left, right, Inner, Some(joinCond), _)), _, _, _, _)) + Join(left, right, Inner, Some(joinCond), _)), _, _, _, _, _)) if left.canonicalized ne right.canonicalized => val attrMapping = AttributeMap(right.output.zip(left.output)) val subCondExprs = splitConjunctivePredicates(joinCond transform { -- Gitee From 6efd40b5b7f3123935f75a3b4f209890122f0299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 19:55:09 +0800 Subject: [PATCH 07/43] add normalized ,adjust parametres and modify method toInternalError --- .../spark/sql/execution/QueryExecution.scala | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index ef33a84de..37f3db5d6 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -105,6 +105,23 @@ class QueryExecution( case other => other } + // The plan that has been normalized by custom rules, so that it's more likely to hit cache. + lazy val normalized: LogicalPlan = { + val normalizationRules = sparkSession.sessionState.planNormalizationRules + if (normalizationRules.isEmpty) { + commandExecuted + } else { + val planChangeLogger = new PlanChangeLogger[LogicalPlan]() + val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) => + val result = rule.apply(p) + planChangeLogger.logRule(rule.ruleName, p, result) + result + } + planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized) + normalized + } + } // Spark3.4.3 + lazy val withCachedData: LogicalPlan = sparkSession.withActive { assertAnalyzed() assertSupported() @@ -227,7 +244,7 @@ class QueryExecution( // output mode does not matter since there is no `Sink`. new IncrementalExecution( sparkSession, logical, OutputMode.Append(), "", - UUID.randomUUID, UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) + UUID.randomUUID, UUID.randomUUID, 0, None ,OffsetSeqMetadata(0, 0)) } else { this } @@ -494,11 +511,10 @@ object QueryExecution { */ private[sql] def toInternalError(msg: String, e: Throwable): Throwable = e match { case e @ (_: java.lang.NullPointerException | _: java.lang.AssertionError) => - new SparkException( - errorClass = "INTERNAL_ERROR", - messageParameters = Array(msg + - " Please, fill a bug report in, and provide the full stack trace."), - cause = e) + SparkException.internalError( + msg + " You hit a bug in Spark or the Spark plugins you use. Please, report this bug " + + "to the corresponding communities or vendors, and provide the full stack trace.", + e) case e: Throwable => e } @@ -513,4 +529,4 @@ object QueryExecution { case e: Throwable => throw toInternalError(msg, e) } } -} \ No newline at end of file +} -- Gitee From 817c0425ce4efb7d7adea3725ac75454662ccc2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 19:55:32 +0800 Subject: [PATCH 08/43] modify parametres --- .../sql/execution/adaptive/PlanAdaptiveSubqueries.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index b5a1ad375..dfdbe2c70 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -30,11 +30,11 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning(_.containsAnyPattern( SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY, RUNTIME_FILTER_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _) => + case expressions.ScalarSubquery(_, _, exprId, _, _, _) => val subquery = SubqueryExec.createForScalarSubquery( s"subquery#${exprId.id}", subqueryMap(exprId.id)) execution.ScalarSubquery(subquery, exprId) - case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _)) => + case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) => val expr = if (values.length == 1) { values.head } else { @@ -47,7 +47,7 @@ case class PlanAdaptiveSubqueries( val subquery = SubqueryExec(s"subquery#${exprId.id}", subqueryMap(exprId.id)) InSubqueryExec(expr, subquery, exprId, shouldBroadcast = true) case expressions.DynamicPruningSubquery(value, buildPlan, - buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) => + buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) => val name = s"dynamicpruning#${exprId.id}" val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndex, onlyInBroadcast, buildPlan, buildKeys, subqueryMap(exprId.id)) -- Gitee From 7013a77760c297ada613d17245d95a7e5aa94e59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 19:56:07 +0800 Subject: [PATCH 09/43] modify filePath --- .../execution/datasources/parquet/OmniParquetFileFormat.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala index 9504b34d1..3d61f873d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala @@ -80,7 +80,7 @@ class OmniParquetFileFormat extends FileFormat with DataSourceRegister with Logg (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) - val filePath = new Path(new URI(file.filePath)) + val filePath = file.toPath val split = new org.apache.parquet.hadoop.ParquetInputSplit( filePath, -- Gitee From 902deae2b0092bbd810588a185a13d54b82f8b97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 19:57:33 +0800 Subject: [PATCH 10/43] modify filePath and ORC predicate pushdown --- .../datasources/orc/OmniOrcFileFormat.scala | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala index 334800f51..807369004 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala @@ -118,17 +118,27 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ (file: PartitionedFile) => { val conf = broadcastedConf.value.value - val filePath = new Path(new URI(file.filePath)) - val isPPDSafeValue = isPPDSafe(filters, dataSchema).reduceOption(_ && _) +// val filePath = new Path(new URI(file.filePath.urlEncoded)) + val filePath = file.toPath - // ORC predicate pushdown - if (orcFilterPushDown && filters.nonEmpty && isPPDSafeValue.getOrElse(false)) { - OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { - fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f => - OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) - } - } - } + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val orcSchema = + Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions))(_.getSchema) + val isPPDSafeValue = isPPDSafe(filters, dataSchema).reduceOption(_ && _) + val resultedColPruneInfo = OrcUtils.requestedColumnIds( + isCaseSensitive, dataSchema, requiredSchema, orcSchema, conf) + + + // ORC predicate pushdown + if (orcFilterPushDown && filters.nonEmpty && isPPDSafeValue.getOrElse(false)) { + val fileSchema = OrcUtils.toCatalystSchema(orcSchema) + // OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { fileSchema => + OrcFilters.createFilter(fileSchema, filters).foreach { f => + OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) + // } + } + } val taskConf = new Configuration(conf) val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) -- Gitee From a1de18295a282d69537f355e84c0a70f929656f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 19:59:47 +0800 Subject: [PATCH 11/43] modify parametres , remove promotePrecision, modify CastBase to Cast --- .../spark/expression/OmniExpressionAdaptor.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index d1f911a8c..f7ed75fa3 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -76,7 +76,7 @@ object OmniExpressionAdaptor extends Logging { } } - private def unsupportedCastCheck(expr: Expression, cast: CastBase): Unit = { + private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = { def doSupportCastToString(dataType: DataType): Boolean = { dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[IntegerType] || dataType.isInstanceOf[LongType] || dataType.isInstanceOf[DateType] || dataType.isInstanceOf[DoubleType] || @@ -160,8 +160,8 @@ object OmniExpressionAdaptor extends Logging { throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}") } - case promotePrecision: PromotePrecision => - rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap) +// case promotePrecision: PromotePrecision => +// rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap) case sub: Subtract => new JSONObject().put("exprType", "BINARY") @@ -296,7 +296,7 @@ object OmniExpressionAdaptor extends Logging { .put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.len, exprsIndexMap))) // Cast - case cast: CastBase => + case cast: Cast => unsupportedCastCheck(expr, cast) cast.child.dataType match { case NullType => @@ -588,10 +588,10 @@ object OmniExpressionAdaptor extends Logging { rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap) } else { children.head match { - case base: CastBase if base.child.dataType.isInstanceOf[NullType] => + case base: Cast if base.child.dataType.isInstanceOf[NullType] => rewriteToOmniJsonExpressionLiteralJsonObject(children(1), exprsIndexMap) case _ => children(1) match { - case base: CastBase if base.child.dataType.isInstanceOf[NullType] => + case base: Cast if base.child.dataType.isInstanceOf[NullType] => rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap) case _ => new JSONObject().put("exprType", "FUNCTION") -- Gitee From c86a67d2da6bdf99ea66b87b7d3bb7e1acbdd3ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 20:00:23 +0800 Subject: [PATCH 12/43] modify parametres --- .../sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala index aaa244cdf..e1c620e1c 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala @@ -43,7 +43,7 @@ class MergeSubqueryFiltersSuite extends PlanTest { } private def extractorExpression(cteIndex: Int, output: Seq[Attribute], fieldIndex: Int) = { - GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output)), fieldIndex) + GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output, isStreaming = false)), fieldIndex) .as("scalarsubquery()") } -- Gitee From 2dcdc3e5cd912af7ec0bcefda0e77594504400d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 20:00:35 +0800 Subject: [PATCH 13/43] modify parametres --- .../spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala index 1b5baa230..c4435379f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala @@ -643,7 +643,7 @@ object MergeSubqueryFilters extends Rule[LogicalPlan] { val subqueryCTE = header.plan.asInstanceOf[CTERelationDef] GetStructField( ScalarSubquery( - CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output), + CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output, subqueryCTE.isStreaming), exprId = ssr.exprId), ssr.headerIndex) } else { -- Gitee From 3c57000ec57c50a6d683458a61026e68e5747b25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 20:04:31 +0800 Subject: [PATCH 14/43] modify injectBloomFilter,injectInsubqueryFilter,isSelectiveFilterOverScan,hasDynamicPruningSubquery,hasInSubquery,tryInJectRuntimeFilter --- .../optimizer/InjectRuntimeFilter.scala | 67 ++++++++++++++----- 1 file changed, 51 insertions(+), 16 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 812c387bc..e6c266242 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import com.huawei.boostkit.spark.ColumnarPluginConfig +import scala.annotation.tailrec + /** * Insert a filter on one side of the join if the other side has a selective predicate. * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something @@ -85,12 +87,12 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J val bloomFilterAgg = if (rowCount.isDefined && rowCount.get.longValue > 0L) { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), - Literal(rowCount.get.longValue)) + rowCount.get.longValue) } else { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) } - val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) - val alias = Alias(aggExp, "bloomFilter")() + + val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")() val aggregate = ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan))) val bloomFilterSubquery = if (canReuseExchange) { @@ -112,7 +114,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType) val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp) val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)() - val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan) + val aggregate = + ColumnPruning(Aggregate(Seq(filterCreationSideExp), Seq(alias), filterCreationSidePlan)) if (!canBroadcastBySize(aggregate, conf)) { // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold, // i.e., the semi-join will be a shuffled join, which is not worthwhile. @@ -129,13 +132,39 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J * do not add a subquery that might have an expensive computation */ private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { - val ret = plan match { - case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] => - filters.forall(isSimpleExpression) && - filters.exists(isLikelySelective) + @tailrec + def isSelective( + p: LogicalPlan, + predicateReference: AttributeSet, + hasHitFilter: Boolean, + hasHitSelectiveFilter: Boolean): Boolean = p match { + case Project(projectList, child) => + if (hasHitFilter) { + // We need to make sure all expressions referenced by filter predicates are simple + // expressions. + val referencedExprs = projectList.filter(predicateReference.contains) + referencedExprs.forall(isSimpleExpression) && + isSelective( + child, + referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _), + hasHitFilter, + hasHitSelectiveFilter) + } else { + assert(predicateReference.isEmpty && !hasHitSelectiveFilter) + isSelective(child, predicateReference, hasHitFilter, hasHitSelectiveFilter) + } + case Filter(condition, child) => + isSimpleExpression(condition) && isSelective( + child, + predicateReference ++ condition.references, + hasHitFilter = true, + hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition)) + case _: LeafNode => hasHitSelectiveFilter case _ => false } - !plan.isStreaming && ret + + !plan.isStreaming && + isSelective(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false) } private def isSimpleExpression(e: Expression): Boolean = { @@ -184,8 +213,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J /** * Check that: - * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the - * expression references originate from a single leaf node) + * - The filterApplicationSideJoinExp can be pushed down through joins, aggregates and windows + * (ie the expression references originate from a single leaf node) * - The filter creation side has a selective predicate * - The current join is a shuffle join or a broadcast join that has a shuffle below it * - The max filterApplicationSide scan size is greater than a configurable threshold @@ -218,9 +247,9 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J leftKey: Expression, rightKey: Expression): Boolean = { (left, right) match { - case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) => + case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan), _) => pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey) - case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) => + case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan)) => pruningKey.fastEquals(rightKey) || hasDynamicPruningSubquery(left, plan, leftKey, rightKey) case _ => false @@ -251,10 +280,10 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J rightKey: Expression): Boolean = { (left, right) match { case (Filter(InSubquery(Seq(key), - ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) => + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _), _) => key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey))) case (_, Filter(InSubquery(Seq(key), - ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) => + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _)) => key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey))) case _ => false } @@ -299,7 +328,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J case s: Subquery if s.correlated => plan case _ if !conf.runtimeFilterSemiJoinReductionEnabled && !conf.runtimeFilterBloomFilterEnabled => plan - case _ => tryInjectRuntimeFilter(plan) + case _ => + val newPlan = tryInjectRuntimeFilter(plan) + if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) { + RewritePredicateSubquery(newPlan) + } else { + newPlan + } } } \ No newline at end of file -- Gitee From ae03791184939a9f51756a057b2f2de618f7d623 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 20:05:15 +0800 Subject: [PATCH 15/43] modify metastore.uris and hive.db --- .../java/src/test/resources/HiveResource.properties | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties index 89eabe8e6..099e28e8d 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties @@ -2,11 +2,13 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. # -hive.metastore.uris=thrift://server1:9083 +#hive.metastore.uris=thrift://server1:9083 +hive.metastore.uris=thrift://OmniOperator:9083 spark.sql.warehouse.dir=/user/hive/warehouse spark.memory.offHeap.size=8G spark.sql.codegen.wholeStage=false spark.sql.extensions=com.huawei.boostkit.spark.ColumnarPlugin spark.shuffle.manager=org.apache.spark.shuffle.sort.OmniColumnarShuffleManager spark.sql.orc.impl=native -hive.db=tpcds_bin_partitioned_orc_2 \ No newline at end of file +#hive.db=tpcds_bin_partitioned_orc_2 +hive.db=tpcds_bin_partitioned_varchar_orc_2 -- Gitee From 2a06713e82ae19b1607369f856dae7b59e6dcd25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 20:06:35 +0800 Subject: [PATCH 16/43] modify "test child max row" limit(0) to limit(1) --- .../spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index f83edb9ca..ea52aca62 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -61,7 +61,7 @@ class CombiningLimitsSuite extends PlanTest { comparePlans(optimized1, expected1) // test child max row > limit. - val query2 = testRelation.select().groupBy()(count(1)).limit(0).analyze + val query2 = testRelation2.select($"x").groupBy($"x")(count(1)).limit(1).analyze val optimized2 = Optimize.execute(query2) comparePlans(optimized2, query2) -- Gitee From 5fd470d2e77d38f3de7a3bdb12d15fc789a4adf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 20:07:43 +0800 Subject: [PATCH 17/43] modify type of path --- .../spark/sql/execution/ColumnarFileSourceScanExec.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index 800dcf1a0..4b70db452 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -479,10 +479,11 @@ abstract class BaseColumnarFileSourceScanExec( } }.groupBy { f => BucketingUtils - .getBucketId(new Path(f.filePath).getName) - .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.filePath)) + .getBucketId(f.toPath.getName) + .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath)) } + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { val bucketSet = optionalBucketSet.get filesGroupedToBuckets.filter { -- Gitee From deb446159bb9483b44cb7e7595145d021b820c96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 20:08:38 +0800 Subject: [PATCH 18/43] modify traits of ColumnarProjectExec --- .../execution/ColumnarBasicPhysicalOperators.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index 486369843..630034bd7 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -16,10 +16,10 @@ */ package org.apache.spark.sql.execution - import java.util.concurrent.TimeUnit.NANOSECONDS import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP + import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ import com.huawei.boostkit.spark.util.OmniAdaptorUtil import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs @@ -41,11 +41,13 @@ import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.expression.ColumnarExpressionConverter import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.catalyst.plans.{AliasAwareOutputExpression, AliasAwareQueryOutputOrdering} + case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode - with AliasAwareOutputPartitioning - with AliasAwareOutputOrdering { + with AliasAwareOutputExpression + with AliasAwareQueryOutputOrdering[SparkPlan] { override def supportsColumnar: Boolean = true @@ -267,8 +269,8 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], condition: Expression, child: SparkPlan) extends UnaryExecNode - with AliasAwareOutputPartitioning - with AliasAwareOutputOrdering { + with AliasAwareOutputExpression + with AliasAwareQueryOutputOrdering[SparkPlan] { override def supportsColumnar: Boolean = true -- Gitee From ccfc3230bd3912eab482e2851db0026f582da553 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 20:12:41 +0800 Subject: [PATCH 19/43] modify "SPARK-30953" ,"SPARK-31658" ADAPTIVE_EXECUTION_FORCE_APPLY,"SPARK-32932" ADAPTIVE_EXECUTION_ENABLED --- .../ColumnarAdaptiveQueryExecSuite.scala | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala index c34ff5bb1..c0be72f31 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala @@ -19,17 +19,15 @@ package org.apache.spark.sql.execution.adaptive import java.io.File import java.net.URI - import org.apache.logging.log4j.Level import org.scalatest.PrivateMethodTester import org.scalatest.time.SpanSugar._ - import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec} +import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource @@ -37,7 +35,7 @@ import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter -import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate +import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLExecutionStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -1122,13 +1120,21 @@ class AdaptiveQueryExecSuite test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + var plan: SparkPlan = null + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + plan = qe.executedPlan + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } + spark.listenerManager.register(listener) withTable("t1") { - val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan - assert(plan.isInstanceOf[CommandResultExec]) - val commandResultExec = plan.asInstanceOf[CommandResultExec] - assert(commandResultExec.commandPhysicalPlan.isInstanceOf[DataWritingCommandExec]) - assert(commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec] - .child.isInstanceOf[AdaptiveSparkPlanExec]) + val format = classOf[NoopDataSource].getName + Seq((0, 1)).toDF("x", "y").write.format(format).mode("overwrite").save() + sparkContext.listenerBus.waitUntilEmpty() + assert(plan.isInstanceOf[V2TableWriteExec]) + assert(plan.asInstanceOf[V2TableWriteExec].child.isInstanceOf[AdaptiveSparkPlanExec]) + spark.listenerManager.unregister(listener) } } } @@ -1172,15 +1178,14 @@ class AdaptiveQueryExecSuite test("SPARK-31658: SQL UI should show write commands") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "false") { withTable("t1") { - var checkDone = false + var commands: Seq[SparkPlanInfo] = Seq.empty val listener = new SparkListener { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { - case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) => - assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand") - checkDone = true + case start: SparkListenerSQLExecutionStart => + commands = commands ++ Seq(start.sparkPlanInfo) case _ => // ignore other events } } @@ -1189,7 +1194,12 @@ class AdaptiveQueryExecSuite try { sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() spark.sparkContext.listenerBus.waitUntilEmpty() - assert(checkDone) + assert(commands.size == 3) + assert(commands.head.nodeName == "Execute CreateDataSourceTableAsSelectCommand") + assert(commands(1).nodeName == "Execute InsertIntoHadoopFsRelationCommand") + assert(commands(1).children.size == 1) + assert(commands(1).children.head.nodeName == "WriteFiles") + assert(commands(2).nodeName == "CommandResult") } finally { spark.sparkContext.removeSparkListener(listener) } @@ -1574,7 +1584,7 @@ class AdaptiveQueryExecSuite test("SPARK-32932: Do not use local shuffle read at final stage on write command") { withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString, SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { val data = for ( i <- 1L to 10L; j <- 1L to 3L @@ -1584,9 +1594,8 @@ class AdaptiveQueryExecSuite var noLocalread: Boolean = false val listener = new QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - qe.executedPlan match { + stripAQEPlan(qe.executedPlan) match { case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) => - assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec]) noLocalread = collect(plan) { case exec: AQEShuffleReadExec if exec.isLocalRead => exec }.isEmpty -- Gitee From 6c2b46f1878f86bdf60559ed19b947aa818af850 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 20:15:01 +0800 Subject: [PATCH 20/43] modify spark exception --- .../sql/catalyst/expressions/CastSuite.scala | 55 +++++++++++++------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e6b786c2a..329295bac 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkException import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.execution.ColumnarSparkPlanTest import org.apache.spark.sql.types.{DataType, Decimal} @@ -64,7 +65,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getByte(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as short") { @@ -72,7 +74,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getShort(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as int") { @@ -80,7 +83,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getInt(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as long") { @@ -88,7 +92,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getLong(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as float") { @@ -96,7 +101,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getFloat(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as double") { @@ -104,7 +110,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getDouble(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as date") { @@ -154,13 +161,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception4 = intercept[Exception]( result4.collect().toSeq.head.getBoolean(0) ) - assert(exception4.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception4.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception4.isInstanceOf[SparkException], s"sql: ${sql}") val result5 = spark.sql("select cast('test' as boolean);") val exception5 = intercept[Exception]( result5.collect().toSeq.head.getBoolean(0) ) - assert(exception5.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception5.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception5.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast boolean to string") { @@ -182,13 +191,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getByte(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as byte);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getByte(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast byte to string") { @@ -210,13 +221,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getShort(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as short);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getShort(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast short to string") { @@ -238,13 +251,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getInt(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as int);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getInt(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast int to string") { @@ -266,13 +281,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getLong(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as long);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getLong(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast long to string") { @@ -298,7 +315,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception3 = intercept[Exception]( result3.collect().toSeq.head.getFloat(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast float to string") { @@ -324,7 +342,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception3 = intercept[Exception]( result3.collect().toSeq.head.getDouble(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast double to string") { -- Gitee From f7e5c5b235131fb3eeae814e330ccffb099042fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 20:16:57 +0800 Subject: [PATCH 21/43] modify AQEOptimizer and optimizeQueryStage --- .../adaptive/AdaptiveSparkPlanExec.scala | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 2599a1410..e5c8fc0ab 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.execution.adaptive import java.util import java.util.concurrent.LinkedBlockingQueue - import scala.collection.JavaConverters._ import scala.collection.concurrent.TrieMap import scala.collection.mutable import scala.concurrent.ExecutionContext import scala.util.control.NonFatal - import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession @@ -35,12 +33,13 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan import org.apache.spark.sql.execution.exchange._ -import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric} +import org.apache.spark.sql.execution.ui.{SQLPlanMetric, SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SparkFatalException, ThreadUtils} @@ -83,7 +82,9 @@ case class AdaptiveSparkPlanExec( @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]() // The logical plan optimizer for re-optimizing the current logical plan. - @transient private val optimizer = new AQEOptimizer(conf) + @transient private val optimizer = new AQEOptimizer(conf, + session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules + ) // `EnsureRequirements` may remove user-specified repartition and assume the query plan won't // change its output partitioning. This assumption is not true in AQE. Here we check the @@ -122,7 +123,7 @@ case class AdaptiveSparkPlanExec( RemoveRedundantSorts, DisableUnnecessaryBucketedScan, OptimizeSkewedJoin(ensureRequirements) - ) ++ context.session.sessionState.queryStagePrepRules + ) ++ context.session.sessionState.adaptiveRulesHolder.queryStagePrepRules } // A list of physical optimizer rules to be applied to a new stage before its execution. These @@ -152,7 +153,13 @@ case class AdaptiveSparkPlanExec( ) private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = { - val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) => + val rules = if (isFinalStage && + !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) { + queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule]) + } else { + queryStageOptimizerRules + } + val optimized = rules.foldLeft(plan) { case (latestPlan, rule) => val applied = rule.apply(latestPlan) val result = rule match { case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) => @@ -187,7 +194,7 @@ case class AdaptiveSparkPlanExec( @volatile private var currentPhysicalPlan = initialPlan - private var isFinalPlan = false + @volatile private var _isFinalPlan = false private var currentStageId = 0 @@ -204,6 +211,8 @@ case class AdaptiveSparkPlanExec( def executedPlan: SparkPlan = currentPhysicalPlan + def isFinalPlan: Boolean = _isFinalPlan + override def conf: SQLConf = context.session.sessionState.conf override def output: Seq[Attribute] = inputPlan.output @@ -222,6 +231,8 @@ case class AdaptiveSparkPlanExec( .map(_.toLong).filter(SQLExecution.getQueryExecution(_) eq context.qe) } + def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity) // saprk 3.4.3 + private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized { if (isFinalPlan) return currentPhysicalPlan @@ -309,7 +320,8 @@ case class AdaptiveSparkPlanExec( val newCost = costEvaluator.evaluateCost(newPhysicalPlan) if (newCost < origCost || (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) { - logOnLevel(s"Plan changed from $currentPhysicalPlan to $newPhysicalPlan") + logOnLevel("Plan changed:\n" + + sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")) cleanUpTempTags(newPhysicalPlan) currentPhysicalPlan = newPhysicalPlan currentLogicalPlan = newLogicalPlan @@ -325,7 +337,7 @@ case class AdaptiveSparkPlanExec( optimizeQueryStage(result.newPlan, isFinalStage = true), postStageCreationRules(supportsColumnar), Some((planChangeLogger, "AQE Post Stage Creation"))) - isFinalPlan = true + _isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan } @@ -339,7 +351,7 @@ case class AdaptiveSparkPlanExec( if (!isSubquery && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) { getExecutionId.foreach(onUpdatePlan(_, Seq.empty)) } - logOnLevel(s"Final plan: $currentPhysicalPlan") + logOnLevel(s"Final plan: \n$currentPhysicalPlan") } override def executeCollect(): Array[InternalRow] = { -- Gitee From 905f5e86e14f0fcba2db39e3e31c025dda96c3a8 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Thu, 17 Oct 2024 06:58:44 +0000 Subject: [PATCH 22/43] update omnioperator/omniop-native-reader/java/pom.xml. modify spark_version and logger dependency Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- omnioperator/omniop-native-reader/java/pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/omnioperator/omniop-native-reader/java/pom.xml b/omnioperator/omniop-native-reader/java/pom.xml index bdad33a3a..773ed2c39 100644 --- a/omnioperator/omniop-native-reader/java/pom.xml +++ b/omnioperator/omniop-native-reader/java/pom.xml @@ -34,6 +34,7 @@ 1.6.0 + org.slf4j slf4j-simple 1.7.36 -- Gitee From 999221a97b98dcb398b7cf3a6fa5f071737a16b1 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Thu, 17 Oct 2024 07:10:15 +0000 Subject: [PATCH 23/43] restore PromotePrecision mode and modify class CastBase to Cast Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../boostkit/spark/expression/OmniExpressionAdaptor.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index f7ed75fa3..be1538172 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -160,8 +160,8 @@ object OmniExpressionAdaptor extends Logging { throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}") } -// case promotePrecision: PromotePrecision => -// rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap) + case promotePrecision: PromotePrecision => + rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap) case sub: Subtract => new JSONObject().put("exprType", "BINARY") -- Gitee From 30dd5a213b420c8b3f7ffef99adb5105de8b9097 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Thu, 17 Oct 2024 07:15:21 +0000 Subject: [PATCH 24/43] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20analysis?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep new file mode 100644 index 000000000..e69de29bb -- Gitee From 2e17b148b13b72ba5a89b20989abfefc99d3e8a3 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Thu, 17 Oct 2024 07:25:39 +0000 Subject: [PATCH 25/43] add DecimalPrecision.scala Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../apache/spark/sql/catalyst/analysis/.keep | 0 .../catalyst/analysis/DecimalPrecision.scala | 347 ++++++++++++++++++ 2 files changed, 347 insertions(+) delete mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/.keep deleted file mode 100644 index e69de29bb..000000000 diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala new file mode 100644 index 000000000..5a04f02ce --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal._ +import org.apache.spark.sql.types._ + + +// scalastyle:off +/** + * Calculates and propagates precision for fixed-precision decimals. Hive has a number of + * rules for this based on the SQL standard and MS SQL: + * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * https://msdn.microsoft.com/en-us/library/ms190476.aspx + * + * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 + * respectively, then the following operations have the following precision / scale: + * + * Operation Result Precision Result Scale + * ------------------------------------------------------------------------ + * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 * e2 p1 + p2 + 1 s1 + s2 + * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) + * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) + * + * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale + * needed are out of the range of available values, the scale is reduced up to 6, in order to + * prevent the truncation of the integer part of the decimals. + * + * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited + * precision, do the math on unlimited-precision numbers, then introduce casts back to the + * required fixed precision. This allows us to do all rounding and overflow handling in the + * cast-to-fixed-precision operator. + * + * In addition, when mixing non-decimal types with decimals, we use the following rules: + * - BYTE gets turned into DECIMAL(3, 0) + * - SHORT gets turned into DECIMAL(5, 0) + * - INT gets turned into DECIMAL(10, 0) + * - LONG gets turned into DECIMAL(20, 0) + * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + * - Literals INT and LONG get turned into DECIMAL with the precision strictly needed by the value + */ +// scalastyle:on +object DecimalPrecision extends TypeCoercionRule { + import scala.math.{max, min} + + private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + + // Returns the wider decimal type that's wider than both of them + def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = { + widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale) + } + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { + val scale = max(s1, s2) + val range = max(p1 - s1, p2 - s2) + DecimalType.bounded(range + scale, scale) + } + + private def promotePrecision(e: Expression, dataType: DataType): Expression = { + PromotePrecision(Cast(e, dataType)) + } + + override def transform: PartialFunction[Expression, Expression] = { + decimalAndDecimal() + .orElse(integralAndDecimalLiteral) + .orElse(nondecimalAndDecimal(conf.literalPickMinimumPrecision)) + } + + private[catalyst] def decimalAndDecimal(): PartialFunction[Expression, Expression] = { + decimalAndDecimal(conf.decimalOperationsAllowPrecisionLoss, !conf.ansiEnabled) + } + + /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */ + private[catalyst] def decimalAndDecimal(allowPrecisionLoss: Boolean, nullOnOverflow: Boolean) + : PartialFunction[Expression, Expression] = { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + // Skip nodes who is already promoted + case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e + + case a @ Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultScale = max(s1, s2) + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow( + a.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)), + resultType, nullOnOverflow) + + case s @ Subtract(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2), _) => + val resultScale = max(s1, s2) + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow( + s.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)), + resultType, nullOnOverflow) + + case m @ Multiply( + e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) + } else { + DecimalType.bounded(p1 + p2 + 1, s1 + s2) + } + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow( + m.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) + + case d @ Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultType = if (allowPrecisionLoss) { + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + val intDig = p1 - s1 + s2 + val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) + val prec = intDig + scale + DecimalType.adjustPrecisionScale(prec, scale) + } else { + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + DecimalType.bounded(intDig + decDig, decDig) + } + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow( + d.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) + + case r @ Remainder( + e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow( + r.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) + + case p @ Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow( + p.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) + + case expr @ IntegralDivide( + e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val widerType = widerDecimalType(p1, s1, p2, s2) + val promotedExpr = expr.copy( + left = promotePrecision(e1, widerType), + right = promotePrecision(e2, widerType)) + if (expr.dataType.isInstanceOf[DecimalType]) { + // This follows division rule + val intDig = p1 - s1 + s2 + // No precision loss can happen as the result scale is 0. + // Overflow can happen only in the promote precision of the operands, but if none of them + // overflows in that phase, no overflow can happen, but CheckOverflow is needed in order + // to return a decimal with the proper scale and precision + CheckOverflow(promotedExpr, DecimalType.bounded(intDig, 0), nullOnOverflow) + } else { + promotedExpr + } + + case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + val resultType = widerDecimalType(p1, s1, p2, s2) + val newE1 = if (e1.dataType == resultType) e1 else Cast(e1, resultType) + val newE2 = if (e2.dataType == resultType) e2 else Cast(e2, resultType) + b.makeCopy(Array(newE1, newE2)) + } + + /** + * Strength reduction for comparing integral expressions with decimal literals. + * 1. int_col > decimal_literal => int_col > floor(decimal_literal) + * 2. int_col >= decimal_literal => int_col >= ceil(decimal_literal) + * 3. int_col < decimal_literal => int_col < ceil(decimal_literal) + * 4. int_col <= decimal_literal => int_col <= floor(decimal_literal) + * 5. decimal_literal > int_col => ceil(decimal_literal) > int_col + * 6. decimal_literal >= int_col => floor(decimal_literal) >= int_col + * 7. decimal_literal < int_col => floor(decimal_literal) < int_col + * 8. decimal_literal <= int_col => ceil(decimal_literal) <= int_col + * + * Note that technically this is an "optimization" and should go into the optimizer. However, + * by the time the optimizer runs, these comparison expressions would be pretty hard to pattern + * match because there are multiple (at least 2) levels of casts involved. + * + * There are a lot more possible rules we can implement, but we don't do them + * because we are not sure how common they are. + */ + private val integralAndDecimalLiteral: PartialFunction[Expression, Expression] = { + + case GreaterThan(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + GreaterThan(i, Literal(value.floor.toLong)) + } + + case GreaterThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + GreaterThanOrEqual(i, Literal(value.ceil.toLong)) + } + + case LessThan(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + LessThan(i, Literal(value.ceil.toLong)) + } + + case LessThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + LessThanOrEqual(i, Literal(value.floor.toLong)) + } + + case GreaterThan(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + GreaterThan(Literal(value.ceil.toLong), i) + } + + case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + GreaterThanOrEqual(Literal(value.floor.toLong), i) + } + + case LessThan(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + LessThan(Literal(value.floor.toLong), i) + } + + case LessThanOrEqual(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + LessThanOrEqual(Literal(value.ceil.toLong), i) + } + } + + /** + * Type coercion for BinaryOperator in which one side is a non-decimal numeric, and the other + * side is a decimal. + */ + private def nondecimalAndDecimal(literalPickMinimumPrecision: Boolean) + : PartialFunction[Expression, Expression] = { + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + (left, right) match { + // Promote literal integers inside a binary expression with fixed-precision decimals to + // decimals. The precision and scale are the ones strictly needed by the integer value. + // Requiring more precision than necessary may lead to a useless loss of precision. + // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2. + // If we use the default precision and scale for the integer type, 2 is considered a + // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18), + // which is out of range and therefore it will become DECIMAL(38, 7), leading to + // potentially loosing 11 digits of the fractional part. Using only the precision needed + // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would + // become DECIMAL(38, 16), safely having a much lower precision loss. + case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] && + l.dataType.isInstanceOf[IntegralType] && + literalPickMinimumPrecision => + b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r)) + case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] && + r.dataType.isInstanceOf[IntegralType] && + literalPickMinimumPrecision => + b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r)))) + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) => + b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r)) + case (l @ DecimalType.Expression(_, _), r @ IntegralType()) => + b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType)))) + case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) => + b.makeCopy(Array(l, Cast(r, DoubleType))) + case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) => + b.makeCopy(Array(Cast(l, DoubleType), r)) + case _ => b + } + } + +} \ No newline at end of file -- Gitee From 6f5460f564982ca139b58f4ac125dfe22d1fb87c Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Thu, 17 Oct 2024 07:30:53 +0000 Subject: [PATCH 26/43] add decimalExpressions Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../expressions/decimalExpressions.scala | 234 ++++++++++++++++++ 1 file changed, 234 insertions(+) create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala new file mode 100644 index 000000000..ac02f03c2 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * Return the unscaled Long value of a Decimal, assuming it fits in a Long. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ +case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant { + + override def dataType: DataType = LongType + override def toString: String = s"UnscaledValue($child)" + + protected override def nullSafeEval(input: Any): Any = + input.asInstanceOf[Decimal].toUnscaledLong + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") + } + + override protected def withNewChildInternal(newChild: Expression): UnscaledValue = + copy(child = newChild) +} + +/** + * Create a Decimal from an unscaled Long value. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ +case class MakeDecimal( + child: Expression, + precision: Int, + scale: Int, + nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant { + + def this(child: Expression, precision: Int, scale: Int) = { + this(child, precision, scale, !SQLConf.get.ansiEnabled) + } + + override def dataType: DataType = DecimalType(precision, scale) + override def nullable: Boolean = child.nullable || nullOnOverflow + override def toString: String = s"MakeDecimal($child,$precision,$scale)" + + protected override def nullSafeEval(input: Any): Any = { + val longInput = input.asInstanceOf[Long] + val result = new Decimal() + if (nullOnOverflow) { + result.setOrNull(longInput, precision, scale) + } else { + result.set(longInput, precision, scale) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, eval => { + val setMethod = if (nullOnOverflow) { + "setOrNull" + } else { + "set" + } + val setNull = if (nullable) { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } + s""" + |${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale); + |$setNull + |""".stripMargin + }) + } + + override protected def withNewChildInternal(newChild: Expression): MakeDecimal = + copy(child = newChild) +} + +object MakeDecimal { + def apply(child: Expression, precision: Int, scale: Int): MakeDecimal = { + new MakeDecimal(child, precision, scale) + } +} + +/** + * An expression used to wrap the children when promote the precision of DecimalType to avoid + * promote multiple times. + */ +case class PromotePrecision(child: Expression) extends UnaryExpression { + override def dataType: DataType = child.dataType + override def eval(input: InternalRow): Any = child.eval(input) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + child.genCode(ctx) + override def prettyName: String = "promote_precision" + override def sql: String = child.sql + override lazy val canonicalized: Expression = child.canonicalized + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +/** + * Rounds the decimal to given scale and check whether the decimal can fit in provided precision + * or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an + * `ArithmeticException` is thrown. + */ +case class CheckOverflow( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean) extends UnaryExpression with SupportQueryContext { + + override def nullable: Boolean = true + + override def nullSafeEval(input: Any): Any = + input.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow, + queryContext) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val errorContextCode = if (nullOnOverflow) { + "\"\"" + } else { + ctx.addReferenceObj("errCtx", queryContext) + } + nullSafeCodeGen(ctx, ev, eval => { + // scalastyle:off line.size.limit + s""" + |${ev.value} = $eval.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode); + |${ev.isNull} = ${ev.value} == null; + """.stripMargin + // scalastyle:on line.size.limit + }) + } + + override def toString: String = s"CheckOverflow($child, $dataType)" + + override def sql: String = child.sql + + override protected def withNewChildInternal(newChild: Expression): CheckOverflow = + copy(child = newChild) + + override def initQueryContext(): String = if (nullOnOverflow) { + "" + } else { + origin.context + } +} + +// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`. +case class CheckOverflowInSum( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean, + queryContext: String = "") extends UnaryExpression { + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + if (nullOnOverflow) null + else throw QueryExecutionErrors.overflowInSumOfDecimalError(queryContext) + } else { + value.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow, + queryContext) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val errorContextCode = if (nullOnOverflow) { + "\"\"" + } else { + ctx.addReferenceObj("errCtx", queryContext) + } + val nullHandling = if (nullOnOverflow) { + "" + } else { + s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);" + } + // scalastyle:off line.size.limit + val code = code""" + |${childGen.code} + |boolean ${ev.isNull} = ${childGen.isNull}; + |Decimal ${ev.value} = null; + |if (${childGen.isNull}) { + | $nullHandling + |} else { + | ${ev.value} = ${childGen.value}.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode); + | ${ev.isNull} = ${ev.value} == null; + |} + |""".stripMargin + // scalastyle:on line.size.limit + + ev.copy(code = code) + } + + override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)" + + override def sql: String = child.sql + + override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum = + copy(child = newChild) +} \ No newline at end of file -- Gitee From 853985b46efab210acc49d04614113e3f7eead03 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Thu, 17 Oct 2024 07:48:33 +0000 Subject: [PATCH 27/43] update decimalExpressions Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../expressions/decimalExpressions.scala | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index ac02f03c2..2b6d3ff33 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -138,7 +139,7 @@ case class CheckOverflow( dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow, - queryContext) + getContextOrNull()) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val errorContextCode = if (nullOnOverflow) { @@ -164,10 +165,10 @@ case class CheckOverflow( override protected def withNewChildInternal(newChild: Expression): CheckOverflow = copy(child = newChild) - override def initQueryContext(): String = if (nullOnOverflow) { - "" + override def initQueryContext(): Option[SQLQueryContext] = if (nullOnOverflow) { + Some(origin.context) } else { - origin.context + None } } @@ -176,7 +177,7 @@ case class CheckOverflowInSum( child: Expression, dataType: DecimalType, nullOnOverflow: Boolean, - queryContext: String = "") extends UnaryExpression { + context: SQLQueryContext) extends UnaryExpression { override def nullable: Boolean = true @@ -184,14 +185,14 @@ case class CheckOverflowInSum( val value = child.eval(input) if (value == null) { if (nullOnOverflow) null - else throw QueryExecutionErrors.overflowInSumOfDecimalError(queryContext) + else throw QueryExecutionErrors.overflowInSumOfDecimalError(context) } else { value.asInstanceOf[Decimal].toPrecision( dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow, - queryContext) + context) } } @@ -200,7 +201,7 @@ case class CheckOverflowInSum( val errorContextCode = if (nullOnOverflow) { "\"\"" } else { - ctx.addReferenceObj("errCtx", queryContext) + ctx.addReferenceObj("errCtx", context) } val nullHandling = if (nullOnOverflow) { "" -- Gitee From c3519a4694f21d114b1fb16a5fc31a8c31c161fb Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Thu, 17 Oct 2024 08:19:51 +0000 Subject: [PATCH 28/43] update injectRuntimeFilter Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../optimizer/InjectRuntimeFilter.scala | 59 ++++--------------- 1 file changed, 12 insertions(+), 47 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index e6c266242..b758b69e6 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import com.huawei.boostkit.spark.ColumnarPluginConfig -import scala.annotation.tailrec - /** * Insert a filter on one side of the join if the other side has a selective predicate. * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something @@ -87,12 +85,12 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J val bloomFilterAgg = if (rowCount.isDefined && rowCount.get.longValue > 0L) { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), - rowCount.get.longValue) + Literal(rowCount.get.longValue) } else { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) } - - val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")() + val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) + val alias = Alias(aggExp, "bloomFilter")() val aggregate = ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan))) val bloomFilterSubquery = if (canReuseExchange) { @@ -114,8 +112,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType) val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp) val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)() - val aggregate = - ColumnPruning(Aggregate(Seq(filterCreationSideExp), Seq(alias), filterCreationSidePlan)) + val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan) if (!canBroadcastBySize(aggregate, conf)) { // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold, // i.e., the semi-join will be a shuffled join, which is not worthwhile. @@ -132,39 +129,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J * do not add a subquery that might have an expensive computation */ private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { - @tailrec - def isSelective( - p: LogicalPlan, - predicateReference: AttributeSet, - hasHitFilter: Boolean, - hasHitSelectiveFilter: Boolean): Boolean = p match { - case Project(projectList, child) => - if (hasHitFilter) { - // We need to make sure all expressions referenced by filter predicates are simple - // expressions. - val referencedExprs = projectList.filter(predicateReference.contains) - referencedExprs.forall(isSimpleExpression) && - isSelective( - child, - referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _), - hasHitFilter, - hasHitSelectiveFilter) - } else { - assert(predicateReference.isEmpty && !hasHitSelectiveFilter) - isSelective(child, predicateReference, hasHitFilter, hasHitSelectiveFilter) - } - case Filter(condition, child) => - isSimpleExpression(condition) && isSelective( - child, - predicateReference ++ condition.references, - hasHitFilter = true, - hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition)) - case _: LeafNode => hasHitSelectiveFilter + val ret = plan match { + case PhysicalOperation(_, filters, child) if child.isInstance[LeafNode] => + filters.forall(isSimpleExpression) && + filters.exists(isLikelySelective) case _ => false } - - !plan.isStreaming && - isSelective(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false) + !plan.isStreaming && ret } private def isSimpleExpression(e: Expression): Boolean = { @@ -213,8 +184,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J /** * Check that: - * - The filterApplicationSideJoinExp can be pushed down through joins, aggregates and windows - * (ie the expression references originate from a single leaf node) + * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the + * expression references originate from a single leaf node) * - The filter creation side has a selective predicate * - The current join is a shuffle join or a broadcast join that has a shuffle below it * - The max filterApplicationSide scan size is greater than a configurable threshold @@ -328,13 +299,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J case s: Subquery if s.correlated => plan case _ if !conf.runtimeFilterSemiJoinReductionEnabled && !conf.runtimeFilterBloomFilterEnabled => plan - case _ => - val newPlan = tryInjectRuntimeFilter(plan) - if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) { - RewritePredicateSubquery(newPlan) - } else { - newPlan - } + case _ => tryInjectRuntimeFilter(plan) } } \ No newline at end of file -- Gitee From 6eeee26dc0015796f8a6ef7eec40af0d2bb4c3c9 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Thu, 17 Oct 2024 08:23:52 +0000 Subject: [PATCH 29/43] update TreePatterns Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../sql/catalyst/Tree/TreePatterns.scala | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala index ef17a0740..4d8c246ab 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala @@ -25,7 +25,7 @@ object TreePattern extends Enumeration { // Expression patterns (alphabetically ordered) val AGGREGATE_EXPRESSION = Value(0) val ALIAS: Value = Value -// val AND_OR: Value = Value + val AND_OR: Value = Value val AND: Value = Value val ARRAYS_ZIP: Value = Value val ATTRIBUTE_REFERENCE: Value = Value @@ -59,7 +59,7 @@ object TreePattern extends Enumeration { val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value - val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value // spark3.4.3 + val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value val LATERAL_SUBQUERY: Value = Value val LIKE_FAMLIY: Value = Value val LIST_SUBQUERY: Value = Value @@ -71,33 +71,33 @@ object TreePattern extends Enumeration { val NULL_CHECK: Value = Value val NULL_LITERAL: Value = Value val SERIALIZE_FROM_OBJECT: Value = Value - val OR: Value = Value // spark3.4.3 + val OR: Value = Value val OUTER_REFERENCE: Value = Value - val PARAMETER: Value = Value // spark3.4.3 - val PARAMETERIZED_QUERY: Value = Value // spark3.4.3 + val PARAMETER: Value = Value + val PARAMETERIZED_QUERY: Value = Value val PIVOT: Value = Value val PLAN_EXPRESSION: Value = Value val PYTHON_UDF: Value = Value val REGEXP_EXTRACT_FAMILY: Value = Value val REGEXP_REPLACE: Value = Value val RUNTIME_REPLACEABLE: Value = Value - val RUNTIME_FILTER_EXPRESSION: Value = Value // spark3.4.3移除 - val RUNTIME_FILTER_SUBQUERY: Value = Value // spark3.4.3移除 + val RUNTIME_FILTER_EXPRESSION: Value = Value + val RUNTIME_FILTER_SUBQUERY: Value = Value val SCALAR_SUBQUERY: Value = Value val SCALAR_SUBQUERY_REFERENCE: Value = Value val SCALA_UDF: Value = Value - val SESSION_WINDOW: Value = Value // spark3.4.3 + val SESSION_WINDOW: Value = Value val SORT: Value = Value val SUBQUERY_ALIAS: Value = Value - val SUBQUERY_WRAPPER: Value = Value // spark3.4.3移除 + val SUBQUERY_WRAPPER: Value = Value val SUM: Value = Value val TIME_WINDOW: Value = Value val TIME_ZONE_AWARE_EXPRESSION: Value = Value val TRUE_OR_FALSE_LITERAL: Value = Value val WINDOW_EXPRESSION: Value = Value - val WINDOW_TIME: Value = Value // saprk3.4.3 + val WINDOW_TIME: Value = Value val UNARY_POSITIVE: Value = Value - val UNPIVOT: Value = Value // spark3.4.3 + val UNPIVOT: Value = Value val UPDATE_FIELDS: Value = Value val UPPER_OR_LOWER: Value = Value val UP_CAST: Value = Value @@ -127,7 +127,7 @@ object TreePattern extends Enumeration { val UNION: Value = Value val UNRESOLVED_RELATION: Value = Value val UNRESOLVED_WITH: Value = Value - val TEMP_RESOLVED_COLUMN: Value = Value // spark3.4.3 + val TEMP_RESOLVED_COLUMN: Value = Value val TYPED_FILTER: Value = Value val WINDOW: Value = Value val WITH_WINDOW_DEFINITION: Value = Value @@ -136,7 +136,7 @@ object TreePattern extends Enumeration { val UNRESOLVED_ALIAS: Value = Value val UNRESOLVED_ATTRIBUTE: Value = Value val UNRESOLVED_DESERIALIZER: Value = Value - val UNRESOLVED_HAVING: Value = Value // spark3.4.3 + val UNRESOLVED_HAVING: Value = Value val UNRESOLVED_ORDINAL: Value = Value val UNRESOLVED_FUNCTION: Value = Value val UNRESOLVED_HINT: Value = Value @@ -145,8 +145,8 @@ object TreePattern extends Enumeration { // Unresolved Plan patterns (Alphabetically ordered) val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value val UNRESOLVED_FUNC: Value = Value - val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value // spark3.4.3 - val UNRESOLVED_TVF_ALIASES: Value = Value // spark3.4.3 + val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value + val UNRESOLVED_TVF_ALIASES: Value = Value // Execution expression patterns (alphabetically ordered) val IN_SUBQUERY_EXEC: Value = Value -- Gitee From cd9fe00c2c3ed2adb232da4d47a160965f4ce3d3 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Thu, 17 Oct 2024 08:31:04 +0000 Subject: [PATCH 30/43] update AdaptiveSparkPlanExec Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../adaptive/AdaptiveSparkPlanExec.scala | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index e5c8fc0ab..b5d638138 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.adaptive import java.util import java.util.concurrent.LinkedBlockingQueue + import scala.collection.JavaConverters._ import scala.collection.concurrent.TrieMap import scala.collection.mutable import scala.concurrent.ExecutionContext import scala.util.control.NonFatal + import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession @@ -33,7 +35,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ @@ -153,13 +154,7 @@ case class AdaptiveSparkPlanExec( ) private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = { - val rules = if (isFinalStage && - !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) { - queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule]) - } else { - queryStageOptimizerRules - } - val optimized = rules.foldLeft(plan) { case (latestPlan, rule) => + val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) => val applied = rule.apply(latestPlan) val result = rule match { case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) => @@ -194,7 +189,7 @@ case class AdaptiveSparkPlanExec( @volatile private var currentPhysicalPlan = initialPlan - @volatile private var _isFinalPlan = false + volatile private var isFinalPlan = false private var currentStageId = 0 @@ -211,8 +206,6 @@ case class AdaptiveSparkPlanExec( def executedPlan: SparkPlan = currentPhysicalPlan - def isFinalPlan: Boolean = _isFinalPlan - override def conf: SQLConf = context.session.sessionState.conf override def output: Seq[Attribute] = inputPlan.output @@ -231,8 +224,6 @@ case class AdaptiveSparkPlanExec( .map(_.toLong).filter(SQLExecution.getQueryExecution(_) eq context.qe) } - def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity) // saprk 3.4.3 - private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized { if (isFinalPlan) return currentPhysicalPlan @@ -320,8 +311,7 @@ case class AdaptiveSparkPlanExec( val newCost = costEvaluator.evaluateCost(newPhysicalPlan) if (newCost < origCost || (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) { - logOnLevel("Plan changed:\n" + - sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")) + logOnLevel("Plan changed from $currentPhysicalPlan to $newPhysicalPlan") cleanUpTempTags(newPhysicalPlan) currentPhysicalPlan = newPhysicalPlan currentLogicalPlan = newLogicalPlan @@ -337,7 +327,7 @@ case class AdaptiveSparkPlanExec( optimizeQueryStage(result.newPlan, isFinalStage = true), postStageCreationRules(supportsColumnar), Some((planChangeLogger, "AQE Post Stage Creation"))) - _isFinalPlan = true + isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan } @@ -351,7 +341,7 @@ case class AdaptiveSparkPlanExec( if (!isSubquery && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) { getExecutionId.foreach(onUpdatePlan(_, Seq.empty)) } - logOnLevel(s"Final plan: \n$currentPhysicalPlan") + logOnLevel(s"Final plan: $currentPhysicalPlan") } override def executeCollect(): Array[InternalRow] = { -- Gitee From cac18688e9ec1ed567dd8521b44f27c5bee02d2c Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Thu, 17 Oct 2024 08:35:09 +0000 Subject: [PATCH 31/43] update ColumnarBasicPhysicalOperators Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../spark/sql/execution/ColumnarBasicPhysicalOperators.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index 630034bd7..b6dd2ab51 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import java.util.concurrent.TimeUnit.NANOSECONDS import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP - import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ import com.huawei.boostkit.spark.util.OmniAdaptorUtil import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs @@ -43,7 +42,6 @@ import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.sql.catalyst.plans.{AliasAwareOutputExpression, AliasAwareQueryOutputOrdering} - case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode with AliasAwareOutputExpression -- Gitee From 5d435130ec8b3c618391d55d5ddcfc5d90e66263 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Thu, 17 Oct 2024 08:37:20 +0000 Subject: [PATCH 32/43] update QueryExecution Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../spark/sql/execution/QueryExecution.scala | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 37f3db5d6..e49c88191 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -105,23 +105,6 @@ class QueryExecution( case other => other } - // The plan that has been normalized by custom rules, so that it's more likely to hit cache. - lazy val normalized: LogicalPlan = { - val normalizationRules = sparkSession.sessionState.planNormalizationRules - if (normalizationRules.isEmpty) { - commandExecuted - } else { - val planChangeLogger = new PlanChangeLogger[LogicalPlan]() - val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) => - val result = rule.apply(p) - planChangeLogger.logRule(rule.ruleName, p, result) - result - } - planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized) - normalized - } - } // Spark3.4.3 - lazy val withCachedData: LogicalPlan = sparkSession.withActive { assertAnalyzed() assertSupported() @@ -529,4 +512,4 @@ object QueryExecution { case e: Throwable => throw toInternalError(msg, e) } } -} +} \ No newline at end of file -- Gitee From a14e7de692a395ad21cabbcff492c03f84d1d301 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A8=81=E5=A8=81=E7=8C=AB?= <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 14 Oct 2024 19:45:53 +0800 Subject: [PATCH 33/43] modify version and spark_version modify SubqueryExpression parametre nums and add new method add and remove Expression patterns repair parametre nums add normalized ,adjust parametres and modify method toInternalError modify parametres modify filePath modify filePath and ORC predicate pushdown modify parametres , remove promotePrecision, modify CastBase to Cast modify parametres modify parametres modify injectBloomFilter,injectInsubqueryFilter,isSelectiveFilterOverScan,hasDynamicPruningSubquery,hasInSubquery,tryInJectRuntimeFilter modify metastore.uris and hive.db modify "test child max row" limit(0) to limit(1) modify type of path modify traits of ColumnarProjectExec modify "SPARK-30953" ,"SPARK-31658" ADAPTIVE_EXECUTION_FORCE_APPLY,"SPARK-32932" ADAPTIVE_EXECUTION_ENABLED modify spark exception modify AQEOptimizer and optimizeQueryStage --- .../omniop-spark-extension/java/pom.xml | 10 +-- .../expression/OmniExpressionAdaptor.scala | 12 ++-- .../sql/catalyst/Tree/TreePatterns.scala | 20 ++++-- .../catalyst/expressions/runtimefilter.scala | 4 +- .../optimizer/InjectRuntimeFilter.scala | 67 ++++++++++++++----- .../optimizer/MergeSubqueryFilters.scala | 2 +- .../RewriteSelfJoinInInPredicate.scala | 2 +- .../ColumnarBasicPhysicalOperators.scala | 12 ++-- .../ColumnarFileSourceScanExec.scala | 5 +- .../spark/sql/execution/QueryExecution.scala | 30 +++++++-- .../adaptive/AdaptiveSparkPlanExec.scala | 32 ++++++--- .../adaptive/PlanAdaptiveSubqueries.scala | 6 +- .../datasources/orc/OmniOrcFileFormat.scala | 30 ++++++--- .../parquet/OmniParquetFileFormat.scala | 2 +- .../test/resources/HiveResource.properties | 6 +- .../sql/catalyst/expressions/CastSuite.scala | 55 ++++++++++----- .../optimizer/CombiningLimitsSuite.scala | 2 +- .../optimizer/MergeSubqueryFiltersSuite.scala | 2 +- .../ColumnarAdaptiveQueryExecSuite.scala | 47 +++++++------ omnioperator/omniop-spark-extension/pom.xml | 4 +- 20 files changed, 235 insertions(+), 115 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml index 9cc1b9d25..a40b415fd 100644 --- a/omnioperator/omniop-spark-extension/java/pom.xml +++ b/omnioperator/omniop-spark-extension/java/pom.xml @@ -7,7 +7,7 @@ com.huawei.kunpeng boostkit-omniop-spark-parent - 3.3.1-1.6.0 + 3.4.3-1.6.0 ../pom.xml @@ -52,7 +52,7 @@ com.huawei.boostkit boostkit-omniop-native-reader - 3.3.1-1.6.0 + 3.4.3-1.6.0 junit @@ -247,7 +247,7 @@ true false ${project.basedir}/src/main/scala - ${project.basedir}/src/test/scala + ${project.basedir}/src/test/scala ${user.dir}/scalastyle-config.xml ${project.basedir}/target/scalastyle-output.xml ${project.build.sourceEncoding} @@ -335,7 +335,7 @@ org.scalatest scalatest-maven-plugin - false + false ${project.build.directory}/surefire-reports . @@ -352,4 +352,4 @@ - \ No newline at end of file + diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index d1f911a8c..f7ed75fa3 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -76,7 +76,7 @@ object OmniExpressionAdaptor extends Logging { } } - private def unsupportedCastCheck(expr: Expression, cast: CastBase): Unit = { + private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = { def doSupportCastToString(dataType: DataType): Boolean = { dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[IntegerType] || dataType.isInstanceOf[LongType] || dataType.isInstanceOf[DateType] || dataType.isInstanceOf[DoubleType] || @@ -160,8 +160,8 @@ object OmniExpressionAdaptor extends Logging { throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}") } - case promotePrecision: PromotePrecision => - rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap) +// case promotePrecision: PromotePrecision => +// rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap) case sub: Subtract => new JSONObject().put("exprType", "BINARY") @@ -296,7 +296,7 @@ object OmniExpressionAdaptor extends Logging { .put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.len, exprsIndexMap))) // Cast - case cast: CastBase => + case cast: Cast => unsupportedCastCheck(expr, cast) cast.child.dataType match { case NullType => @@ -588,10 +588,10 @@ object OmniExpressionAdaptor extends Logging { rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap) } else { children.head match { - case base: CastBase if base.child.dataType.isInstanceOf[NullType] => + case base: Cast if base.child.dataType.isInstanceOf[NullType] => rewriteToOmniJsonExpressionLiteralJsonObject(children(1), exprsIndexMap) case _ => children(1) match { - case base: CastBase if base.child.dataType.isInstanceOf[NullType] => + case base: Cast if base.child.dataType.isInstanceOf[NullType] => rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap) case _ => new JSONObject().put("exprType", "FUNCTION") diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala index ea2712447..ef17a0740 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala @@ -25,7 +25,8 @@ object TreePattern extends Enumeration { // Expression patterns (alphabetically ordered) val AGGREGATE_EXPRESSION = Value(0) val ALIAS: Value = Value - val AND_OR: Value = Value +// val AND_OR: Value = Value + val AND: Value = Value val ARRAYS_ZIP: Value = Value val ATTRIBUTE_REFERENCE: Value = Value val APPEND_COLUMNS: Value = Value @@ -58,6 +59,7 @@ object TreePattern extends Enumeration { val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value + val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value // spark3.4.3 val LATERAL_SUBQUERY: Value = Value val LIKE_FAMLIY: Value = Value val LIST_SUBQUERY: Value = Value @@ -69,27 +71,33 @@ object TreePattern extends Enumeration { val NULL_CHECK: Value = Value val NULL_LITERAL: Value = Value val SERIALIZE_FROM_OBJECT: Value = Value + val OR: Value = Value // spark3.4.3 val OUTER_REFERENCE: Value = Value + val PARAMETER: Value = Value // spark3.4.3 + val PARAMETERIZED_QUERY: Value = Value // spark3.4.3 val PIVOT: Value = Value val PLAN_EXPRESSION: Value = Value val PYTHON_UDF: Value = Value val REGEXP_EXTRACT_FAMILY: Value = Value val REGEXP_REPLACE: Value = Value val RUNTIME_REPLACEABLE: Value = Value - val RUNTIME_FILTER_EXPRESSION: Value = Value - val RUNTIME_FILTER_SUBQUERY: Value = Value + val RUNTIME_FILTER_EXPRESSION: Value = Value // spark3.4.3移除 + val RUNTIME_FILTER_SUBQUERY: Value = Value // spark3.4.3移除 val SCALAR_SUBQUERY: Value = Value val SCALAR_SUBQUERY_REFERENCE: Value = Value val SCALA_UDF: Value = Value + val SESSION_WINDOW: Value = Value // spark3.4.3 val SORT: Value = Value val SUBQUERY_ALIAS: Value = Value - val SUBQUERY_WRAPPER: Value = Value + val SUBQUERY_WRAPPER: Value = Value // spark3.4.3移除 val SUM: Value = Value val TIME_WINDOW: Value = Value val TIME_ZONE_AWARE_EXPRESSION: Value = Value val TRUE_OR_FALSE_LITERAL: Value = Value val WINDOW_EXPRESSION: Value = Value + val WINDOW_TIME: Value = Value // saprk3.4.3 val UNARY_POSITIVE: Value = Value + val UNPIVOT: Value = Value // spark3.4.3 val UPDATE_FIELDS: Value = Value val UPPER_OR_LOWER: Value = Value val UP_CAST: Value = Value @@ -119,6 +127,7 @@ object TreePattern extends Enumeration { val UNION: Value = Value val UNRESOLVED_RELATION: Value = Value val UNRESOLVED_WITH: Value = Value + val TEMP_RESOLVED_COLUMN: Value = Value // spark3.4.3 val TYPED_FILTER: Value = Value val WINDOW: Value = Value val WITH_WINDOW_DEFINITION: Value = Value @@ -127,6 +136,7 @@ object TreePattern extends Enumeration { val UNRESOLVED_ALIAS: Value = Value val UNRESOLVED_ATTRIBUTE: Value = Value val UNRESOLVED_DESERIALIZER: Value = Value + val UNRESOLVED_HAVING: Value = Value // spark3.4.3 val UNRESOLVED_ORDINAL: Value = Value val UNRESOLVED_FUNCTION: Value = Value val UNRESOLVED_HINT: Value = Value @@ -135,6 +145,8 @@ object TreePattern extends Enumeration { // Unresolved Plan patterns (Alphabetically ordered) val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value val UNRESOLVED_FUNC: Value = Value + val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value // spark3.4.3 + val UNRESOLVED_TVF_ALIASES: Value = Value // spark3.4.3 // Execution expression patterns (alphabetically ordered) val IN_SUBQUERY_EXEC: Value = Value diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala index 0a5d509b0..85192fc36 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala @@ -39,7 +39,7 @@ case class RuntimeFilterSubquery( exprId: ExprId = NamedExpression.newExprId, hint: Option[HintInfo] = None) extends SubqueryExpression( - filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty) + filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty, hint) with Unevaluable with UnaryLike[Expression] { @@ -74,6 +74,8 @@ case class RuntimeFilterSubquery( override protected def withNewChildInternal(newChild: Expression): RuntimeFilterSubquery = copy(filterApplicationSideExp = newChild) + + override def withNewHint(hint: Option[HintInfo]): RuntimeFilterSubquery = copy(hint = hint) } /** diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 812c387bc..e6c266242 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import com.huawei.boostkit.spark.ColumnarPluginConfig +import scala.annotation.tailrec + /** * Insert a filter on one side of the join if the other side has a selective predicate. * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something @@ -85,12 +87,12 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J val bloomFilterAgg = if (rowCount.isDefined && rowCount.get.longValue > 0L) { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), - Literal(rowCount.get.longValue)) + rowCount.get.longValue) } else { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) } - val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) - val alias = Alias(aggExp, "bloomFilter")() + + val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")() val aggregate = ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan))) val bloomFilterSubquery = if (canReuseExchange) { @@ -112,7 +114,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType) val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp) val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)() - val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan) + val aggregate = + ColumnPruning(Aggregate(Seq(filterCreationSideExp), Seq(alias), filterCreationSidePlan)) if (!canBroadcastBySize(aggregate, conf)) { // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold, // i.e., the semi-join will be a shuffled join, which is not worthwhile. @@ -129,13 +132,39 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J * do not add a subquery that might have an expensive computation */ private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { - val ret = plan match { - case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] => - filters.forall(isSimpleExpression) && - filters.exists(isLikelySelective) + @tailrec + def isSelective( + p: LogicalPlan, + predicateReference: AttributeSet, + hasHitFilter: Boolean, + hasHitSelectiveFilter: Boolean): Boolean = p match { + case Project(projectList, child) => + if (hasHitFilter) { + // We need to make sure all expressions referenced by filter predicates are simple + // expressions. + val referencedExprs = projectList.filter(predicateReference.contains) + referencedExprs.forall(isSimpleExpression) && + isSelective( + child, + referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _), + hasHitFilter, + hasHitSelectiveFilter) + } else { + assert(predicateReference.isEmpty && !hasHitSelectiveFilter) + isSelective(child, predicateReference, hasHitFilter, hasHitSelectiveFilter) + } + case Filter(condition, child) => + isSimpleExpression(condition) && isSelective( + child, + predicateReference ++ condition.references, + hasHitFilter = true, + hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition)) + case _: LeafNode => hasHitSelectiveFilter case _ => false } - !plan.isStreaming && ret + + !plan.isStreaming && + isSelective(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false) } private def isSimpleExpression(e: Expression): Boolean = { @@ -184,8 +213,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J /** * Check that: - * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the - * expression references originate from a single leaf node) + * - The filterApplicationSideJoinExp can be pushed down through joins, aggregates and windows + * (ie the expression references originate from a single leaf node) * - The filter creation side has a selective predicate * - The current join is a shuffle join or a broadcast join that has a shuffle below it * - The max filterApplicationSide scan size is greater than a configurable threshold @@ -218,9 +247,9 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J leftKey: Expression, rightKey: Expression): Boolean = { (left, right) match { - case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) => + case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan), _) => pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey) - case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) => + case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan)) => pruningKey.fastEquals(rightKey) || hasDynamicPruningSubquery(left, plan, leftKey, rightKey) case _ => false @@ -251,10 +280,10 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J rightKey: Expression): Boolean = { (left, right) match { case (Filter(InSubquery(Seq(key), - ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) => + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _), _) => key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey))) case (_, Filter(InSubquery(Seq(key), - ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) => + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _)) => key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey))) case _ => false } @@ -299,7 +328,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J case s: Subquery if s.correlated => plan case _ if !conf.runtimeFilterSemiJoinReductionEnabled && !conf.runtimeFilterBloomFilterEnabled => plan - case _ => tryInjectRuntimeFilter(plan) + case _ => + val newPlan = tryInjectRuntimeFilter(plan) + if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) { + RewritePredicateSubquery(newPlan) + } else { + newPlan + } } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala index 1b5baa230..c4435379f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala @@ -643,7 +643,7 @@ object MergeSubqueryFilters extends Rule[LogicalPlan] { val subqueryCTE = header.plan.asInstanceOf[CTERelationDef] GetStructField( ScalarSubquery( - CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output), + CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output, subqueryCTE.isStreaming), exprId = ssr.exprId), ssr.headerIndex) } else { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala index 9e4029025..f6ebd716d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala @@ -61,7 +61,7 @@ object RewriteSelfJoinInInPredicate extends Rule[LogicalPlan] with PredicateHelp case f: Filter => f transformExpressions { case in @ InSubquery(_, listQuery @ ListQuery(Project(projectList, - Join(left, right, Inner, Some(joinCond), _)), _, _, _, _)) + Join(left, right, Inner, Some(joinCond), _)), _, _, _, _, _)) if left.canonicalized ne right.canonicalized => val attrMapping = AttributeMap(right.output.zip(left.output)) val subCondExprs = splitConjunctivePredicates(joinCond transform { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index 486369843..630034bd7 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -16,10 +16,10 @@ */ package org.apache.spark.sql.execution - import java.util.concurrent.TimeUnit.NANOSECONDS import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP + import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ import com.huawei.boostkit.spark.util.OmniAdaptorUtil import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs @@ -41,11 +41,13 @@ import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.expression.ColumnarExpressionConverter import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.catalyst.plans.{AliasAwareOutputExpression, AliasAwareQueryOutputOrdering} + case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode - with AliasAwareOutputPartitioning - with AliasAwareOutputOrdering { + with AliasAwareOutputExpression + with AliasAwareQueryOutputOrdering[SparkPlan] { override def supportsColumnar: Boolean = true @@ -267,8 +269,8 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], condition: Expression, child: SparkPlan) extends UnaryExecNode - with AliasAwareOutputPartitioning - with AliasAwareOutputOrdering { + with AliasAwareOutputExpression + with AliasAwareQueryOutputOrdering[SparkPlan] { override def supportsColumnar: Boolean = true diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index 800dcf1a0..4b70db452 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -479,10 +479,11 @@ abstract class BaseColumnarFileSourceScanExec( } }.groupBy { f => BucketingUtils - .getBucketId(new Path(f.filePath).getName) - .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.filePath)) + .getBucketId(f.toPath.getName) + .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath)) } + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { val bucketSet = optionalBucketSet.get filesGroupedToBuckets.filter { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index ef33a84de..37f3db5d6 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -105,6 +105,23 @@ class QueryExecution( case other => other } + // The plan that has been normalized by custom rules, so that it's more likely to hit cache. + lazy val normalized: LogicalPlan = { + val normalizationRules = sparkSession.sessionState.planNormalizationRules + if (normalizationRules.isEmpty) { + commandExecuted + } else { + val planChangeLogger = new PlanChangeLogger[LogicalPlan]() + val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) => + val result = rule.apply(p) + planChangeLogger.logRule(rule.ruleName, p, result) + result + } + planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized) + normalized + } + } // Spark3.4.3 + lazy val withCachedData: LogicalPlan = sparkSession.withActive { assertAnalyzed() assertSupported() @@ -227,7 +244,7 @@ class QueryExecution( // output mode does not matter since there is no `Sink`. new IncrementalExecution( sparkSession, logical, OutputMode.Append(), "", - UUID.randomUUID, UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) + UUID.randomUUID, UUID.randomUUID, 0, None ,OffsetSeqMetadata(0, 0)) } else { this } @@ -494,11 +511,10 @@ object QueryExecution { */ private[sql] def toInternalError(msg: String, e: Throwable): Throwable = e match { case e @ (_: java.lang.NullPointerException | _: java.lang.AssertionError) => - new SparkException( - errorClass = "INTERNAL_ERROR", - messageParameters = Array(msg + - " Please, fill a bug report in, and provide the full stack trace."), - cause = e) + SparkException.internalError( + msg + " You hit a bug in Spark or the Spark plugins you use. Please, report this bug " + + "to the corresponding communities or vendors, and provide the full stack trace.", + e) case e: Throwable => e } @@ -513,4 +529,4 @@ object QueryExecution { case e: Throwable => throw toInternalError(msg, e) } } -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 2599a1410..e5c8fc0ab 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.execution.adaptive import java.util import java.util.concurrent.LinkedBlockingQueue - import scala.collection.JavaConverters._ import scala.collection.concurrent.TrieMap import scala.collection.mutable import scala.concurrent.ExecutionContext import scala.util.control.NonFatal - import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession @@ -35,12 +33,13 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan import org.apache.spark.sql.execution.exchange._ -import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric} +import org.apache.spark.sql.execution.ui.{SQLPlanMetric, SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SparkFatalException, ThreadUtils} @@ -83,7 +82,9 @@ case class AdaptiveSparkPlanExec( @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]() // The logical plan optimizer for re-optimizing the current logical plan. - @transient private val optimizer = new AQEOptimizer(conf) + @transient private val optimizer = new AQEOptimizer(conf, + session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules + ) // `EnsureRequirements` may remove user-specified repartition and assume the query plan won't // change its output partitioning. This assumption is not true in AQE. Here we check the @@ -122,7 +123,7 @@ case class AdaptiveSparkPlanExec( RemoveRedundantSorts, DisableUnnecessaryBucketedScan, OptimizeSkewedJoin(ensureRequirements) - ) ++ context.session.sessionState.queryStagePrepRules + ) ++ context.session.sessionState.adaptiveRulesHolder.queryStagePrepRules } // A list of physical optimizer rules to be applied to a new stage before its execution. These @@ -152,7 +153,13 @@ case class AdaptiveSparkPlanExec( ) private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = { - val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) => + val rules = if (isFinalStage && + !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) { + queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule]) + } else { + queryStageOptimizerRules + } + val optimized = rules.foldLeft(plan) { case (latestPlan, rule) => val applied = rule.apply(latestPlan) val result = rule match { case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) => @@ -187,7 +194,7 @@ case class AdaptiveSparkPlanExec( @volatile private var currentPhysicalPlan = initialPlan - private var isFinalPlan = false + @volatile private var _isFinalPlan = false private var currentStageId = 0 @@ -204,6 +211,8 @@ case class AdaptiveSparkPlanExec( def executedPlan: SparkPlan = currentPhysicalPlan + def isFinalPlan: Boolean = _isFinalPlan + override def conf: SQLConf = context.session.sessionState.conf override def output: Seq[Attribute] = inputPlan.output @@ -222,6 +231,8 @@ case class AdaptiveSparkPlanExec( .map(_.toLong).filter(SQLExecution.getQueryExecution(_) eq context.qe) } + def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity) // saprk 3.4.3 + private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized { if (isFinalPlan) return currentPhysicalPlan @@ -309,7 +320,8 @@ case class AdaptiveSparkPlanExec( val newCost = costEvaluator.evaluateCost(newPhysicalPlan) if (newCost < origCost || (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) { - logOnLevel(s"Plan changed from $currentPhysicalPlan to $newPhysicalPlan") + logOnLevel("Plan changed:\n" + + sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")) cleanUpTempTags(newPhysicalPlan) currentPhysicalPlan = newPhysicalPlan currentLogicalPlan = newLogicalPlan @@ -325,7 +337,7 @@ case class AdaptiveSparkPlanExec( optimizeQueryStage(result.newPlan, isFinalStage = true), postStageCreationRules(supportsColumnar), Some((planChangeLogger, "AQE Post Stage Creation"))) - isFinalPlan = true + _isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan } @@ -339,7 +351,7 @@ case class AdaptiveSparkPlanExec( if (!isSubquery && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) { getExecutionId.foreach(onUpdatePlan(_, Seq.empty)) } - logOnLevel(s"Final plan: $currentPhysicalPlan") + logOnLevel(s"Final plan: \n$currentPhysicalPlan") } override def executeCollect(): Array[InternalRow] = { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index b5a1ad375..dfdbe2c70 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -30,11 +30,11 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning(_.containsAnyPattern( SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY, RUNTIME_FILTER_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _) => + case expressions.ScalarSubquery(_, _, exprId, _, _, _) => val subquery = SubqueryExec.createForScalarSubquery( s"subquery#${exprId.id}", subqueryMap(exprId.id)) execution.ScalarSubquery(subquery, exprId) - case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _)) => + case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) => val expr = if (values.length == 1) { values.head } else { @@ -47,7 +47,7 @@ case class PlanAdaptiveSubqueries( val subquery = SubqueryExec(s"subquery#${exprId.id}", subqueryMap(exprId.id)) InSubqueryExec(expr, subquery, exprId, shouldBroadcast = true) case expressions.DynamicPruningSubquery(value, buildPlan, - buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) => + buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) => val name = s"dynamicpruning#${exprId.id}" val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndex, onlyInBroadcast, buildPlan, buildKeys, subqueryMap(exprId.id)) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala index 334800f51..807369004 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala @@ -118,17 +118,27 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ (file: PartitionedFile) => { val conf = broadcastedConf.value.value - val filePath = new Path(new URI(file.filePath)) - val isPPDSafeValue = isPPDSafe(filters, dataSchema).reduceOption(_ && _) +// val filePath = new Path(new URI(file.filePath.urlEncoded)) + val filePath = file.toPath - // ORC predicate pushdown - if (orcFilterPushDown && filters.nonEmpty && isPPDSafeValue.getOrElse(false)) { - OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { - fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f => - OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) - } - } - } + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val orcSchema = + Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions))(_.getSchema) + val isPPDSafeValue = isPPDSafe(filters, dataSchema).reduceOption(_ && _) + val resultedColPruneInfo = OrcUtils.requestedColumnIds( + isCaseSensitive, dataSchema, requiredSchema, orcSchema, conf) + + + // ORC predicate pushdown + if (orcFilterPushDown && filters.nonEmpty && isPPDSafeValue.getOrElse(false)) { + val fileSchema = OrcUtils.toCatalystSchema(orcSchema) + // OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { fileSchema => + OrcFilters.createFilter(fileSchema, filters).foreach { f => + OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) + // } + } + } val taskConf = new Configuration(conf) val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala index 9504b34d1..3d61f873d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala @@ -80,7 +80,7 @@ class OmniParquetFileFormat extends FileFormat with DataSourceRegister with Logg (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) - val filePath = new Path(new URI(file.filePath)) + val filePath = file.toPath val split = new org.apache.parquet.hadoop.ParquetInputSplit( filePath, diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties index 89eabe8e6..099e28e8d 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties @@ -2,11 +2,13 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. # -hive.metastore.uris=thrift://server1:9083 +#hive.metastore.uris=thrift://server1:9083 +hive.metastore.uris=thrift://OmniOperator:9083 spark.sql.warehouse.dir=/user/hive/warehouse spark.memory.offHeap.size=8G spark.sql.codegen.wholeStage=false spark.sql.extensions=com.huawei.boostkit.spark.ColumnarPlugin spark.shuffle.manager=org.apache.spark.shuffle.sort.OmniColumnarShuffleManager spark.sql.orc.impl=native -hive.db=tpcds_bin_partitioned_orc_2 \ No newline at end of file +#hive.db=tpcds_bin_partitioned_orc_2 +hive.db=tpcds_bin_partitioned_varchar_orc_2 diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e6b786c2a..329295bac 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkException import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.execution.ColumnarSparkPlanTest import org.apache.spark.sql.types.{DataType, Decimal} @@ -64,7 +65,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getByte(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as short") { @@ -72,7 +74,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getShort(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as int") { @@ -80,7 +83,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getInt(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as long") { @@ -88,7 +92,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getLong(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as float") { @@ -96,7 +101,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getFloat(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as double") { @@ -104,7 +110,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getDouble(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as date") { @@ -154,13 +161,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception4 = intercept[Exception]( result4.collect().toSeq.head.getBoolean(0) ) - assert(exception4.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception4.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception4.isInstanceOf[SparkException], s"sql: ${sql}") val result5 = spark.sql("select cast('test' as boolean);") val exception5 = intercept[Exception]( result5.collect().toSeq.head.getBoolean(0) ) - assert(exception5.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception5.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception5.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast boolean to string") { @@ -182,13 +191,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getByte(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as byte);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getByte(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast byte to string") { @@ -210,13 +221,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getShort(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as short);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getShort(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast short to string") { @@ -238,13 +251,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getInt(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as int);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getInt(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast int to string") { @@ -266,13 +281,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getLong(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as long);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getLong(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast long to string") { @@ -298,7 +315,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception3 = intercept[Exception]( result3.collect().toSeq.head.getFloat(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast float to string") { @@ -324,7 +342,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception3 = intercept[Exception]( result3.collect().toSeq.head.getDouble(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast double to string") { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index f83edb9ca..ea52aca62 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -61,7 +61,7 @@ class CombiningLimitsSuite extends PlanTest { comparePlans(optimized1, expected1) // test child max row > limit. - val query2 = testRelation.select().groupBy()(count(1)).limit(0).analyze + val query2 = testRelation2.select($"x").groupBy($"x")(count(1)).limit(1).analyze val optimized2 = Optimize.execute(query2) comparePlans(optimized2, query2) diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala index aaa244cdf..e1c620e1c 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala @@ -43,7 +43,7 @@ class MergeSubqueryFiltersSuite extends PlanTest { } private def extractorExpression(cteIndex: Int, output: Seq[Attribute], fieldIndex: Int) = { - GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output)), fieldIndex) + GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output, isStreaming = false)), fieldIndex) .as("scalarsubquery()") } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala index c34ff5bb1..c0be72f31 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala @@ -19,17 +19,15 @@ package org.apache.spark.sql.execution.adaptive import java.io.File import java.net.URI - import org.apache.logging.log4j.Level import org.scalatest.PrivateMethodTester import org.scalatest.time.SpanSugar._ - import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec} +import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource @@ -37,7 +35,7 @@ import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter -import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate +import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLExecutionStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -1122,13 +1120,21 @@ class AdaptiveQueryExecSuite test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + var plan: SparkPlan = null + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + plan = qe.executedPlan + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } + spark.listenerManager.register(listener) withTable("t1") { - val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan - assert(plan.isInstanceOf[CommandResultExec]) - val commandResultExec = plan.asInstanceOf[CommandResultExec] - assert(commandResultExec.commandPhysicalPlan.isInstanceOf[DataWritingCommandExec]) - assert(commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec] - .child.isInstanceOf[AdaptiveSparkPlanExec]) + val format = classOf[NoopDataSource].getName + Seq((0, 1)).toDF("x", "y").write.format(format).mode("overwrite").save() + sparkContext.listenerBus.waitUntilEmpty() + assert(plan.isInstanceOf[V2TableWriteExec]) + assert(plan.asInstanceOf[V2TableWriteExec].child.isInstanceOf[AdaptiveSparkPlanExec]) + spark.listenerManager.unregister(listener) } } } @@ -1172,15 +1178,14 @@ class AdaptiveQueryExecSuite test("SPARK-31658: SQL UI should show write commands") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "false") { withTable("t1") { - var checkDone = false + var commands: Seq[SparkPlanInfo] = Seq.empty val listener = new SparkListener { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { - case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) => - assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand") - checkDone = true + case start: SparkListenerSQLExecutionStart => + commands = commands ++ Seq(start.sparkPlanInfo) case _ => // ignore other events } } @@ -1189,7 +1194,12 @@ class AdaptiveQueryExecSuite try { sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() spark.sparkContext.listenerBus.waitUntilEmpty() - assert(checkDone) + assert(commands.size == 3) + assert(commands.head.nodeName == "Execute CreateDataSourceTableAsSelectCommand") + assert(commands(1).nodeName == "Execute InsertIntoHadoopFsRelationCommand") + assert(commands(1).children.size == 1) + assert(commands(1).children.head.nodeName == "WriteFiles") + assert(commands(2).nodeName == "CommandResult") } finally { spark.sparkContext.removeSparkListener(listener) } @@ -1574,7 +1584,7 @@ class AdaptiveQueryExecSuite test("SPARK-32932: Do not use local shuffle read at final stage on write command") { withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString, SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { val data = for ( i <- 1L to 10L; j <- 1L to 3L @@ -1584,9 +1594,8 @@ class AdaptiveQueryExecSuite var noLocalread: Boolean = false val listener = new QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - qe.executedPlan match { + stripAQEPlan(qe.executedPlan) match { case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) => - assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec]) noLocalread = collect(plan) { case exec: AQEShuffleReadExec if exec.isLocalRead => exec }.isEmpty diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml index 4376a89be..e949cea2d 100644 --- a/omnioperator/omniop-spark-extension/pom.xml +++ b/omnioperator/omniop-spark-extension/pom.xml @@ -8,13 +8,13 @@ com.huawei.kunpeng boostkit-omniop-spark-parent pom - 3.3.1-1.6.0 + 3.4.3-1.6.0 BoostKit Spark Native Sql Engine Extension Parent Pom 2.12.10 2.12 - 3.3.1 + 3.4.3 3.2.2 UTF-8 UTF-8 -- Gitee From f12d2c02f024f683fe998ea31042ff35dfefd573 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Thu, 17 Oct 2024 06:58:44 +0000 Subject: [PATCH 34/43] update omnioperator/omniop-native-reader/java/pom.xml. modify spark_version and logger dependency Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../omniop-native-reader/java/pom.xml | 1 + .../expression/OmniExpressionAdaptor.scala | 4 +- .../sql/catalyst/Tree/TreePatterns.scala | 30 +- .../catalyst/analysis/DecimalPrecision.scala | 347 ++++++++++++++++++ .../expressions/decimalExpressions.scala | 235 ++++++++++++ .../optimizer/InjectRuntimeFilter.scala | 59 +-- .../ColumnarBasicPhysicalOperators.scala | 2 - .../spark/sql/execution/QueryExecution.scala | 19 +- .../adaptive/AdaptiveSparkPlanExec.scala | 24 +- 9 files changed, 620 insertions(+), 101 deletions(-) create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala diff --git a/omnioperator/omniop-native-reader/java/pom.xml b/omnioperator/omniop-native-reader/java/pom.xml index bdad33a3a..773ed2c39 100644 --- a/omnioperator/omniop-native-reader/java/pom.xml +++ b/omnioperator/omniop-native-reader/java/pom.xml @@ -34,6 +34,7 @@ 1.6.0 + org.slf4j slf4j-simple 1.7.36 diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index f7ed75fa3..be1538172 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -160,8 +160,8 @@ object OmniExpressionAdaptor extends Logging { throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}") } -// case promotePrecision: PromotePrecision => -// rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap) + case promotePrecision: PromotePrecision => + rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap) case sub: Subtract => new JSONObject().put("exprType", "BINARY") diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala index ef17a0740..4d8c246ab 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala @@ -25,7 +25,7 @@ object TreePattern extends Enumeration { // Expression patterns (alphabetically ordered) val AGGREGATE_EXPRESSION = Value(0) val ALIAS: Value = Value -// val AND_OR: Value = Value + val AND_OR: Value = Value val AND: Value = Value val ARRAYS_ZIP: Value = Value val ATTRIBUTE_REFERENCE: Value = Value @@ -59,7 +59,7 @@ object TreePattern extends Enumeration { val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value - val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value // spark3.4.3 + val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value val LATERAL_SUBQUERY: Value = Value val LIKE_FAMLIY: Value = Value val LIST_SUBQUERY: Value = Value @@ -71,33 +71,33 @@ object TreePattern extends Enumeration { val NULL_CHECK: Value = Value val NULL_LITERAL: Value = Value val SERIALIZE_FROM_OBJECT: Value = Value - val OR: Value = Value // spark3.4.3 + val OR: Value = Value val OUTER_REFERENCE: Value = Value - val PARAMETER: Value = Value // spark3.4.3 - val PARAMETERIZED_QUERY: Value = Value // spark3.4.3 + val PARAMETER: Value = Value + val PARAMETERIZED_QUERY: Value = Value val PIVOT: Value = Value val PLAN_EXPRESSION: Value = Value val PYTHON_UDF: Value = Value val REGEXP_EXTRACT_FAMILY: Value = Value val REGEXP_REPLACE: Value = Value val RUNTIME_REPLACEABLE: Value = Value - val RUNTIME_FILTER_EXPRESSION: Value = Value // spark3.4.3移除 - val RUNTIME_FILTER_SUBQUERY: Value = Value // spark3.4.3移除 + val RUNTIME_FILTER_EXPRESSION: Value = Value + val RUNTIME_FILTER_SUBQUERY: Value = Value val SCALAR_SUBQUERY: Value = Value val SCALAR_SUBQUERY_REFERENCE: Value = Value val SCALA_UDF: Value = Value - val SESSION_WINDOW: Value = Value // spark3.4.3 + val SESSION_WINDOW: Value = Value val SORT: Value = Value val SUBQUERY_ALIAS: Value = Value - val SUBQUERY_WRAPPER: Value = Value // spark3.4.3移除 + val SUBQUERY_WRAPPER: Value = Value val SUM: Value = Value val TIME_WINDOW: Value = Value val TIME_ZONE_AWARE_EXPRESSION: Value = Value val TRUE_OR_FALSE_LITERAL: Value = Value val WINDOW_EXPRESSION: Value = Value - val WINDOW_TIME: Value = Value // saprk3.4.3 + val WINDOW_TIME: Value = Value val UNARY_POSITIVE: Value = Value - val UNPIVOT: Value = Value // spark3.4.3 + val UNPIVOT: Value = Value val UPDATE_FIELDS: Value = Value val UPPER_OR_LOWER: Value = Value val UP_CAST: Value = Value @@ -127,7 +127,7 @@ object TreePattern extends Enumeration { val UNION: Value = Value val UNRESOLVED_RELATION: Value = Value val UNRESOLVED_WITH: Value = Value - val TEMP_RESOLVED_COLUMN: Value = Value // spark3.4.3 + val TEMP_RESOLVED_COLUMN: Value = Value val TYPED_FILTER: Value = Value val WINDOW: Value = Value val WITH_WINDOW_DEFINITION: Value = Value @@ -136,7 +136,7 @@ object TreePattern extends Enumeration { val UNRESOLVED_ALIAS: Value = Value val UNRESOLVED_ATTRIBUTE: Value = Value val UNRESOLVED_DESERIALIZER: Value = Value - val UNRESOLVED_HAVING: Value = Value // spark3.4.3 + val UNRESOLVED_HAVING: Value = Value val UNRESOLVED_ORDINAL: Value = Value val UNRESOLVED_FUNCTION: Value = Value val UNRESOLVED_HINT: Value = Value @@ -145,8 +145,8 @@ object TreePattern extends Enumeration { // Unresolved Plan patterns (Alphabetically ordered) val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value val UNRESOLVED_FUNC: Value = Value - val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value // spark3.4.3 - val UNRESOLVED_TVF_ALIASES: Value = Value // spark3.4.3 + val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value + val UNRESOLVED_TVF_ALIASES: Value = Value // Execution expression patterns (alphabetically ordered) val IN_SUBQUERY_EXEC: Value = Value diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala new file mode 100644 index 000000000..5a04f02ce --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal._ +import org.apache.spark.sql.types._ + + +// scalastyle:off +/** + * Calculates and propagates precision for fixed-precision decimals. Hive has a number of + * rules for this based on the SQL standard and MS SQL: + * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * https://msdn.microsoft.com/en-us/library/ms190476.aspx + * + * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 + * respectively, then the following operations have the following precision / scale: + * + * Operation Result Precision Result Scale + * ------------------------------------------------------------------------ + * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 * e2 p1 + p2 + 1 s1 + s2 + * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) + * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) + * + * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale + * needed are out of the range of available values, the scale is reduced up to 6, in order to + * prevent the truncation of the integer part of the decimals. + * + * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited + * precision, do the math on unlimited-precision numbers, then introduce casts back to the + * required fixed precision. This allows us to do all rounding and overflow handling in the + * cast-to-fixed-precision operator. + * + * In addition, when mixing non-decimal types with decimals, we use the following rules: + * - BYTE gets turned into DECIMAL(3, 0) + * - SHORT gets turned into DECIMAL(5, 0) + * - INT gets turned into DECIMAL(10, 0) + * - LONG gets turned into DECIMAL(20, 0) + * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + * - Literals INT and LONG get turned into DECIMAL with the precision strictly needed by the value + */ +// scalastyle:on +object DecimalPrecision extends TypeCoercionRule { + import scala.math.{max, min} + + private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + + // Returns the wider decimal type that's wider than both of them + def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = { + widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale) + } + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { + val scale = max(s1, s2) + val range = max(p1 - s1, p2 - s2) + DecimalType.bounded(range + scale, scale) + } + + private def promotePrecision(e: Expression, dataType: DataType): Expression = { + PromotePrecision(Cast(e, dataType)) + } + + override def transform: PartialFunction[Expression, Expression] = { + decimalAndDecimal() + .orElse(integralAndDecimalLiteral) + .orElse(nondecimalAndDecimal(conf.literalPickMinimumPrecision)) + } + + private[catalyst] def decimalAndDecimal(): PartialFunction[Expression, Expression] = { + decimalAndDecimal(conf.decimalOperationsAllowPrecisionLoss, !conf.ansiEnabled) + } + + /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */ + private[catalyst] def decimalAndDecimal(allowPrecisionLoss: Boolean, nullOnOverflow: Boolean) + : PartialFunction[Expression, Expression] = { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + // Skip nodes who is already promoted + case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e + + case a @ Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultScale = max(s1, s2) + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow( + a.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)), + resultType, nullOnOverflow) + + case s @ Subtract(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2), _) => + val resultScale = max(s1, s2) + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow( + s.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)), + resultType, nullOnOverflow) + + case m @ Multiply( + e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) + } else { + DecimalType.bounded(p1 + p2 + 1, s1 + s2) + } + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow( + m.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) + + case d @ Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultType = if (allowPrecisionLoss) { + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + val intDig = p1 - s1 + s2 + val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) + val prec = intDig + scale + DecimalType.adjustPrecisionScale(prec, scale) + } else { + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + DecimalType.bounded(intDig + decDig, decDig) + } + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow( + d.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) + + case r @ Remainder( + e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow( + r.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) + + case p @ Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow( + p.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) + + case expr @ IntegralDivide( + e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val widerType = widerDecimalType(p1, s1, p2, s2) + val promotedExpr = expr.copy( + left = promotePrecision(e1, widerType), + right = promotePrecision(e2, widerType)) + if (expr.dataType.isInstanceOf[DecimalType]) { + // This follows division rule + val intDig = p1 - s1 + s2 + // No precision loss can happen as the result scale is 0. + // Overflow can happen only in the promote precision of the operands, but if none of them + // overflows in that phase, no overflow can happen, but CheckOverflow is needed in order + // to return a decimal with the proper scale and precision + CheckOverflow(promotedExpr, DecimalType.bounded(intDig, 0), nullOnOverflow) + } else { + promotedExpr + } + + case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + val resultType = widerDecimalType(p1, s1, p2, s2) + val newE1 = if (e1.dataType == resultType) e1 else Cast(e1, resultType) + val newE2 = if (e2.dataType == resultType) e2 else Cast(e2, resultType) + b.makeCopy(Array(newE1, newE2)) + } + + /** + * Strength reduction for comparing integral expressions with decimal literals. + * 1. int_col > decimal_literal => int_col > floor(decimal_literal) + * 2. int_col >= decimal_literal => int_col >= ceil(decimal_literal) + * 3. int_col < decimal_literal => int_col < ceil(decimal_literal) + * 4. int_col <= decimal_literal => int_col <= floor(decimal_literal) + * 5. decimal_literal > int_col => ceil(decimal_literal) > int_col + * 6. decimal_literal >= int_col => floor(decimal_literal) >= int_col + * 7. decimal_literal < int_col => floor(decimal_literal) < int_col + * 8. decimal_literal <= int_col => ceil(decimal_literal) <= int_col + * + * Note that technically this is an "optimization" and should go into the optimizer. However, + * by the time the optimizer runs, these comparison expressions would be pretty hard to pattern + * match because there are multiple (at least 2) levels of casts involved. + * + * There are a lot more possible rules we can implement, but we don't do them + * because we are not sure how common they are. + */ + private val integralAndDecimalLiteral: PartialFunction[Expression, Expression] = { + + case GreaterThan(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + GreaterThan(i, Literal(value.floor.toLong)) + } + + case GreaterThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + GreaterThanOrEqual(i, Literal(value.ceil.toLong)) + } + + case LessThan(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + LessThan(i, Literal(value.ceil.toLong)) + } + + case LessThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + LessThanOrEqual(i, Literal(value.floor.toLong)) + } + + case GreaterThan(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + GreaterThan(Literal(value.ceil.toLong), i) + } + + case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + GreaterThanOrEqual(Literal(value.floor.toLong), i) + } + + case LessThan(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + LessThan(Literal(value.floor.toLong), i) + } + + case LessThanOrEqual(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + LessThanOrEqual(Literal(value.ceil.toLong), i) + } + } + + /** + * Type coercion for BinaryOperator in which one side is a non-decimal numeric, and the other + * side is a decimal. + */ + private def nondecimalAndDecimal(literalPickMinimumPrecision: Boolean) + : PartialFunction[Expression, Expression] = { + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + (left, right) match { + // Promote literal integers inside a binary expression with fixed-precision decimals to + // decimals. The precision and scale are the ones strictly needed by the integer value. + // Requiring more precision than necessary may lead to a useless loss of precision. + // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2. + // If we use the default precision and scale for the integer type, 2 is considered a + // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18), + // which is out of range and therefore it will become DECIMAL(38, 7), leading to + // potentially loosing 11 digits of the fractional part. Using only the precision needed + // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would + // become DECIMAL(38, 16), safely having a much lower precision loss. + case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] && + l.dataType.isInstanceOf[IntegralType] && + literalPickMinimumPrecision => + b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r)) + case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] && + r.dataType.isInstanceOf[IntegralType] && + literalPickMinimumPrecision => + b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r)))) + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) => + b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r)) + case (l @ DecimalType.Expression(_, _), r @ IntegralType()) => + b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType)))) + case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) => + b.makeCopy(Array(l, Cast(r, DoubleType))) + case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) => + b.makeCopy(Array(Cast(l, DoubleType), r)) + case _ => b + } + } + +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala new file mode 100644 index 000000000..2b6d3ff33 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * Return the unscaled Long value of a Decimal, assuming it fits in a Long. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ +case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant { + + override def dataType: DataType = LongType + override def toString: String = s"UnscaledValue($child)" + + protected override def nullSafeEval(input: Any): Any = + input.asInstanceOf[Decimal].toUnscaledLong + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") + } + + override protected def withNewChildInternal(newChild: Expression): UnscaledValue = + copy(child = newChild) +} + +/** + * Create a Decimal from an unscaled Long value. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ +case class MakeDecimal( + child: Expression, + precision: Int, + scale: Int, + nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant { + + def this(child: Expression, precision: Int, scale: Int) = { + this(child, precision, scale, !SQLConf.get.ansiEnabled) + } + + override def dataType: DataType = DecimalType(precision, scale) + override def nullable: Boolean = child.nullable || nullOnOverflow + override def toString: String = s"MakeDecimal($child,$precision,$scale)" + + protected override def nullSafeEval(input: Any): Any = { + val longInput = input.asInstanceOf[Long] + val result = new Decimal() + if (nullOnOverflow) { + result.setOrNull(longInput, precision, scale) + } else { + result.set(longInput, precision, scale) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, eval => { + val setMethod = if (nullOnOverflow) { + "setOrNull" + } else { + "set" + } + val setNull = if (nullable) { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } + s""" + |${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale); + |$setNull + |""".stripMargin + }) + } + + override protected def withNewChildInternal(newChild: Expression): MakeDecimal = + copy(child = newChild) +} + +object MakeDecimal { + def apply(child: Expression, precision: Int, scale: Int): MakeDecimal = { + new MakeDecimal(child, precision, scale) + } +} + +/** + * An expression used to wrap the children when promote the precision of DecimalType to avoid + * promote multiple times. + */ +case class PromotePrecision(child: Expression) extends UnaryExpression { + override def dataType: DataType = child.dataType + override def eval(input: InternalRow): Any = child.eval(input) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + child.genCode(ctx) + override def prettyName: String = "promote_precision" + override def sql: String = child.sql + override lazy val canonicalized: Expression = child.canonicalized + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +/** + * Rounds the decimal to given scale and check whether the decimal can fit in provided precision + * or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an + * `ArithmeticException` is thrown. + */ +case class CheckOverflow( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean) extends UnaryExpression with SupportQueryContext { + + override def nullable: Boolean = true + + override def nullSafeEval(input: Any): Any = + input.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow, + getContextOrNull()) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val errorContextCode = if (nullOnOverflow) { + "\"\"" + } else { + ctx.addReferenceObj("errCtx", queryContext) + } + nullSafeCodeGen(ctx, ev, eval => { + // scalastyle:off line.size.limit + s""" + |${ev.value} = $eval.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode); + |${ev.isNull} = ${ev.value} == null; + """.stripMargin + // scalastyle:on line.size.limit + }) + } + + override def toString: String = s"CheckOverflow($child, $dataType)" + + override def sql: String = child.sql + + override protected def withNewChildInternal(newChild: Expression): CheckOverflow = + copy(child = newChild) + + override def initQueryContext(): Option[SQLQueryContext] = if (nullOnOverflow) { + Some(origin.context) + } else { + None + } +} + +// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`. +case class CheckOverflowInSum( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean, + context: SQLQueryContext) extends UnaryExpression { + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + if (nullOnOverflow) null + else throw QueryExecutionErrors.overflowInSumOfDecimalError(context) + } else { + value.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow, + context) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val errorContextCode = if (nullOnOverflow) { + "\"\"" + } else { + ctx.addReferenceObj("errCtx", context) + } + val nullHandling = if (nullOnOverflow) { + "" + } else { + s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);" + } + // scalastyle:off line.size.limit + val code = code""" + |${childGen.code} + |boolean ${ev.isNull} = ${childGen.isNull}; + |Decimal ${ev.value} = null; + |if (${childGen.isNull}) { + | $nullHandling + |} else { + | ${ev.value} = ${childGen.value}.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode); + | ${ev.isNull} = ${ev.value} == null; + |} + |""".stripMargin + // scalastyle:on line.size.limit + + ev.copy(code = code) + } + + override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)" + + override def sql: String = child.sql + + override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum = + copy(child = newChild) +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index e6c266242..b758b69e6 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import com.huawei.boostkit.spark.ColumnarPluginConfig -import scala.annotation.tailrec - /** * Insert a filter on one side of the join if the other side has a selective predicate. * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something @@ -87,12 +85,12 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J val bloomFilterAgg = if (rowCount.isDefined && rowCount.get.longValue > 0L) { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), - rowCount.get.longValue) + Literal(rowCount.get.longValue) } else { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) } - - val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")() + val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) + val alias = Alias(aggExp, "bloomFilter")() val aggregate = ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan))) val bloomFilterSubquery = if (canReuseExchange) { @@ -114,8 +112,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType) val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp) val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)() - val aggregate = - ColumnPruning(Aggregate(Seq(filterCreationSideExp), Seq(alias), filterCreationSidePlan)) + val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan) if (!canBroadcastBySize(aggregate, conf)) { // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold, // i.e., the semi-join will be a shuffled join, which is not worthwhile. @@ -132,39 +129,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J * do not add a subquery that might have an expensive computation */ private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { - @tailrec - def isSelective( - p: LogicalPlan, - predicateReference: AttributeSet, - hasHitFilter: Boolean, - hasHitSelectiveFilter: Boolean): Boolean = p match { - case Project(projectList, child) => - if (hasHitFilter) { - // We need to make sure all expressions referenced by filter predicates are simple - // expressions. - val referencedExprs = projectList.filter(predicateReference.contains) - referencedExprs.forall(isSimpleExpression) && - isSelective( - child, - referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _), - hasHitFilter, - hasHitSelectiveFilter) - } else { - assert(predicateReference.isEmpty && !hasHitSelectiveFilter) - isSelective(child, predicateReference, hasHitFilter, hasHitSelectiveFilter) - } - case Filter(condition, child) => - isSimpleExpression(condition) && isSelective( - child, - predicateReference ++ condition.references, - hasHitFilter = true, - hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition)) - case _: LeafNode => hasHitSelectiveFilter + val ret = plan match { + case PhysicalOperation(_, filters, child) if child.isInstance[LeafNode] => + filters.forall(isSimpleExpression) && + filters.exists(isLikelySelective) case _ => false } - - !plan.isStreaming && - isSelective(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false) + !plan.isStreaming && ret } private def isSimpleExpression(e: Expression): Boolean = { @@ -213,8 +184,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J /** * Check that: - * - The filterApplicationSideJoinExp can be pushed down through joins, aggregates and windows - * (ie the expression references originate from a single leaf node) + * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the + * expression references originate from a single leaf node) * - The filter creation side has a selective predicate * - The current join is a shuffle join or a broadcast join that has a shuffle below it * - The max filterApplicationSide scan size is greater than a configurable threshold @@ -328,13 +299,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J case s: Subquery if s.correlated => plan case _ if !conf.runtimeFilterSemiJoinReductionEnabled && !conf.runtimeFilterBloomFilterEnabled => plan - case _ => - val newPlan = tryInjectRuntimeFilter(plan) - if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) { - RewritePredicateSubquery(newPlan) - } else { - newPlan - } + case _ => tryInjectRuntimeFilter(plan) } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index 630034bd7..b6dd2ab51 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import java.util.concurrent.TimeUnit.NANOSECONDS import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP - import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ import com.huawei.boostkit.spark.util.OmniAdaptorUtil import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs @@ -43,7 +42,6 @@ import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.sql.catalyst.plans.{AliasAwareOutputExpression, AliasAwareQueryOutputOrdering} - case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode with AliasAwareOutputExpression diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 37f3db5d6..e49c88191 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -105,23 +105,6 @@ class QueryExecution( case other => other } - // The plan that has been normalized by custom rules, so that it's more likely to hit cache. - lazy val normalized: LogicalPlan = { - val normalizationRules = sparkSession.sessionState.planNormalizationRules - if (normalizationRules.isEmpty) { - commandExecuted - } else { - val planChangeLogger = new PlanChangeLogger[LogicalPlan]() - val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) => - val result = rule.apply(p) - planChangeLogger.logRule(rule.ruleName, p, result) - result - } - planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized) - normalized - } - } // Spark3.4.3 - lazy val withCachedData: LogicalPlan = sparkSession.withActive { assertAnalyzed() assertSupported() @@ -529,4 +512,4 @@ object QueryExecution { case e: Throwable => throw toInternalError(msg, e) } } -} +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index e5c8fc0ab..b5d638138 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.adaptive import java.util import java.util.concurrent.LinkedBlockingQueue + import scala.collection.JavaConverters._ import scala.collection.concurrent.TrieMap import scala.collection.mutable import scala.concurrent.ExecutionContext import scala.util.control.NonFatal + import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession @@ -33,7 +35,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ @@ -153,13 +154,7 @@ case class AdaptiveSparkPlanExec( ) private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = { - val rules = if (isFinalStage && - !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) { - queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule]) - } else { - queryStageOptimizerRules - } - val optimized = rules.foldLeft(plan) { case (latestPlan, rule) => + val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) => val applied = rule.apply(latestPlan) val result = rule match { case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) => @@ -194,7 +189,7 @@ case class AdaptiveSparkPlanExec( @volatile private var currentPhysicalPlan = initialPlan - @volatile private var _isFinalPlan = false + volatile private var isFinalPlan = false private var currentStageId = 0 @@ -211,8 +206,6 @@ case class AdaptiveSparkPlanExec( def executedPlan: SparkPlan = currentPhysicalPlan - def isFinalPlan: Boolean = _isFinalPlan - override def conf: SQLConf = context.session.sessionState.conf override def output: Seq[Attribute] = inputPlan.output @@ -231,8 +224,6 @@ case class AdaptiveSparkPlanExec( .map(_.toLong).filter(SQLExecution.getQueryExecution(_) eq context.qe) } - def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity) // saprk 3.4.3 - private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized { if (isFinalPlan) return currentPhysicalPlan @@ -320,8 +311,7 @@ case class AdaptiveSparkPlanExec( val newCost = costEvaluator.evaluateCost(newPhysicalPlan) if (newCost < origCost || (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) { - logOnLevel("Plan changed:\n" + - sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")) + logOnLevel("Plan changed from $currentPhysicalPlan to $newPhysicalPlan") cleanUpTempTags(newPhysicalPlan) currentPhysicalPlan = newPhysicalPlan currentLogicalPlan = newLogicalPlan @@ -337,7 +327,7 @@ case class AdaptiveSparkPlanExec( optimizeQueryStage(result.newPlan, isFinalStage = true), postStageCreationRules(supportsColumnar), Some((planChangeLogger, "AQE Post Stage Creation"))) - _isFinalPlan = true + isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan } @@ -351,7 +341,7 @@ case class AdaptiveSparkPlanExec( if (!isSubquery && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) { getExecutionId.foreach(onUpdatePlan(_, Seq.empty)) } - logOnLevel(s"Final plan: \n$currentPhysicalPlan") + logOnLevel(s"Final plan: $currentPhysicalPlan") } override def executeCollect(): Array[InternalRow] = { -- Gitee From 40acdcc48e193b6d3386b17b1182ab75b0dc1d8e Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Fri, 18 Oct 2024 03:24:45 +0000 Subject: [PATCH 35/43] fix EOF bug Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index b758b69e6..328be083f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -85,8 +85,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J val bloomFilterAgg = if (rowCount.isDefined && rowCount.get.longValue > 0L) { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), - Literal(rowCount.get.longValue) - } else { + Literal(rowCount.get.longValue)) + else { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) } val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) -- Gitee From 9cc62a176cb499370ca287fce5c40625821c3e33 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Fri, 18 Oct 2024 03:26:00 +0000 Subject: [PATCH 36/43] fix bug Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index b5d638138..a1a5d09f1 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -189,7 +189,7 @@ case class AdaptiveSparkPlanExec( @volatile private var currentPhysicalPlan = initialPlan - volatile private var isFinalPlan = false + private var isFinalPlan = false private var currentStageId = 0 -- Gitee From 83cc6aaf2fcb0e1c4cceda8953f154639a7d54d4 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Fri, 18 Oct 2024 03:51:10 +0000 Subject: [PATCH 37/43] fix EOF bug Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 328be083f..895e7c3d7 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -86,7 +86,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J if (rowCount.isDefined && rowCount.get.longValue > 0L) { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), Literal(rowCount.get.longValue)) - else { + } else { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) } val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) -- Gitee From d49c55ec71923ae0d70e73a995814ddd5afb953f Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Fri, 18 Oct 2024 04:12:46 +0000 Subject: [PATCH 38/43] fix bug Signed-off-by: wangwei <14424757+wiwimao@user.noreply.gitee.com> --- .../spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 895e7c3d7..2dd8efe58 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -130,7 +130,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J */ private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { val ret = plan match { - case PhysicalOperation(_, filters, child) if child.isInstance[LeafNode] => + case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] => filters.forall(isSimpleExpression) && filters.exists(isLikelySelective) case _ => false -- Gitee From 25e9fb1f0bd4c8a88fa37933a170f71b6520ccf4 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Sun, 20 Oct 2024 19:07:54 +0800 Subject: [PATCH 39/43] add normalized method --- .../spark/sql/execution/QueryExecution.scala | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index e49c88191..3630b0f2e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -105,12 +105,29 @@ class QueryExecution( case other => other } + // The plan that has been normalized by custom rules, so that it's more likely to hit cache. + lazy val normalized: LogicalPlan = { + val normalizationRules = sparkSession.sessionState.planNormalizationRules + if (normalizationRules.isEmpty) { + commandExecuted + } else { + val planChangeLogger = new PlanChangeLogger[LogicalPlan]() + val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) => + val result = rule.apply(p) + planChangeLogger.logRule(rule.ruleName, p, result) + result + } + planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized) + normalized + } + } + lazy val withCachedData: LogicalPlan = sparkSession.withActive { assertAnalyzed() assertSupported() // clone the plan to avoid sharing the plan instance between different stages like analyzing, // optimizing and planning. - sparkSession.sharedState.cacheManager.useCachedData(commandExecuted.clone()) + sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) } def assertCommandExecuted(): Unit = commandExecuted -- Gitee From ff0aafbe64859794ae67914e76af78b6b326c2be Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Sun, 20 Oct 2024 19:14:55 +0800 Subject: [PATCH 40/43] fix UT bug --- .../optimizer/InjectRuntimeFilter.scala | 59 +++++++++++++++---- 1 file changed, 46 insertions(+), 13 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index b758b69e6..a6124f5e3 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -84,15 +84,14 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J val rowCount = filterCreationSidePlan.stats.rowCount val bloomFilterAgg = if (rowCount.isDefined && rowCount.get.longValue > 0L) { - new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), - Literal(rowCount.get.longValue) + new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), rowCount.get.longValue) } else { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) } - val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) - val alias = Alias(aggExp, "bloomFilter")() + + val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")() val aggregate = - ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan))) + ColumnPruning(Aggregate(Seq(filterCreationSideExp), Seq(alias), filterCreationSidePlan)) val bloomFilterSubquery = if (canReuseExchange) { // Try to reuse the results of broadcast exchange. RuntimeFilterSubquery(filterApplicationSideExp, aggregate, filterCreationSideExp) @@ -129,13 +128,39 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J * do not add a subquery that might have an expensive computation */ private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { - val ret = plan match { - case PhysicalOperation(_, filters, child) if child.isInstance[LeafNode] => - filters.forall(isSimpleExpression) && - filters.exists(isLikelySelective) + @tailrec + def isSelective( + p: LogicalPlan, + predicateReference: AttributeSet, + hasHitFilter: Boolean, + hasHitSelectiveFilter: Boolean): Boolean = p match { + case Project(projectList, child) => + if (hasHitFilter) { + // We need to make sure all expressions referenced by filter predicates are simple + // expressions. + val referencedExprs = projectList.filter(predicateReference.contains) + referencedExprs.forall(isSimpleExpression) && + isSelective( + child, + referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _), + hasHitFilter, + hasHitSelectiveFilter) + } else { + assert(predicateReference.isEmpty && !hasHitSelectiveFilter) + isSelective(child, predicateReference, hasHitFilter, hasHitSelectiveFilter) + } + case Filter(condition, child) => + isSimpleExpression(condition) && isSelective( + child, + predicateReference ++ condition.references, + hasHitFilter = true, + hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition)) + case _: LeafNode => hasHitSelectiveFilter case _ => false } - !plan.isStreaming && ret + + !plan.isStreaming && + isSelective(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false) } private def isSimpleExpression(e: Expression): Boolean = { @@ -153,6 +178,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J plan.exists { case Join(left, right, _, _, hint) => isProbablyShuffleJoin(left, right, hint) case _: Aggregate => true + case _: Window => true case _ => false } } @@ -184,8 +210,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J /** * Check that: - * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the - * expression references originate from a single leaf node) + * - The filterApplicationSideJoinExp can be pushed down through joins, aggregates and windows + * (ie the expression references originate from a single leaf node) * - The filter creation side has a selective predicate * - The current join is a shuffle join or a broadcast join that has a shuffle below it * - The max filterApplicationSide scan size is greater than a configurable threshold @@ -212,6 +238,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J } // This checks if there is already a DPP filter, as this rule is called just after DPP. + @tailrec def hasDynamicPruningSubquery( left: LogicalPlan, right: LogicalPlan, @@ -299,7 +326,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J case s: Subquery if s.correlated => plan case _ if !conf.runtimeFilterSemiJoinReductionEnabled && !conf.runtimeFilterBloomFilterEnabled => plan - case _ => tryInjectRuntimeFilter(plan) + case _ => + val newPlan = tryInjectRuntimeFilter(plan) + if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) { + RewritePredicateSubquery(newPlan) + } else { + newPlan + } } } \ No newline at end of file -- Gitee From 2e880e09f5f25e466daaa2937df59bc7b0bca9e6 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 21 Oct 2024 10:57:53 +0800 Subject: [PATCH 41/43] fix BUG --- .../spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index b5d638138..a1a5d09f1 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -189,7 +189,7 @@ case class AdaptiveSparkPlanExec( @volatile private var currentPhysicalPlan = initialPlan - volatile private var isFinalPlan = false + private var isFinalPlan = false private var currentStageId = 0 -- Gitee From e2222aab325779d2d80b79e901834af759d6b2d9 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Mon, 21 Oct 2024 11:09:39 +0800 Subject: [PATCH 42/43] fix BUG --- .../spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index a6124f5e3..c4a94f5db 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import com.huawei.boostkit.spark.ColumnarPluginConfig +import scala.annotation.tailrec + /** * Insert a filter on one side of the join if the other side has a selective predicate. * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something -- Gitee From 8a74794acde7c58155dcaa22f9c1334a0fed1435 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Tue, 22 Oct 2024 22:44:21 +0800 Subject: [PATCH 43/43] refactor orc push down filter --- .../spark/jni/OrcColumnarBatchScanReader.java | 445 ++++++++++++------ .../orc/OmniOrcColumnarBatchReader.java | 161 +++---- .../datasources/orc/OmniOrcFileFormat.scala | 107 +---- ...OrcColumnarBatchJniReaderDataTypeTest.java | 2 +- ...ColumnarBatchJniReaderNotPushDownTest.java | 2 +- ...OrcColumnarBatchJniReaderPushDownTest.java | 2 +- ...BatchJniReaderSparkORCNotPushDownTest.java | 2 +- ...narBatchJniReaderSparkORCPushDownTest.java | 2 +- .../jni/OrcColumnarBatchJniReaderTest.java | 134 +++--- 9 files changed, 437 insertions(+), 420 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java index 73438aa43..064f50416 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java @@ -22,59 +22,78 @@ import com.huawei.boostkit.scan.jni.OrcColumnarBatchJniReader; import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.vector.*; +import org.apache.spark.sql.catalyst.util.CharVarcharUtils; import org.apache.spark.sql.catalyst.util.RebaseDateTime; -import org.apache.hadoop.hive.ql.io.sarg.ExpressionTree; -import org.apache.hadoop.hive.ql.io.sarg.PredicateLeaf; -import org.apache.orc.OrcFile.ReaderOptions; -import org.apache.orc.Reader.Options; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.json.JSONObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.math.BigDecimal; import java.net.URI; -import java.sql.Date; +import java.time.LocalDate; +import java.text.DateFormat; +import java.text.ParseException; +import java.text.SimpleDateFormat; import java.util.ArrayList; -import java.util.List; +import java.util.Arrays; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.TimeZone; public class OrcColumnarBatchScanReader { private static final Logger LOGGER = LoggerFactory.getLogger(OrcColumnarBatchScanReader.class); + private boolean nativeSupportTimestampRebase; + private static final Pattern CHAR_TYPE = Pattern.compile("char\\(\\s*(\\d+)\\s*\\)"); + + private static final int MAX_LEAF_THRESHOLD = 256; + public long reader; public long recordReader; public long batchReader; - public int[] colsToGet; - public int realColsCnt; - public ArrayList fildsNames; + // All ORC fieldNames + public ArrayList allFieldsNames; - public ArrayList colToInclu; + // Indicate columns to read + public int[] colsToGet; - public String[] requiredfieldNames; + // Actual columns to read + public ArrayList includedColumns; - public int[] precisionArray; + // max threshold for leaf node + private int leafIndex; - public int[] scaleArray; + // spark required schema + private StructType requiredSchema; public OrcColumnarBatchJniReader jniReader; public OrcColumnarBatchScanReader() { jniReader = new OrcColumnarBatchJniReader(); - fildsNames = new ArrayList(); - } - - public JSONObject getSubJson(ExpressionTree node) { - JSONObject jsonObject = new JSONObject(); - jsonObject.put("op", node.getOperator().ordinal()); - if (node.getOperator().toString().equals("LEAF")) { - jsonObject.put("leaf", node.toString()); - return jsonObject; - } - ArrayList child = new ArrayList(); - for (ExpressionTree childNode : node.getChildren()) { - JSONObject rtnJson = getSubJson(childNode); - child.add(rtnJson); - } - jsonObject.put("child", child); - return jsonObject; + allFieldsNames = new ArrayList(); } public String padZeroForDecimals(String [] decimalStrArray, int decimalScale) { @@ -86,98 +105,15 @@ public class OrcColumnarBatchScanReader { return String.format("%1$-" + decimalScale + "s", decimalVal).replace(' ', '0'); } - public int getPrecision(String colname) { - for (int i = 0; i < requiredfieldNames.length; i++) { - if (colname.equals(requiredfieldNames[i])) { - return precisionArray[i]; - } - } - - return -1; - } - - public int getScale(String colname) { - for (int i = 0; i < requiredfieldNames.length; i++) { - if (colname.equals(requiredfieldNames[i])) { - return scaleArray[i]; - } - } - - return -1; - } - - public JSONObject getLeavesJson(List leaves) { - JSONObject jsonObjectList = new JSONObject(); - for (int i = 0; i < leaves.size(); i++) { - PredicateLeaf pl = leaves.get(i); - JSONObject jsonObject = new JSONObject(); - jsonObject.put("op", pl.getOperator().ordinal()); - jsonObject.put("name", pl.getColumnName()); - jsonObject.put("type", pl.getType().ordinal()); - if (pl.getLiteral() != null) { - if (pl.getType() == PredicateLeaf.Type.DATE) { - jsonObject.put("literal", ((int)Math.ceil(((Date)pl.getLiteral()).getTime()* 1.0/3600/24/1000)) + ""); - } else if (pl.getType() == PredicateLeaf.Type.DECIMAL) { - int decimalP = getPrecision(pl.getColumnName()); - int decimalS = getScale(pl.getColumnName()); - String[] spiltValues = pl.getLiteral().toString().split("\\."); - if (decimalS == 0) { - jsonObject.put("literal", spiltValues[0] + " " + decimalP + " " + decimalS); - } else { - String scalePadZeroStr = padZeroForDecimals(spiltValues, decimalS); - jsonObject.put("literal", spiltValues[0] + "." + scalePadZeroStr + " " + decimalP + " " + decimalS); - } - } else { - jsonObject.put("literal", pl.getLiteral().toString()); - } - } else { - jsonObject.put("literal", ""); - } - if ((pl.getLiteralList() != null) && (pl.getLiteralList().size() != 0)){ - List lst = new ArrayList<>(); - for (Object ob : pl.getLiteralList()) { - if (ob == null) { - lst.add(null); - continue; - } - if (pl.getType() == PredicateLeaf.Type.DECIMAL) { - int decimalP = getPrecision(pl.getColumnName()); - int decimalS = getScale(pl.getColumnName()); - String[] spiltValues = ob.toString().split("\\."); - if (decimalS == 0) { - lst.add(spiltValues[0] + " " + decimalP + " " + decimalS); - } else { - String scalePadZeroStr = padZeroForDecimals(spiltValues, decimalS); - lst.add(spiltValues[0] + "." + scalePadZeroStr + " " + decimalP + " " + decimalS); - } - } else if (pl.getType() == PredicateLeaf.Type.DATE) { - lst.add(((int)Math.ceil(((Date)ob).getTime()* 1.0/3600/24/1000)) + ""); - } else { - lst.add(ob.toString()); - } - } - jsonObject.put("literalList", lst); - } else { - jsonObject.put("literalList", new ArrayList()); - } - jsonObjectList.put("leaf-" + i, jsonObject); - } - return jsonObjectList; - } - /** * Init Orc reader. * * @param uri split file path - * @param options split file options */ - public long initializeReaderJava(URI uri, ReaderOptions options) { + public long initializeReaderJava(URI uri) { JSONObject job = new JSONObject(); - if (options.getOrcTail() == null) { - job.put("serializedTail", ""); - } else { - job.put("serializedTail", options.getOrcTail().getSerializedTail().toString()); - } + + job.put("serializedTail", ""); job.put("tailLocation", 9223372036854775807L); job.put("scheme", uri.getScheme() == null ? "" : uri.getScheme()); @@ -185,37 +121,36 @@ public class OrcColumnarBatchScanReader { job.put("port", uri.getPort()); job.put("path", uri.getPath() == null ? "" : uri.getPath()); - reader = jniReader.initializeReader(job, fildsNames); + reader = jniReader.initializeReader(job, allFieldsNames); return reader; } /** * Init Orc RecordReader. * - * @param options split file options + * @param offset split file offset + * @param length split file read length + * @param pushedFilter the filter push down to native + * @param requiredSchema the columns read from native */ - public long initializeRecordReaderJava(Options options) { + public long initializeRecordReaderJava(long offset, long length, Filter pushedFilter, StructType requiredSchema) { + this.requiredSchema = requiredSchema; JSONObject job = new JSONObject(); - if (options.getInclude() == null) { - job.put("include", ""); - } else { - job.put("include", options.getInclude().toString()); - } - job.put("offset", options.getOffset()); - job.put("length", options.getLength()); - // When the number of pushedFilters > hive.CNF_COMBINATIONS_THRESHOLD, the expression is rewritten to - // 'YES_NO_NULL'. Under the circumstances, filter push down will be skipped. - if (options.getSearchArgument() != null - && !options.getSearchArgument().toString().contains("YES_NO_NULL")) { - LOGGER.debug("SearchArgument: {}", options.getSearchArgument().toString()); - JSONObject jsonexpressionTree = getSubJson(options.getSearchArgument().getExpression()); - job.put("expressionTree", jsonexpressionTree); - JSONObject jsonleaves = getLeavesJson(options.getSearchArgument().getLeaves()); - job.put("leaves", jsonleaves); - } - job.put("includedColumns", colToInclu.toArray()); + job.put("offset", offset); + job.put("length", length); + if (pushedFilter != null) { + JSONObject jsonExpressionTree = new JSONObject(); + JSONObject jsonLeaves = new JSONObject(); + boolean flag = canPushDown(pushedFilter, jsonExpressionTree, jsonLeaves); + if (flag) { + job.put("expressionTree", jsonExpressionTree); + job.put("leaves", jsonLeaves); + } + } + + job.put("includedColumns", includedColumns.toArray()); recordReader = jniReader.initializeRecordReader(reader, job); return recordReader; } @@ -253,14 +188,22 @@ public class OrcColumnarBatchScanReader { } } + public void convertJulianToGregorianMicros(LongVec longVec, long rowNumber) { + long gregorianValue; + for (int rowIndex = 0; rowIndex < rowNumber; rowIndex++) { + gregorianValue = RebaseDateTime.rebaseJulianToGregorianMicros(longVec.get(rowIndex)); + longVec.set(rowIndex, gregorianValue); + } + } + public int next(Vec[] vecList, int[] typeIds) { - long[] vecNativeIds = new long[realColsCnt]; + long[] vecNativeIds = new long[typeIds.length]; long rtn = jniReader.recordReaderNext(recordReader, batchReader, typeIds, vecNativeIds); if (rtn == 0) { return 0; } int nativeGetId = 0; - for (int i = 0; i < realColsCnt; i++) { + for (int i = 0; i < colsToGet.length; i++) { if (colsToGet[i] != 0) { continue; } @@ -301,7 +244,7 @@ public class OrcColumnarBatchScanReader { } default: { throw new RuntimeException("UnSupport type for ColumnarFileScan:" + - DataType.DataTypeId.values()[typeIds[i]]); + DataType.DataTypeId.values()[typeIds[i]]); } } nativeGetId++; @@ -309,6 +252,228 @@ public class OrcColumnarBatchScanReader { return (int)rtn; } + enum OrcOperator { + OR, + AND, + NOT, + LEAF, + CONSTANT + } + + enum OrcLeafOperator { + EQUALS, + NULL_SAFE_EQUALS, + LESS_THAN, + LESS_THAN_EQUALS, + IN, + BETWEEN, // not use, spark transfers it to gt and lt + IS_NULL + } + + enum OrcPredicateDataType { + LONG, // all of integer types + FLOAT, // float and double + STRING, // string, char, varchar + DATE, + DECIMAL, + TIMESTAMP, + BOOLEAN + } + + private OrcPredicateDataType getOrcPredicateDataType(String attribute) { + StructField field = requiredSchema.apply(attribute); + org.apache.spark.sql.types.DataType dataType = field.dataType(); + if (dataType instanceof ShortType || dataType instanceof IntegerType || + dataType instanceof LongType) { + return OrcPredicateDataType.LONG; + } else if (dataType instanceof DoubleType) { + return OrcPredicateDataType.FLOAT; + } else if (dataType instanceof StringType) { + if (isCharType(field.metadata())) { + throw new UnsupportedOperationException("Unsupported orc push down filter data type: char"); + } + return OrcPredicateDataType.STRING; + } else if (dataType instanceof DateType) { + return OrcPredicateDataType.DATE; + } else if (dataType instanceof DecimalType) { + return OrcPredicateDataType.DECIMAL; + } else if (dataType instanceof BooleanType) { + return OrcPredicateDataType.BOOLEAN; + } else { + throw new UnsupportedOperationException("Unsupported orc push down filter data type: " + + dataType.getClass().getSimpleName()); + } + } + + // Check the type whether is char type, which orc native does not support push down + private boolean isCharType(Metadata metadata) { + if (metadata != null) { + String rawTypeString = CharVarcharUtils.getRawTypeString(metadata).getOrElse(null); + if (rawTypeString != null) { + Matcher matcher = CHAR_TYPE.matcher(rawTypeString); + return matcher.matches(); + } + } + return false; + } + + private boolean canPushDown(Filter pushedFilter, JSONObject jsonExpressionTree, + JSONObject jsonLeaves) { + try { + getExprJson(pushedFilter, jsonExpressionTree, jsonLeaves); + if (leafIndex > MAX_LEAF_THRESHOLD) { + throw new UnsupportedOperationException("leaf node nums is " + leafIndex + + ", which is bigger than max threshold " + MAX_LEAF_THRESHOLD + "."); + } + return true; + } catch (Exception e) { + LOGGER.info("Unable to push down orc filter because " + e.getMessage()); + return false; + } + } + + private void getExprJson(Filter filterPredicate, JSONObject jsonExpressionTree, + JSONObject jsonLeaves) { + if (filterPredicate instanceof And) { + addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.AND, + ((And) filterPredicate).left(), ((And) filterPredicate).right()); + } else if (filterPredicate instanceof Or) { + addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.OR, + ((Or) filterPredicate).left(), ((Or) filterPredicate).right()); + } else if (filterPredicate instanceof Not) { + addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.NOT, + ((Not) filterPredicate).child()); + } else if (filterPredicate instanceof EqualTo) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.EQUALS, jsonLeaves, + ((EqualTo) filterPredicate).attribute(), ((EqualTo) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof GreaterThan) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN_EQUALS, jsonLeaves, + ((GreaterThan) filterPredicate).attribute(), ((GreaterThan) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof GreaterThanOrEqual) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN, jsonLeaves, + ((GreaterThanOrEqual) filterPredicate).attribute(), ((GreaterThanOrEqual) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof LessThan) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN, jsonLeaves, + ((LessThan) filterPredicate).attribute(), ((LessThan) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof LessThanOrEqual) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN_EQUALS, jsonLeaves, + ((LessThanOrEqual) filterPredicate).attribute(), ((LessThanOrEqual) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof IsNotNull) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IS_NULL, jsonLeaves, + ((IsNotNull) filterPredicate).attribute(), null, null); + leafIndex++; + } else if (filterPredicate instanceof IsNull) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IS_NULL, jsonLeaves, + ((IsNull) filterPredicate).attribute(), null, null); + leafIndex++; + } else if (filterPredicate instanceof In) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IN, jsonLeaves, + ((In) filterPredicate).attribute(), null, Arrays.stream(((In) filterPredicate).values()).toArray()); + leafIndex++; + } else { + throw new UnsupportedOperationException("Unsupported orc push down filter operation: " + + filterPredicate.getClass().getSimpleName()); + } + } + + private void addLiteralToJsonLeaves(String leaf, OrcLeafOperator leafOperator, JSONObject jsonLeaves, + String name, Object literal, Object[] literals) { + JSONObject leafJson = new JSONObject(); + leafJson.put("op", leafOperator.ordinal()); + leafJson.put("name", name); + leafJson.put("type", getOrcPredicateDataType(name).ordinal()); + + leafJson.put("literal", getLiteralValue(literal)); + + ArrayList literalList = new ArrayList<>(); + if (literals != null) { + for (Object lit: literalList) { + literalList.add(getLiteralValue(literal)); + } + } + leafJson.put("literalList", literalList); + jsonLeaves.put(leaf, leafJson); + } + + private void addToJsonExpressionTree(String leaf, JSONObject jsonExpressionTree, boolean addNot) { + if (addNot) { + jsonExpressionTree.put("op", OrcOperator.NOT.ordinal()); + ArrayList child = new ArrayList<>(); + JSONObject subJson = new JSONObject(); + subJson.put("op", OrcOperator.LEAF.ordinal()); + subJson.put("leaf", leaf); + child.add(subJson); + jsonExpressionTree.put("child", child); + } else { + jsonExpressionTree.put("op", OrcOperator.LEAF.ordinal()); + jsonExpressionTree.put("leaf", leaf); + } + } + + private void addChildJson(JSONObject jsonExpressionTree, JSONObject jsonLeaves, + OrcOperator orcOperator, Filter ... filters) { + jsonExpressionTree.put("op", orcOperator.ordinal()); + ArrayList child = new ArrayList<>(); + for (Filter filter: filters) { + JSONObject subJson = new JSONObject(); + getExprJson(filter, subJson, jsonLeaves); + child.add(subJson); + } + jsonExpressionTree.put("child", child); + } + + private String getLiteralValue(Object literal) { + // For null literal, the predicate will not be pushed down. + if (literal == null) { + throw new UnsupportedOperationException("Unsupported orc push down filter for literal is null"); + } + + // For Decimal Type, we use the special string format to represent, which is "$decimalVal + // $precision $scale". + // e.g., Decimal(9, 3) = 123456.789, it outputs "123456.789 9 3". + // e.g., Decimal(9, 3) = 123456.7, it outputs "123456.700 9 3". + if (literal instanceof BigDecimal) { + BigDecimal value = (BigDecimal) literal; + int precision = value.precision(); + int scale = value.scale(); + String[] split = value.toString().split("\\."); + if (scale == 0) { + return split[0] + " " + precision + " " + scale; + } else { + String padded = padZeroForDecimals(split, scale); + return split[0] + "." + padded + " " + precision + " " + scale; + } + } + // For Date Type, spark uses Gregorian in default but orc uses Julian, which should be converted. + if (literal instanceof LocalDate) { + int epochDay = Math.toIntExact(((LocalDate) literal).toEpochDay()); + int rebased = RebaseDateTime.rebaseGregorianToJulianDays(epochDay); + return String.valueOf(rebased); + } + if (literal instanceof String) { + return (String) literal; + } + if (literal instanceof Integer || literal instanceof Long || literal instanceof Boolean || + literal instanceof Short || literal instanceof Double) { + return literal.toString(); + } + throw new UnsupportedOperationException("Unsupported orc push down filter date type: " + + literal.getClass().getSimpleName()); + } + private static String bytesToHexString(byte[] bytes) { if (bytes == null || bytes.length < 1) { throw new IllegalArgumentException("this bytes must not be null or empty"); diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java index 93950e9f0..1880208fb 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -22,18 +22,15 @@ import com.google.common.annotations.VisibleForTesting; import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; import com.huawei.boostkit.spark.jni.OrcColumnarBatchScanReader; import nova.hetu.omniruntime.vector.Vec; -import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.RecordReader; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.input.FileSplit; -import org.apache.orc.OrcConf; -import org.apache.orc.OrcFile; -import org.apache.orc.Reader; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.OmniColumnVectorUtils; import org.apache.spark.sql.execution.vectorized.OmniColumnVector; +import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -49,26 +46,11 @@ import java.util.ArrayList; public class OmniOrcColumnarBatchReader extends RecordReader { // The capacity of vectorized batch. - private int capacity; - /** - * The column IDs of the physical ORC file schema which are required by this reader. - * -1 means this required column is partition column, or it doesn't exist in the ORC file. - * Ideally partition column should never appear in the physical file, and should only appear - * in the directory name. However, Spark allows partition columns inside physical file, - * but Spark will discard the values from the file, and use the partition value got from - * directory name. The column order will be reserved though. - */ - @VisibleForTesting - public int[] requestedDataColIds; - // Native Record reader from ORC row batch. private OrcColumnarBatchScanReader recordReader; - private StructField[] requiredFields; - private StructField[] resultFields; - // The result columnar batch for vectorized execution by whole-stage codegen. @VisibleForTesting public ColumnarBatch columnarBatch; @@ -82,13 +64,15 @@ public class OmniOrcColumnarBatchReader extends RecordReader orcfieldNames = recordReader.fildsNames; // save valid cols and numbers of valid cols recordReader.colsToGet = new int[requiredfieldNames.length]; - recordReader.realColsCnt = 0; - // save valid cols fieldsNames - recordReader.colToInclu = new ArrayList(); + recordReader.includedColumns = new ArrayList<>(); // collect read cols types ArrayList typeBuilder = new ArrayList<>(); + for (int i = 0; i < requiredfieldNames.length; i++) { String target = requiredfieldNames[i]; - boolean is_find = false; - for (int j = 0; j < orcfieldNames.size(); j++) { - String temp = orcfieldNames.get(j); - if (target.equals(temp)) { - requestedDataColIds[i] = i; - recordReader.colsToGet[i] = 0; - recordReader.colToInclu.add(requiredfieldNames[i]); - recordReader.realColsCnt++; - typeBuilder.add(OmniExpressionAdaptor.sparkTypeToOmniType(requiredSchema.fields()[i].dataType())); - is_find = true; - } - } - - // if invalid, set colsToGet value -1, else set colsToGet 0 - if (!is_find) { + // if not find, set colsToGet value -1, else set colsToGet 0 + if (recordReader.allFieldsNames.contains(target)) { + recordReader.colsToGet[i] = 0; + recordReader.includedColumns.add(requiredfieldNames[i]); + typeBuilder.add(OmniExpressionAdaptor.sparkTypeToOmniType(requiredSchema.fields()[i].dataType())); + } else { recordReader.colsToGet[i] = -1; } } vecTypeIds = typeBuilder.stream().mapToInt(Integer::intValue).toArray(); - - for (int i = 0; i < resultFields.length; i++) { - if (requestedPartitionColIds[i] != -1) { - requestedDataColIds[i] = -1; - } - } - - // set data members resultFields and requestedDataColIdS - this.resultFields = resultFields; - this.requestedDataColIds = requestedDataColIds; - - recordReader.requiredfieldNames = requiredfieldNames; - recordReader.precisionArray = precisionArray; - recordReader.scaleArray = scaleArray; - recordReader.initializeRecordReaderJava(options); } /** * Initialize columnar batch by setting required schema and partition information. * With this information, this creates ColumnarBatch with the full schema. * - * @param requiredFields The fields that are required to return,. - * @param resultFields All the fields that are required to return, including partition fields. - * @param requestedDataColIds Requested column ids from orcSchema. -1 if not existed. - * @param requestedPartitionColIds Requested column ids from partition schema. -1 if not existed. + * @param partitionColumns partition columns * @param partitionValues Values of partition columns. */ - public void initBatch( - StructField[] requiredFields, - StructField[] resultFields, - int[] requestedDataColIds, - int[] requestedPartitionColIds, - InternalRow partitionValues) { - if (resultFields.length != requestedDataColIds.length || resultFields.length != requestedPartitionColIds.length){ - throw new UnsupportedOperationException("This operator doesn't support orc initBatch."); - } + public void initBatch(StructType partitionColumns, InternalRow partitionValues) { + StructType resultSchema = new StructType(); - this.requiredFields = requiredFields; + for (StructField f: requiredSchema.fields()) { + resultSchema = resultSchema.add(f); + } - StructType resultSchema = new StructType(resultFields); + if (partitionColumns != null) { + for (StructField f: partitionColumns.fields()) { + resultSchema = resultSchema.add(f); + } + } // Just wrap the ORC column vector instead of copying it to Spark column vector. orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; templateWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; - for (int i = 0; i < resultFields.length; i++) { - DataType dt = resultFields[i].dataType(); - if (requestedPartitionColIds[i] != -1) { - OmniColumnVector partitionCol = new OmniColumnVector(capacity, dt, true); - OmniColumnVectorUtils.populate(partitionCol, partitionValues, requestedPartitionColIds[i]); + if (partitionColumns != null) { + int partitionIdx = requiredSchema.fields().length; + for (int i = 0; i < partitionColumns.fields().length; i++) { + OmniColumnVector partitionCol = new OmniColumnVector(capacity, partitionColumns.fields()[i].dataType(), true); + OmniColumnVectorUtils.populate(partitionCol, partitionValues, i); partitionCol.setIsConstant(); - templateWrappers[i] = partitionCol; - orcVectorWrappers[i] = new OmniColumnVector(capacity, dt, false);; + templateWrappers[i + partitionIdx] = partitionCol; + orcVectorWrappers[i + partitionIdx] = new OmniColumnVector(capacity, partitionColumns.fields()[i].dataType(), false); + } + } + + for (int i = 0; i < requiredSchema.fields().length; i++) { + DataType dt = requiredSchema.fields()[i].dataType(); + if (recordReader.colsToGet[i] == -1) { + // missing cols + OmniColumnVector missingCol = new OmniColumnVector(capacity, dt, true); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + templateWrappers[i] = missingCol; } else { - int colId = requestedDataColIds[i]; - // Initialize the missing columns once. - if (colId == -1) { - OmniColumnVector missingCol = new OmniColumnVector(capacity, dt, true); - missingCol.putNulls(0, capacity); - missingCol.setIsConstant(); - templateWrappers[i] = missingCol; - } else { - templateWrappers[i] = new OmniColumnVector(capacity, dt, false); - } - orcVectorWrappers[i] = new OmniColumnVector(capacity, dt, false); + templateWrappers[i] = new OmniColumnVector(capacity, dt, false); } + orcVectorWrappers[i] = new OmniColumnVector(capacity, dt, false); } + // init batch recordReader.initBatchJava(capacity); vecs = new Vec[orcVectorWrappers.length]; @@ -260,7 +210,7 @@ public class OmniOrcColumnarBatchReader extends RecordReader - convertibleFiltersHelper(left, dataSchema) && convertibleFiltersHelper(right, dataSchema) - case Or(left, right) => - convertibleFiltersHelper(left, dataSchema) && convertibleFiltersHelper(right, dataSchema) - case Not(pred) => - convertibleFiltersHelper(pred, dataSchema) - case other => - other match { - case EqualTo(name, _) => - dataSchema.apply(name).dataType != StringType - case EqualNullSafe(name, _) => - dataSchema.apply(name).dataType != StringType - case LessThan(name, _) => - dataSchema.apply(name).dataType != StringType - case LessThanOrEqual(name, _) => - dataSchema.apply(name).dataType != StringType - case GreaterThan(name, _) => - dataSchema.apply(name).dataType != StringType - case GreaterThanOrEqual(name, _) => - dataSchema.apply(name).dataType != StringType - case IsNull(name) => - dataSchema.apply(name).dataType != StringType - case IsNotNull(name) => - dataSchema.apply(name).dataType != StringType - case In(name, _) => - dataSchema.apply(name).dataType != StringType - case _ => false - } - } - - filters.map { filter => - convertibleFiltersHelper(filter, dataSchema) - } - } - override def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, @@ -103,7 +61,6 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) val sqlConf = sparkSession.sessionState.conf val capacity = sqlConf.orcVectorizedReaderBatchSize @@ -113,32 +70,18 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown - val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles (file: PartitionedFile) => { val conf = broadcastedConf.value.value -// val filePath = new Path(new URI(file.filePath.urlEncoded)) val filePath = file.toPath - val fs = filePath.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val orcSchema = - Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions))(_.getSchema) - val isPPDSafeValue = isPPDSafe(filters, dataSchema).reduceOption(_ && _) - val resultedColPruneInfo = OrcUtils.requestedColumnIds( - isCaseSensitive, dataSchema, requiredSchema, orcSchema, conf) - - - // ORC predicate pushdown - if (orcFilterPushDown && filters.nonEmpty && isPPDSafeValue.getOrElse(false)) { - val fileSchema = OrcUtils.toCatalystSchema(orcSchema) - // OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { fileSchema => - OrcFilters.createFilter(fileSchema, filters).foreach { f => - OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) - // } - } - } + // ORC predicate pushdown + val pushed = if (orcFilterPushDown) { + filters.reduceOption(And(_, _)) + } else { + None + } val taskConf = new Configuration(conf) val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) @@ -146,42 +89,16 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) // read data from vectorized reader - val batchReader = new OmniOrcColumnarBatchReader(capacity) + val batchReader = new OmniOrcColumnarBatchReader(capacity, requiredSchema, pushed.orNull) // SPARK-23399 Register a task completion listener first to call `close()` in all cases. // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) // after opening a file. val iter = new RecordReaderIterator(batchReader) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) - // fill requestedDataColIds with -1, fil real values int initDataColIds function - val requestedDataColIds = Array.fill(requiredSchema.length)(-1) ++ Array.fill(partitionSchema.length)(-1) - val requestedPartitionColIds = - Array.fill(requiredSchema.length)(-1) ++ Range(0, partitionSchema.length) - - // 初始化precision数组和scale数组,透传至java侧使用 - val requiredFields = requiredSchema.fields - val fieldslength = requiredFields.length - val precisionArray : Array[Int] = Array.ofDim[Int](fieldslength) - val scaleArray : Array[Int] = Array.ofDim[Int](fieldslength) - for ((reqField, index) <- requiredFields.zipWithIndex) { - val reqdatatype = reqField.dataType - if (reqdatatype.isInstanceOf[DecimalType]) { - val precision = reqdatatype.asInstanceOf[DecimalType].precision - val scale = reqdatatype.asInstanceOf[DecimalType].scale - precisionArray(index) = precision - scaleArray(index) = scale - } - } SparkMemoryUtils.init() batchReader.initialize(fileSplit, taskAttemptContext) - batchReader.initDataColIds(requiredSchema, requestedPartitionColIds, requestedDataColIds, resultSchema.fields, - precisionArray, scaleArray) - batchReader.initBatch( - requiredSchema.fields, - resultSchema.fields, - requestedDataColIds, - requestedPartitionColIds, - file.partitionValues) + batchReader.initBatch(partitionSchema, file.partitionValues) iter.asInstanceOf[Iterator[InternalRow]] } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java index fe1c55ffb..ca54ddd73 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java @@ -89,7 +89,7 @@ public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase { includedColumns.add("i_current_price"); job.put("includedColumns", includedColumns.toArray()); - orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader.jniReader.initializeRecordReader(orcColumnarBatchScanReader.reader, job); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.recordReader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java index 995c434f6..78fa455b8 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java @@ -81,7 +81,7 @@ public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase { includedColumns.add("i_item_id"); job.put("includedColumns", includedColumns.toArray()); - orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader.jniReader.initializeRecordReader(orcColumnarBatchScanReader.reader, job); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.recordReader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java index c9ad9fada..2c912d919 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java @@ -66,7 +66,7 @@ public class OrcColumnarBatchJniReaderPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java index 8f4535338..cf86c0a5a 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java @@ -68,7 +68,7 @@ public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java index 27bcf5d7b..ef8d037bf 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java @@ -68,7 +68,7 @@ public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java index eab15fef6..19f23db00 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java @@ -18,103 +18,90 @@ package com.huawei.boostkit.spark.jni; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.io.Input; +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; import junit.framework.TestCase; import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.VarcharVec; import nova.hetu.omniruntime.vector.Vec; -import org.apache.commons.codec.binary.Base64; -import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; -import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl; -import org.apache.orc.OrcFile; -import org.apache.orc.TypeDescription; -import org.apache.orc.mapred.OrcInputFormat; +import org.apache.hadoop.conf.Configuration; +import org.apache.orc.Reader; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.junit.After; import org.junit.Before; import org.junit.FixMethodOrder; import org.junit.Test; import org.junit.runners.MethodSorters; -import org.apache.hadoop.conf.Configuration; + import java.io.File; import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; -import java.util.List; import java.util.Arrays; -import org.apache.orc.Reader.Options; - -import static org.junit.Assert.*; +import java.util.List; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import static org.apache.spark.sql.types.DataTypes.LongType; +import static org.apache.spark.sql.types.DataTypes.StringType; @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderTest extends TestCase { public Configuration conf = new Configuration(); public OrcColumnarBatchScanReader orcColumnarBatchScanReader; - public int batchSize = 4096; + private int batchSize = 4096; + + private StructType requiredSchema; + private int[] vecTypeIds; + + private long offset = 0; + + private long length = Integer.MAX_VALUE; @Before public void setUp() throws Exception { - Configuration conf = new Configuration(); - TypeDescription schema = - TypeDescription.fromString("struct<`i_item_sk`:bigint,`i_item_id`:string>"); - Options options = new Options(conf) - .range(0, Integer.MAX_VALUE) - .useZeroCopy(false) - .skipCorruptRecords(false) - .tolerateMissingSchema(true); - - options.schema(schema); - options.include(OrcInputFormat.parseInclude(schema, - null)); - String kryoSarg = "AQEAb3JnLmFwYWNoZS5oYWRvb3AuaGl2ZS5xbC5pby5zYXJnLkV4cHJlc3Npb25UcmXlAQEBamF2YS51dGlsLkFycmF5TGlz9AECAQABAQEBAQEAAQAAAAEEAAEBAwEAAQEBAQEBAAEAAAIIAAEJAAEBAgEBAQIBAscBb3JnLmFwYWNoZS5oYWRvb3AuaGl2ZS5xbC5pby5zYXJnLlNlYXJjaEFyZ3VtZW50SW1wbCRQcmVkaWNhdGVMZWFmSW1wbAEBaV9pdGVtX3PrAAABBwEBAQIBEAkAAAEEEg=="; - String sargColumns = "i_item_sk,i_item_id,i_rec_start_date,i_rec_end_date,i_item_desc,i_current_price,i_wholesale_cost,i_brand_id,i_brand,i_class_id,i_class,i_category_id,i_category,i_manufact_id,i_manufact,i_size,i_formulation,i_color,i_units,i_container,i_manager_id,i_product_name"; - if (kryoSarg != null && sargColumns != null) { - byte[] sargBytes = Base64.decodeBase64(kryoSarg); - SearchArgument sarg = - new Kryo().readObject(new Input(sargBytes), SearchArgumentImpl.class); - options.searchArgument(sarg, sargColumns.split(",")); - sarg.getExpression().toString(); - } - orcColumnarBatchScanReader = new OrcColumnarBatchScanReader(); + constructSchema(); initReaderJava(); - initDataColIds(options, orcColumnarBatchScanReader); - initRecordReaderJava(options); - initBatch(options); + initDataColIds(); + initRecordReaderJava(); + initBatch(); + } + + private void constructSchema() { + requiredSchema = new StructType() + .add("i_item_sk", LongType) + .add("i_item_id", StringType); } - public void initDataColIds( - Options options, OrcColumnarBatchScanReader orcColumnarBatchScanReader) { - List allCols; - allCols = Arrays.asList(options.getColumnNames()); - orcColumnarBatchScanReader.colToInclu = new ArrayList(); - List optionField = options.getSchema().getFieldNames(); - orcColumnarBatchScanReader.colsToGet = new int[optionField.size()]; - orcColumnarBatchScanReader.realColsCnt = 0; - for (int i = 0; i < optionField.size(); i++) { - if (allCols.contains(optionField.get(i))) { - orcColumnarBatchScanReader.colToInclu.add(optionField.get(i)); - orcColumnarBatchScanReader.colsToGet[i] = 0; - orcColumnarBatchScanReader.realColsCnt++; - } else { + private void initDataColIds() { + // find requiredS fieldNames + String[] requiredfieldNames = requiredSchema.fieldNames(); + // save valid cols and numbers of valid cols + orcColumnarBatchScanReader.colsToGet = new int[requiredfieldNames.length]; + orcColumnarBatchScanReader.includedColumns = new ArrayList<>(); + // collect read cols types + ArrayList typeBuilder = new ArrayList<>(); + + for (int i = 0; i < requiredfieldNames.length; i++) { + String target = requiredfieldNames[i]; + + // if not find, set colsToGet value -1, else set colsToGet 0 + boolean is_find = false; + for (int j = 0; j < orcColumnarBatchScanReader.allFieldsNames.size(); j++) { + if (target.equals(orcColumnarBatchScanReader.allFieldsNames.get(j))) { + orcColumnarBatchScanReader.colsToGet[i] = 0; + orcColumnarBatchScanReader.includedColumns.add(requiredfieldNames[i]); + typeBuilder.add(OmniExpressionAdaptor.sparkTypeToOmniType(requiredSchema.fields()[i].dataType())); + is_find = true; + break; + } + } + + if (!is_find) { orcColumnarBatchScanReader.colsToGet[i] = -1; } } - orcColumnarBatchScanReader.requiredfieldNames = new String[optionField.size()]; - TypeDescription schema = options.getSchema(); - int[] precisionArray = new int[optionField.size()]; - int[] scaleArray = new int[optionField.size()]; - for (int i = 0; i < optionField.size(); i++) { - precisionArray[i] = schema.findSubtype(optionField.get(i)).getPrecision(); - scaleArray[i] = schema.findSubtype(optionField.get(i)).getScale(); - orcColumnarBatchScanReader.requiredfieldNames[i] = optionField.get(i); - } - orcColumnarBatchScanReader.precisionArray = precisionArray; - orcColumnarBatchScanReader.scaleArray = scaleArray; + vecTypeIds = typeBuilder.stream().mapToInt(Integer::intValue).toArray(); } @After @@ -122,8 +109,7 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { System.out.println("orcColumnarBatchJniReader test finished"); } - public void initReaderJava() throws URISyntaxException { - OrcFile.ReaderOptions readerOptions = OrcFile.readerOptions(conf); + private void initReaderJava() { File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); String path = directory.getAbsolutePath(); URI uri = null; @@ -133,16 +119,17 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, readerOptions); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } - public void initRecordReaderJava(Options options) { - orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader.initializeRecordReaderJava(options); + private void initRecordReaderJava() { + orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader. + initializeRecordReaderJava(offset, length, null, requiredSchema); assertTrue(orcColumnarBatchScanReader.recordReader != 0); } - public void initBatch(Options options) { + private void initBatch() { orcColumnarBatchScanReader.initBatchJava(batchSize); assertTrue(orcColumnarBatchScanReader.batchReader != 0); } @@ -150,8 +137,7 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { @Test public void testNext() { Vec[] vecs = new Vec[2]; - int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()}; - long rtn = orcColumnarBatchScanReader.next(vecs, typeId); + long rtn = orcColumnarBatchScanReader.next(vecs, vecTypeIds); assertTrue(rtn == 4096); assertTrue(((LongVec) vecs[0]).get(0) == 1); String str = new String(((VarcharVec) vecs[1]).get(0)); -- Gitee