diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/NdpConnectorUtils.java b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/NdpConnectorUtils.java index d412bcd1ca9b08d5903d8f107f3f5924f353cfd2..fac56bea50880c43dc85717c4b47dc4a0bde5116 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/NdpConnectorUtils.java +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/NdpConnectorUtils.java @@ -40,4 +40,165 @@ public class NdpConnectorUtils { return System.getenv("NDP_PLUGIN_ENABLE") == null ? "false" : System.getenv("NDP_PLUGIN_ENABLE"); } + public static int getPushDownTaskTotal(int taskTotal) { + if (System.getenv("DEFAULT_PUSHDOWN_TASK") != null) { + return Integer.parseInt(System.getenv("DEFAULT_PUSHDOWN_TASK")); + } else { + return taskTotal; + } + } + + public static String getNdpNumPartitionsStr(String numStr) { + if (System.getenv("DEFAULT_NDP_NUM_PARTITIONS") != null) { + return System.getenv("DEFAULT_NDP_NUM_PARTITIONS"); + } else { + return numStr; + } + } + + + public static int getCountTaskTotal(int taskTotal) { + if (System.getenv("COUNT_TASK_TOTAL") != null) { + return Integer.parseInt(System.getenv("COUNT_TASK_TOTAL")); + } else { + return taskTotal; + } + } + + public static String getCountMaxPartSize(String size) { + if (System.getenv("COUNT_MAX_PART_SIZE") != null) { + return System.getenv("COUNT_MAX_PART_SIZE"); + } else { + return size; + } + } + + public static int getCountDistinctTaskTotal(int taskTotal) { + if (System.getenv("COUNT_DISTINCT_TASK_TOTAL") != null) { + return Integer.parseInt(System.getenv("COUNT_DISTINCT_TASK_TOTAL")); + } else { + return taskTotal; + } + } + + public static String getSMJMaxPartSize(String size) { + if (System.getenv("SMJ_MAX_PART_SIZE") != null) { + return System.getenv("SMJ_MAX_PART_SIZE"); + } else { + return size; + } + } + + public static int getSMJNumPartitions(int numPartitions) { + if (System.getenv("SMJ_NUM_PARTITIONS") != null) { + return Integer.parseInt(System.getenv("SMJ_NUM_PARTITIONS")); + } else { + return numPartitions; + } + } + + public static int getOmniColumnarNumPartitions(int numPartitions) { + if (System.getenv("OMNI_COLUMNAR_PARTITIONS") != null) { + return Integer.parseInt(System.getenv("OMNI_COLUMNAR_PARTITIONS")); + } else { + return numPartitions; + } + } + + public static int getOmniColumnarTaskCount(int taskTotal) { + if (System.getenv("OMNI_COLUMNAR_TASK_TOTAL") != null) { + return Integer.parseInt(System.getenv("OMNI_COLUMNAR_TASK_TOTAL")); + } else { + return taskTotal; + } + } + + public static int getFilterPartitions(int numPartitions) { + if (System.getenv("FILTER_COLUMNAR_PARTITIONS") != null) { + return Integer.parseInt(System.getenv("FILTER_COLUMNAR_PARTITIONS")); + } else { + return numPartitions; + } + } + + public static int getFilterTaskCount(int taskTotal) { + if (System.getenv("FILTER_TASK_TOTAL") != null) { + return Integer.parseInt(System.getenv("FILTER_TASK_TOTAL")); + } else { + return taskTotal; + } + } + + public static String getSortRepartitionSizeStr(String sizeStr) { + if (System.getenv("SORT_REPARTITION_SIZE") != null) { + return System.getenv("SORT_REPARTITION_SIZE"); + } else { + return sizeStr; + } + } + + public static String getCastDecimalPrecisionStr(String numStr) { + if (System.getenv("CAST_DECIMAL_PRECISION") != null) { + return System.getenv("CAST_DECIMAL_PRECISION"); + } else { + return numStr; + } + } + + public static String getCountAggMaxFilePtBytesStr(String BytesStr) { + if (System.getenv("COUNT_AGG_MAX_FILE_BYTES") != null) { + return System.getenv("COUNT_AGG_MAX_FILE_BYTES"); + } else { + return BytesStr; + } + } + + public static String getAvgAggMaxFilePtBytesStr(String BytesStr) { + if (System.getenv("AVG_AGG_MAX_FILE_BYTES") != null) { + return System.getenv("AVG_AGG_MAX_FILE_BYTES"); + } else { + return BytesStr; + } + } + + public static String getBhjMaxFilePtBytesStr(String BytesStr) { + if (System.getenv("BHJ_MAX_FILE_BYTES") != null) { + return System.getenv("BHJ_MAX_FILE_BYTES"); + } else { + return BytesStr; + } + } + + public static String getGroupMaxFilePtBytesStr(String BytesStr) { + if (System.getenv("GROUP_MAX_FILE_BYTES") != null) { + return System.getenv("GROUP_MAX_FILE_BYTES"); + } else { + return BytesStr; + } + } + + public static String getAggShufflePartitionsStr(String BytesStr) { + if (System.getenv("AGG_SHUFFLE_PARTITIONS") != null) { + return System.getenv("AGG_SHUFFLE_PARTITIONS"); + } else { + return BytesStr; + } + } + + public static String getShufflePartitionsStr(String BytesStr) { + if (System.getenv("SHUFFLE_PARTITIONS") != null) { + return System.getenv("SHUFFLE_PARTITIONS"); + } else { + return BytesStr; + } + } + + public static String getSortShufflePartitionsStr(String BytesStr) { + if (System.getenv("SORT_SHUFFLE_PARTITIONS") != null) { + return System.getenv("SORT_SHUFFLE_PARTITIONS"); + } else { + return BytesStr; + } + } + } diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnioffload/spark/ColumnarPlugin.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnioffload/spark/ColumnarPlugin.scala index 75b5097ddd1e26b97b24563fd143f173c5351cca..be9ea197edc610a2b3c686a5644ac1a25ed233e3 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnioffload/spark/ColumnarPlugin.scala +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnioffload/spark/ColumnarPlugin.scala @@ -1,6 +1,5 @@ -package com.huawei.boostkit.omnioffload.spark +package com.huawei.boostkit.omnidata.spark -import com.huawei.boostkit.omnidata.spark.NdpConnectorUtils import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation} @@ -28,25 +27,15 @@ import scala.collection.JavaConverters case class NdpOverrides() extends Rule[SparkPlan] { - var numPartitions: Int = SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.coalesce.numPartitions", "10000").toInt + var numPartitions: Int = -1 var pushDownTaskCount: Int = -1 var isSMJ = false var isSort = false - var isCount = false var hasCoalesce = false var hasShuffle = false def apply(plan: SparkPlan): SparkPlan = { - pushDownTaskCount = getOptimizerPushDownThreshold(plan.session) - if (CountReplaceRule.shouldReplaceCountOne(plan)) { - isCount = true - pushDownTaskCount = 200 - SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, "512MB") - } - if (CountReplaceRule.shouldReplaceDistinctCount(plan)) { - isCount = true - pushDownTaskCount = 2000 - } + PreRuleApply(plan) val ruleList = Seq(CountReplaceRule) val afterPlan = ruleList.foldLeft(plan) { case (sp, rule) => val result = rule.apply(sp) @@ -54,10 +43,29 @@ case class NdpOverrides() extends Rule[SparkPlan] { } val optimizedPlan = replaceWithOptimizedPlan(afterPlan) val finalPlan = replaceWithScanPlan(optimizedPlan) + PostRuleApply(finalPlan) + finalPlan + } + + def PreRuleApply(plan: SparkPlan): Unit = { + numPartitions = SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.coalesce.numPartitions", + NdpConnectorUtils.getNdpNumPartitionsStr("10000")).toInt + pushDownTaskCount = NdpConnectorUtils.getPushDownTaskTotal(getOptimizerPushDownThreshold(plan.session)) + if (CountReplaceRule.shouldReplaceCountOne(plan)) { + pushDownTaskCount = NdpConnectorUtils.getCountTaskTotal(200) + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getCountMaxPartSize("512MB")) + } + if (CountReplaceRule.shouldReplaceDistinctCount(plan)) { + pushDownTaskCount = NdpConnectorUtils.getCountDistinctTaskTotal(2000) + } + } + + def PostRuleApply(plan: SparkPlan): Unit = { if (isSMJ) { - SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, "536870912") + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getSMJMaxPartSize("536870912")) } - finalPlan } //now set task total number, we can use this number pushDown task in thread @@ -77,7 +85,7 @@ case class NdpOverrides() extends Rule[SparkPlan] { def replaceWithOptimizedPlan(plan: SparkPlan): SparkPlan = { val p = plan.transformUp { - case shuffle:ShuffleExchangeExec => + case shuffle: ShuffleExchangeExec => hasShuffle = true shuffle case p@ColumnarSortExec(sortOrder, global, child, testSpillFrequency) if isRadixSortExecEnable(sortOrder) => @@ -87,27 +95,37 @@ case class NdpOverrides() extends Rule[SparkPlan] { isSort = true RadixSortExec(sortOrder, global, child, testSpillFrequency) case p@DataWritingCommandExec(cmd, child) => - if (isSort || isCount || isVagueAndAccurateHd(child)) { + if (isSort || isVagueAndAccurateHd(child)) { p } else { hasCoalesce = true DataWritingCommandExec(cmd, CoalesceExec(numPartitions, child)) } - case p@ColumnarSortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right, isSkewJoin, projectList) if joinType.equals(LeftOuter) => + case p@ColumnarSortMergeJoinExec(_, _, joinType, _, _, _, _, projectList) + if joinType.equals(LeftOuter) => isSMJ = true - numPartitions = 5000 - ColumnarSortMergeJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, condition = p.condition, left = p.left, right = p.right, isSkewJoin = p.isSkewJoin, projectList) - case p@SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right, isSkewJoin) if joinType.equals(LeftOuter) => + numPartitions = NdpConnectorUtils.getSMJNumPartitions(5000) + ColumnarSortMergeJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, + condition = p.condition, left = p.left, right = p.right, isSkewJoin = p.isSkewJoin, projectList) + case p@SortMergeJoinExec(_, _, joinType, _, _, _, _) + if joinType.equals(LeftOuter) => isSMJ = true - numPartitions = 5000 - SortMergeJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, condition = p.condition, left = p.left, right = p.right, isSkewJoin = p.isSkewJoin) - case p@ColumnarBroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right, isNullAwareAntiJoin, projectList) if joinType.equals(LeftOuter) => - ColumnarBroadcastHashJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, buildSide = p.buildSide, condition = p.condition, left = p.left, right = p.right, isNullAwareAntiJoin = p.isNullAwareAntiJoin, projectList) - case p@BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right, isNullAwareAntiJoin) if joinType.equals(LeftOuter) => - BroadcastHashJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, buildSide = p.buildSide, condition = p.condition, left = p.left, right = p.right, isNullAwareAntiJoin = p.isNullAwareAntiJoin) - case p@ColumnarShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right, projectList) if joinType.equals(LeftOuter) => - ColumnarShuffledHashJoinExec(p.leftKeys, p.rightKeys, LeftAnti, p.buildSide, p.condition, p.left, p.right, projectList) - case p@ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right, isSkewJoin) if joinType.equals(LeftOuter) => + numPartitions = NdpConnectorUtils.getSMJNumPartitions(5000) + SortMergeJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, condition = p.condition, + left = p.left, right = p.right, isSkewJoin = p.isSkewJoin) + case p@ColumnarBroadcastHashJoinExec(_, _, joinType, _, _, _, _, _, projectList) if joinType.equals(LeftOuter) => + ColumnarBroadcastHashJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, + joinType = LeftAnti, buildSide = p.buildSide, condition = p.condition, left = p.left, + right = p.right, isNullAwareAntiJoin = p.isNullAwareAntiJoin, projectList) + case p@BroadcastHashJoinExec(_, _, joinType, _, _, _, _, _) if joinType.equals(LeftOuter) => + BroadcastHashJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, + buildSide = p.buildSide, condition = p.condition, left = p.left, right = p.right, + isNullAwareAntiJoin = p.isNullAwareAntiJoin) + case p@ColumnarShuffledHashJoinExec(_, _, joinType, _, _, _, _, projectList) + if joinType.equals(LeftOuter) => + ColumnarShuffledHashJoinExec(p.leftKeys, p.rightKeys, LeftAnti, p.buildSide, p.condition, + p.left, p.right, projectList) + case p@ShuffledHashJoinExec(_, _, joinType, _, _, _, _, isSkewJoin) if joinType.equals(LeftOuter) => ShuffledHashJoinExec(p.leftKeys, p.rightKeys, LeftAnti, p.buildSide, p.condition, p.left, p.right, isSkewJoin) case p@FilterExec(condition, child: OmniColumnarToRowExec, selectivity) => val childPlan = child.transform { @@ -128,20 +146,22 @@ case class NdpOverrides() extends Rule[SparkPlan] { } FilterExec(condition, childPlan, selectivity) case c1@OmniColumnarToRowExec(c2@ColumnarFilterExec(condition, c3: FileSourceScanExec)) => - numPartitions = 1000 + numPartitions = NdpConnectorUtils.getOmniColumnarNumPartitions(1000) if (isAccurate(condition)) { - pushDownTaskCount = 400 + pushDownTaskCount = NdpConnectorUtils.getOmniColumnarTaskCount(400) } FilterExec(condition, ColumnarToRowExec(c3)) - case p@FilterExec(condition, child, selectivity) if isAccurate(condition) => - numPartitions = 1000 - pushDownTaskCount = 400 + case p@FilterExec(condition, _, _) if isAccurate(condition) => + numPartitions = NdpConnectorUtils.getFilterPartitions(1000) + pushDownTaskCount = NdpConnectorUtils.getFilterTaskCount(400) p - case p@ColumnarConditionProjectExec(projectList, condition, child) if condition.toString().startsWith("isnull") && (child.isInstanceOf[ColumnarSortMergeJoinExec] - || child.isInstanceOf[ColumnarBroadcastHashJoinExec] || child.isInstanceOf[ColumnarShuffledHashJoinExec]) => + case p@ColumnarConditionProjectExec(projectList, condition, child) + if condition.toString().startsWith("isnull") && (child.isInstanceOf[ColumnarSortMergeJoinExec] + || child.isInstanceOf[ColumnarBroadcastHashJoinExec] || child.isInstanceOf[ColumnarShuffledHashJoinExec]) => ColumnarProjectExec(changeProjectList(projectList), child) - case p@ProjectExec(projectList, filter: FilterExec) if filter.condition.toString().startsWith("isnull") && (filter.child.isInstanceOf[SortMergeJoinExec] - || filter.child.isInstanceOf[BroadcastHashJoinExec] || filter.child.isInstanceOf[ShuffledHashJoinExec]) => + case p@ProjectExec(projectList, filter: FilterExec) + if filter.condition.toString().startsWith("isnull") && (filter.child.isInstanceOf[SortMergeJoinExec] + || filter.child.isInstanceOf[BroadcastHashJoinExec] || filter.child.isInstanceOf[ShuffledHashJoinExec]) => ProjectExec(changeProjectList(projectList), filter.child) case p: SortAggregateExec if p.child.isInstanceOf[OmniColumnarToRowExec] && p.child.asInstanceOf[OmniColumnarToRowExec].child.isInstanceOf[ColumnarSortExec] @@ -173,9 +193,11 @@ case class NdpOverrides() extends Rule[SparkPlan] { p.resultExpressions, SortExec(omniColumnarSort.sortOrder, omniColumnarSort.global, - ShuffleExchangeExec(omniShuffleExchange.outputPartitioning, rowToOmniColumnar.child, omniShuffleExchange.shuffleOrigin), + ShuffleExchangeExec(omniShuffleExchange.outputPartitioning, rowToOmniColumnar.child, + omniShuffleExchange.shuffleOrigin), omniColumnarSort.testSpillFrequency)) - case p@OmniColumnarToRowExec(agg: ColumnarHashAggregateExec) if agg.groupingExpressions.nonEmpty && agg.child.isInstanceOf[ColumnarShuffleExchangeExec] => + case p@OmniColumnarToRowExec(agg: ColumnarHashAggregateExec) + if agg.groupingExpressions.nonEmpty && agg.child.isInstanceOf[ColumnarShuffleExchangeExec] => val omniExchange = agg.child.asInstanceOf[ColumnarShuffleExchangeExec] val omniHashAgg = omniExchange.child.asInstanceOf[ColumnarHashAggregateExec] HashAggregateExec(agg.requiredChildDistributionExpressions, @@ -318,11 +340,14 @@ case class NdpOptimizerRules(session: SparkSession) extends Rule[LogicalPlan] { ) val SORT_REPARTITION_SIZE: Int = SQLConf.get.getConfString( - "spark.omni.sql.ndpPlugin.sort.repartition.size", "104857600").toInt + "spark.omni.sql.ndpPlugin.sort.repartition.size", + NdpConnectorUtils.getSortRepartitionSizeStr("104857600")).toInt val DECIMAL_PRECISION: Int = SQLConf.get.getConfString( - "spark.omni.sql.ndpPlugin.cast.decimal.precision", "15").toInt + "spark.omni.sql.ndpPlugin.cast.decimal.precision", + NdpConnectorUtils.getCastDecimalPrecisionStr("15")).toInt val MAX_PARTITION_BYTES_ENABLE_FACTOR: Int = SQLConf.get.getConfString( - "spark.omni.sql.ndpPlugin.max.partitionBytesEnable.factor", "2").toInt + "spark.omni.sql.ndpPlugin.max.partitionBytesEnable.factor", + NdpConnectorUtils.getNdpMaxPtFactorStr("2")).toInt override def apply(plan: LogicalPlan): LogicalPlan = { @@ -347,9 +372,11 @@ case class NdpOptimizerRules(session: SparkSession) extends Rule[LogicalPlan] { .toBoolean => var ifCast = false if (groupingExpressions.nonEmpty && hasCount(aggregateExpressions)) { - SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, "1024MB") + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getCountAggMaxFilePtBytesStr("1024MB")) } else if (groupingExpressions.nonEmpty && hasAvg(aggregateExpressions)) { - SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, "256MB") + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getAvgAggMaxFilePtBytesStr("256MB")) ifCast = true } if (ifCast) { @@ -363,7 +390,8 @@ case class NdpOptimizerRules(session: SparkSession) extends Rule[LogicalPlan] { case j@Join(_, _, Inner, condition, _) => // turnOffOperator() // 6-x-bhj - SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, "512MB") + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getBhjMaxFilePtBytesStr("512MB")) if (condition.isDefined) { condition.get match { case e@EqualTo(attr1: AttributeReference, attr2: AttributeReference) => @@ -444,9 +472,11 @@ case class NdpOptimizerRules(session: SparkSession) extends Rule[LogicalPlan] { // agg shuffle partition 200 ,other 5000 if (existsTable && existsAgg) { // SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, "536870912") - SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, "200") + SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, + NdpConnectorUtils.getAggShufflePartitionsStr("200")) } else { - SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, "5000") + SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, + NdpConnectorUtils.getShufflePartitionsStr("5000")) } repartitionShuffleForSort(fs, tables, planContents) repartitionHdfsReadForDistinct(fs, tables, plan) @@ -461,7 +491,8 @@ case class NdpOptimizerRules(session: SparkSession) extends Rule[LogicalPlan] { if (tables.length == 1 && SORT_REPARTITION_PLANS.exists(planContent.contains(_))) { val partitions = Math.max(1, fs.getContentSummary(new Path(tables.head)).getLength / SORT_REPARTITION_SIZE) - SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, "1000") + SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, + NdpConnectorUtils.getSortShufflePartitionsStr("1000")) // SQLConf.get.setConfString("spark.shuffle.sort.bypassMergeThreshold", "1000") turnOffOperator() } @@ -477,14 +508,8 @@ case class NdpOptimizerRules(session: SparkSession) extends Rule[LogicalPlan] { plan.foreach { case Aggregate(groupingExpressions, aggregateExpressions, _) if groupingExpressions == aggregateExpressions => - val executors = SQLConf.get.getConfString("spark.executor.instances", "14").toLong - val cores = SQLConf.get.getConfString("spark.executor.cores", "21").toLong - val multi = SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.read.cores.multi", "16").toFloat - val partitionByte = (fs.getContentSummary(new Path(tables.head)).getLength / (executors * cores * multi)).toLong - // SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, - // Math.max(SQLConf.get.filesMaxPartitionBytes, partitionByte).toString) - SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, "1024MB") - // println(s"partitionByte:${partitionByte},partitions:${executors * cores * multi}") + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getGroupMaxFilePtBytesStr("1024MB")) return case _ => } @@ -558,4 +583,4 @@ object NdpPluginEnableFlag { ndpEnabled && (isMatchedIpAddress || NdpConnectorUtils.getNdpEnable.equals("true")) } -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt index dd0b79dbabc3b9c65ff8edf8b6c34853f51a0f63..491cfb7086037229608f2963cf6c278ca132b198 100644 --- a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt @@ -5,7 +5,7 @@ project(spark-thestral-plugin) cmake_minimum_required(VERSION 3.10) # configure cmake -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_COMPILER "g++") set(root_directory ${PROJECT_BINARY_DIR}) diff --git a/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt index e954e4b1c06bd9b9507ba003ff5bddf1536bc862..ab93271cc0ece3de8255699bc52378bdd15c97c9 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt @@ -44,7 +44,6 @@ target_link_libraries (${PROJ_TARGET} PUBLIC snappy lz4 zstd - boostkit-omniop-runtime-1.2.0-aarch64 boostkit-omniop-vector-1.2.0-aarch64 ) diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp b/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp index 2c6b9fab89ec31c7df596cc4e9b14e3f869a12b2..f33d5c4c9df9695c2464b622587dea9e3546c39c 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp @@ -76,21 +76,4 @@ spark::CompressionKind GetCompressionType(const std::string& name) { int IsFileExist(const std::string path) { return !access(path.c_str(), F_OK); -} - -void ReleaseVectorBatch(omniruntime::vec::VectorBatch& vb) -{ - int tmpVectorNum = vb.GetVectorCount(); - std::set vectorBatchAddresses; - vectorBatchAddresses.clear(); - for (int vecIndex = 0; vecIndex < tmpVectorNum; ++vecIndex) { - vectorBatchAddresses.insert(vb.GetVector(vecIndex)); - } - for (Vector * tmpAddress : vectorBatchAddresses) { - if (nullptr == tmpAddress) { - throw std::runtime_error("delete nullptr error for release vectorBatch"); - } - delete tmpAddress; - } - vectorBatchAddresses.clear(); } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/common.h b/omnioperator/omniop-spark-extension/cpp/src/common/common.h index fdc3b10e692e3944eeee9cf70f96ed47262a5e77..733dac920727489b205727d32300252bd32626c5 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/common/common.h +++ b/omnioperator/omniop-spark-extension/cpp/src/common/common.h @@ -45,6 +45,4 @@ spark::CompressionKind GetCompressionType(const std::string& name); int IsFileExist(const std::string path); -void ReleaseVectorBatch(omniruntime::vec::VectorBatch& vb); - #endif //CPP_COMMON_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp index 7506424fbeaa98f0a84f546662e6f4b361eeddaf..bfaeca66378356370a36b32f16298d11c9ebad2c 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp @@ -18,9 +18,11 @@ */ #include "OrcColumnarBatchJniReader.h" +#include #include "jni_common.h" using namespace omniruntime::vec; +using namespace omniruntime::type; using namespace std; using namespace orc; @@ -36,6 +38,8 @@ jmethodID arrayListGet; jmethodID arrayListSize; jmethodID jsonMethodObj; +static constexpr int32_t MAX_DECIMAL64_DIGITS = 18; + int initJniId(JNIEnv *env) { /* @@ -128,19 +132,18 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniRe JNI_FUNC_END(runtimeExceptionClass) } -bool stringToBool(string boolStr) +bool StringToBool(const std::string &boolStr) { - transform(boolStr.begin(), boolStr.end(), boolStr.begin(), ::tolower); - if (boolStr == "true") { - return true; - } else if (boolStr == "false") { - return false; + if (boost::iequals(boolStr, "true")) { + return true; + } else if (boost::iequals(boolStr, "false")) { + return false; } else { - throw std::runtime_error("Invalid input for stringToBool."); + throw std::runtime_error("Invalid input for stringToBool."); } } -int getLiteral(orc::Literal &lit, int leafType, string value) +int GetLiteral(orc::Literal &lit, int leafType, const std::string &value) { switch ((orc::PredicateDataType)leafType) { case orc::PredicateDataType::LONG: { @@ -173,7 +176,7 @@ int getLiteral(orc::Literal &lit, int leafType, string value) break; } case orc::PredicateDataType::BOOLEAN: { - lit = orc::Literal(static_cast(stringToBool(value))); + lit = orc::Literal(static_cast(StringToBool(value))); break; } default: { @@ -183,8 +186,8 @@ int getLiteral(orc::Literal &lit, int leafType, string value) return 0; } -int buildLeaves(PredicateOperatorType leafOp, vector &litList, Literal &lit, string leafNameString, PredicateDataType leafType, - SearchArgumentBuilder &builder) +int BuildLeaves(PredicateOperatorType leafOp, vector &litList, Literal &lit, const std::string &leafNameString, + PredicateDataType leafType, SearchArgumentBuilder &builder) { switch (leafOp) { case PredicateOperatorType::LESS_THAN: { @@ -234,7 +237,7 @@ int initLeaves(JNIEnv *env, SearchArgumentBuilder &builder, jobject &jsonExp, jo if (leafValue != nullptr) { std::string leafValueString(env->GetStringUTFChars(leafValue, nullptr)); if (leafValueString.size() != 0) { - getLiteral(lit, leafType, leafValueString); + GetLiteral(lit, leafType, leafValueString); } } std::vector litList; @@ -244,11 +247,11 @@ int initLeaves(JNIEnv *env, SearchArgumentBuilder &builder, jobject &jsonExp, jo for (int i = 0; i < childs; i++) { jstring child = (jstring)env->CallObjectMethod(litListValue, arrayListGet, i); std::string childString(env->GetStringUTFChars(child, nullptr)); - getLiteral(lit, leafType, childString); + GetLiteral(lit, leafType, childString); litList.push_back(lit); } } - buildLeaves((PredicateOperatorType)leafOp, litList, lit, leafNameString, (PredicateDataType)leafType, builder); + BuildLeaves((PredicateOperatorType)leafOp, litList, lit, leafNameString, (PredicateDataType)leafType, builder); return 1; } @@ -346,133 +349,225 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniRe JNI_FUNC_END(runtimeExceptionClass) } -template uint64_t copyFixwidth(orc::ColumnVectorBatch *field) +template uint64_t CopyFixedWidth(orc::ColumnVectorBatch *field) { - VectorAllocator *allocator = omniruntime::vec::GetProcessGlobalVecAllocator(); using T = typename NativeType::type; ORC_TYPE *lvb = dynamic_cast(field); - FixedWidthVector *originalVector = new FixedWidthVector(allocator, lvb->numElements); - for (uint i = 0; i < lvb->numElements; i++) { - if (lvb->notNull.data()[i]) { - originalVector->SetValue(i, (T)(lvb->data.data()[i])); - } else { - originalVector->SetValueNull(i); + auto numElements = lvb->numElements; + auto values = lvb->data.data(); + auto notNulls = lvb->notNull.data(); + auto originalVector = new Vector(numElements); + // Check ColumnVectorBatch has null or not firstly + if (lvb->hasNulls) { + for (uint i = 0; i < numElements; i++) { + if (notNulls[i]) { + originalVector->SetValue(i, (T)(values[i])); + } else { + originalVector->SetNull(i); + } + } + } else { + for (uint i = 0; i < numElements; i++) { + originalVector->SetValue(i, (T)(values[i])); } } return (uint64_t)originalVector; } +template uint64_t CopyOptimizedForInt64(orc::ColumnVectorBatch *field) +{ + using T = typename NativeType::type; + ORC_TYPE *lvb = dynamic_cast(field); + auto numElements = lvb->numElements; + auto values = lvb->data.data(); + auto notNulls = lvb->notNull.data(); + auto originalVector = new Vector(numElements); + // Check ColumnVectorBatch has null or not firstly + if (lvb->hasNulls) { + for (uint i = 0; i < numElements; i++) { + if (!notNulls[i]) { + originalVector->SetNull(i); + } + } + } + originalVector->SetValues(0, values, numElements); + return (uint64_t)originalVector; +} -uint64_t copyVarwidth(int maxLen, orc::ColumnVectorBatch *field, int vcType) +uint64_t CopyVarWidth(orc::ColumnVectorBatch *field) { - VectorAllocator *allocator = omniruntime::vec::GetProcessGlobalVecAllocator(); orc::StringVectorBatch *lvb = dynamic_cast(field); - uint64_t totalLen = - maxLen * (lvb->numElements) > lvb->getMemoryUsage() ? maxLen * (lvb->numElements) : lvb->getMemoryUsage(); - VarcharVector *originalVector = new VarcharVector(allocator, totalLen, lvb->numElements); - for (uint i = 0; i < lvb->numElements; i++) { - if (lvb->notNull.data()[i]) { - string tmpStr(reinterpret_cast(lvb->data.data()[i]), lvb->length.data()[i]); - if (vcType == orc::TypeKind::CHAR && tmpStr.back() == ' ') { - tmpStr.erase(tmpStr.find_last_not_of(" ") + 1); + auto numElements = lvb->numElements; + auto values = lvb->data.data(); + auto notNulls = lvb->notNull.data(); + auto lens = lvb->length.data(); + auto originalVector = new Vector>(numElements); + if (lvb->hasNulls) { + for (uint i = 0; i < numElements; i++) { + if (notNulls[i]) { + auto data = std::string_view(reinterpret_cast(values[i]), lens[i]); + originalVector->SetValue(i, data); + } else { + originalVector->SetNull(i); } - originalVector->SetValue(i, reinterpret_cast(tmpStr.data()), tmpStr.length()); - } else { - originalVector->SetValueNull(i); + } + } else { + for (uint i = 0; i < numElements; i++) { + auto data = std::string_view(reinterpret_cast(values[i]), lens[i]); + originalVector->SetValue(i, data); } } return (uint64_t)originalVector; } -int copyToOmniVec(orc::TypeKind vcType, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, ...) +inline void FindLastNotEmpty(const char *chars, long &len) { - switch (vcType) { - case orc::TypeKind::BOOLEAN: { + while (len > 0 && chars[len - 1] == ' ') { + len--; + } +} + +uint64_t CopyCharType(orc::ColumnVectorBatch *field) +{ + orc::StringVectorBatch *lvb = dynamic_cast(field); + auto numElements = lvb->numElements; + auto values = lvb->data.data(); + auto notNulls = lvb->notNull.data(); + auto lens = lvb->length.data(); + auto originalVector = new Vector>(numElements); + if (lvb->hasNulls) { + for (uint i = 0; i < numElements; i++) { + if (notNulls[i]) { + auto chars = reinterpret_cast(values[i]); + auto len = lens[i]; + FindLastNotEmpty(chars, len); + auto data = std::string_view(chars, len); + originalVector->SetValue(i, data); + } else { + originalVector->SetNull(i); + } + } + } else { + for (uint i = 0; i < numElements; i++) { + auto chars = reinterpret_cast(values[i]); + auto len = lens[i]; + FindLastNotEmpty(chars, len); + auto data = std::string_view(chars, len); + originalVector->SetValue(i, data); + } + } + return (uint64_t)originalVector; +} + +inline void TransferDecimal128(int64_t &highbits, uint64_t &lowbits) +{ + if (highbits < 0) { // int128's 2s' complement code + lowbits = ~lowbits + 1; // 2s' complement code + highbits = ~highbits; //1s' complement code + if (lowbits == 0) { + highbits += 1; // carry a number as in adding + } + highbits ^= ((uint64_t)1 << 63); + } +} + +uint64_t CopyToOmniDecimal128Vec(orc::ColumnVectorBatch *field) +{ + orc::Decimal128VectorBatch *lvb = dynamic_cast(field); + auto numElements = lvb->numElements; + auto values = lvb->values.data(); + auto notNulls = lvb->notNull.data(); + auto originalVector = new Vector(numElements); + if (lvb->hasNulls) { + for (uint i = 0; i < numElements; i++) { + if (notNulls[i]) { + auto highbits = values[i].getHighBits(); + auto lowbits = values[i].getLowBits(); + TransferDecimal128(highbits, lowbits); + Decimal128 d128(highbits, lowbits); + originalVector->SetValue(i, d128); + } else { + originalVector->SetNull(i); + } + } + } else { + for (uint i = 0; i < numElements; i++) { + auto highbits = values[i].getHighBits(); + auto lowbits = values[i].getLowBits(); + TransferDecimal128(highbits, lowbits); + Decimal128 d128(highbits, lowbits); + originalVector->SetValue(i, d128); + } + } + return (uint64_t)originalVector; +} + +uint64_t CopyToOmniDecimal64Vec(orc::ColumnVectorBatch *field) +{ + orc::Decimal64VectorBatch *lvb = dynamic_cast(field); + auto numElements = lvb->numElements; + auto values = lvb->values.data(); + auto notNulls = lvb->notNull.data(); + auto originalVector = new Vector(numElements); + if (lvb->hasNulls) { + for (uint i = 0; i < numElements; i++) { + if (!notNulls[i]) { + originalVector->SetNull(i); + } + } + } + originalVector->SetValues(0, values, numElements); + return (uint64_t)originalVector; +} + +int CopyToOmniVec(const orc::Type *type, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field) +{ + switch (type->getKind()) { + case orc::TypeKind::BOOLEAN: omniTypeId = static_cast(OMNI_BOOLEAN); - omniVecId = copyFixwidth(field); + omniVecId = CopyFixedWidth(field); break; - } - case orc::TypeKind::SHORT: { + case orc::TypeKind::SHORT: omniTypeId = static_cast(OMNI_SHORT); - omniVecId = copyFixwidth(field); + omniVecId = CopyFixedWidth(field); break; - } - case orc::TypeKind::DATE: { + case orc::TypeKind::DATE: omniTypeId = static_cast(OMNI_DATE32); - omniVecId = copyFixwidth(field); + omniVecId = CopyFixedWidth(field); break; - } - case orc::TypeKind::INT: { + case orc::TypeKind::INT: omniTypeId = static_cast(OMNI_INT); - omniVecId = copyFixwidth(field); + omniVecId = CopyFixedWidth(field); break; - } - case orc::TypeKind::LONG: { + case orc::TypeKind::LONG: omniTypeId = static_cast(OMNI_LONG); - omniVecId = copyFixwidth(field); + omniVecId = CopyOptimizedForInt64(field); break; - } - case orc::TypeKind::DOUBLE: { + case orc::TypeKind::DOUBLE: omniTypeId = static_cast(OMNI_DOUBLE); - omniVecId = copyFixwidth(field); + omniVecId = CopyOptimizedForInt64(field); break; - } case orc::TypeKind::CHAR: + omniTypeId = static_cast(OMNI_VARCHAR); + omniVecId = CopyCharType(field); + break; case orc::TypeKind::STRING: - case orc::TypeKind::VARCHAR: { + case orc::TypeKind::VARCHAR: omniTypeId = static_cast(OMNI_VARCHAR); - va_list args; - va_start(args, field); - omniVecId = (uint64_t)copyVarwidth(va_arg(args, int), field, vcType); - va_end(args); + omniVecId = CopyVarWidth(field); break; - } - default: { - throw std::runtime_error("Native ColumnarFileScan Not support For This Type: " + vcType); - } - } - return 1; -} - -int copyToOmniDecimalVec(int precision, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field) -{ - VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); - if (precision > 18) { - omniTypeId = static_cast(OMNI_DECIMAL128); - orc::Decimal128VectorBatch *lvb = dynamic_cast(field); - FixedWidthVector *originalVector = - new FixedWidthVector(allocator, lvb->numElements); - for (uint i = 0; i < lvb->numElements; i++) { - if (lvb->notNull.data()[i]) { - int64_t highbits = lvb->values.data()[i].getHighBits(); - uint64_t lowbits = lvb->values.data()[i].getLowBits(); - if (highbits < 0) { // int128's 2s' complement code - lowbits = ~lowbits + 1; // 2s' complement code - highbits = ~highbits; //1s' complement code - if (lowbits == 0) { - highbits += 1; // carry a number as in adding - } - highbits ^= ((uint64_t)1 << 63); - } - Decimal128 d128(highbits, lowbits); - originalVector->SetValue(i, d128); - } else { - originalVector->SetValueNull(i); - } - } - omniVecId = (uint64_t)originalVector; - } else { - omniTypeId = static_cast(OMNI_DECIMAL64); - orc::Decimal64VectorBatch *lvb = dynamic_cast(field); - FixedWidthVector *originalVector = new FixedWidthVector(allocator, lvb->numElements); - for (uint i = 0; i < lvb->numElements; i++) { - if (lvb->notNull.data()[i]) { - originalVector->SetValue(i, (int64_t)(lvb->values.data()[i])); + case orc::TypeKind::DECIMAL: + if (type->getPrecision() > MAX_DECIMAL64_DIGITS) { + omniTypeId = static_cast(OMNI_DECIMAL128); + omniVecId = CopyToOmniDecimal128Vec(field); } else { - originalVector->SetValueNull(i); + omniTypeId = static_cast(OMNI_DECIMAL64); + omniVecId = CopyToOmniDecimal64Vec(field); } + break; + default: { + throw std::runtime_error("Native ColumnarFileScan Not support For This Type: " + type->getKind()); } - omniVecId = (uint64_t)originalVector; } return 1; } @@ -491,16 +586,10 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniRe vecCnt = root->fields.size(); batchRowSize = root->fields[0]->numElements; for (int id = 0; id < vecCnt; id++) { - orc::TypeKind vcType = baseTp.getSubtype(id)->getKind(); - int maxLen = baseTp.getSubtype(id)->getMaximumLength(); + auto type = baseTp.getSubtype(id); int omniTypeId = 0; uint64_t omniVecId = 0; - if (vcType != orc::TypeKind::DECIMAL) { - copyToOmniVec(vcType, omniTypeId, omniVecId, root->fields[id], maxLen); - } else { - copyToOmniDecimalVec(baseTp.getSubtype(id)->getPrecision(), omniTypeId, omniVecId, - root->fields[id]); - } + CopyToOmniVec(type, omniTypeId, omniVecId, root->fields[id]); env->SetIntArrayRegion(typeId, id, 1, &omniTypeId); jlong omniVec = static_cast(omniVecId); env->SetLongArrayRegion(vecNativeId, id, 1, &omniVec); diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h b/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h index 975de176f9c99f5bb78001a3beb88db5d43d9298..714d97ee67df137da8c6dcc79f8aa2173a33066d 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h @@ -22,28 +22,27 @@ #ifndef THESTRAL_PLUGIN_ORCCOLUMNARBATCHJNIREADER_H #define THESTRAL_PLUGIN_ORCCOLUMNARBATCHJNIREADER_H -#include "orc/ColumnPrinter.hh" -#include "orc/Exceptions.hh" -#include "orc/Type.hh" -#include "orc/Vector.hh" -#include "orc/Reader.hh" -#include "orc/OrcFile.hh" -#include "orc/MemoryPool.hh" -#include "orc/sargs/SearchArgument.hh" -#include "orc/sargs/Literal.hh" -#include -#include #include #include #include -#include -#include "jni.h" -#include "json/json.h" -#include "vector/vector_common.h" -#include "util/omni_exception.h" -#include +#include +#include #include -#include "../common/debug.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/debug.h" #ifdef __cplusplus extern "C" { @@ -135,18 +134,14 @@ JNIEXPORT jobjectArray JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBat JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_getNumberOfRows(JNIEnv *env, jobject jObj, jlong rowReader, jlong batch); -int getLiteral(orc::Literal &lit, int leafType, std::string value); - -int buildLeaves(PredicateOperatorType leafOp, std::vector &litList, orc::Literal &lit, std::string leafNameString, orc::PredicateDataType leafType, - orc::SearchArgumentBuilder &builder); - -bool stringToBool(std::string boolStr); +int GetLiteral(orc::Literal &lit, int leafType, const std::string &value); -int copyToOmniVec(orc::TypeKind vcType, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, ...); +int BuildLeaves(PredicateOperatorType leafOp, std::vector &litList, orc::Literal &lit, + const std::string &leafNameString, orc::PredicateDataType leafType, orc::SearchArgumentBuilder &builder); -int copyToOmniDecimalVec(int precision, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field); +bool StringToBool(const std::string &boolStr); -int copyToOmniDecimalVec(int precision, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field); +int CopyToOmniVec(const orc::Type *type, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field); #ifdef __cplusplus } diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp index 2f75c23a770b8d40d61ec575b035f899ed22decb..9d357afb51bfc2b1352339e47ced45e651edb677 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp @@ -89,17 +89,17 @@ Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativeMake( DataTypes inputVecTypes = Deserialize(inputTypeCharPtr); const int32_t *inputVecTypeIds = inputVecTypes.GetIds(); // - std::vector inputDataTpyes = inputVecTypes.Get(); - int32_t size = inputDataTpyes.size(); + std::vector inputDataTypes = inputVecTypes.Get(); + int32_t size = inputDataTypes.size(); uint32_t *inputDataPrecisions = new uint32_t[size]; uint32_t *inputDataScales = new uint32_t[size]; for (int i = 0; i < size; ++i) { - if(inputDataTpyes[i]->GetId() == OMNI_DECIMAL64 || inputDataTpyes[i]->GetId() == OMNI_DECIMAL128) { - inputDataScales[i] = std::dynamic_pointer_cast(inputDataTpyes[i])->GetScale(); - inputDataPrecisions[i] = std::dynamic_pointer_cast(inputDataTpyes[i])->GetPrecision(); + if (inputDataTypes[i]->GetId() == OMNI_DECIMAL64 || inputDataTypes[i]->GetId() == OMNI_DECIMAL128) { + inputDataScales[i] = std::dynamic_pointer_cast(inputDataTypes[i])->GetScale(); + inputDataPrecisions[i] = std::dynamic_pointer_cast(inputDataTypes[i])->GetPrecision(); } } - inputDataTpyes.clear(); + inputDataTypes.clear(); InputDataTypes inputDataTypesTmp; inputDataTypesTmp.inputVecTypeIds = (int32_t *)inputVecTypeIds; diff --git a/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto b/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto index c40472020171692ea7b0acde2dd873efeda691f4..725f9fa070aa1f8d188d85118df9765a63d299f3 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto +++ b/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto @@ -57,4 +57,4 @@ message VecType { NANOSEC = 3; } TimeUnit timeUnit = 6; -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp index 2eba4b92930591e97c1a264c40cd5e4a110ec0af..e1152c1da7adaef7315505d10b984b8673163cc4 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp @@ -37,10 +37,10 @@ int Splitter::ComputeAndCountPartitionId(VectorBatch& vb) { partition_id_[i] = 0; } } else { - IntVector* hashVct = static_cast(vb.GetVector(0)); + auto hash_vct = reinterpret_cast *>(vb.Get(0)); for (auto i = 0; i < num_rows; ++i) { // positive mod - int32_t pid = hashVct->GetValue(i); + int32_t pid = hash_vct->GetValue(i); if (pid >= num_partitions_) { LogsError(" Illegal pid Value: %d >= partition number %d .", pid, num_partitions_); throw std::runtime_error("Shuffle pidVec Illegal pid Value!"); @@ -76,7 +76,7 @@ int Splitter::AllocatePartitionBuffers(int32_t partition_id, int32_t new_size) { case SHUFFLE_8BYTE: case SHUFFLE_DECIMAL128: default: { - void *ptr_tmp = static_cast(options_.allocator->alloc(new_size * (1 << column_type_id_[i]))); + void *ptr_tmp = static_cast(options_.allocator->Alloc(new_size * (1 << column_type_id_[i]))); fixed_valueBuffer_size_[partition_id] = new_size * (1 << column_type_id_[i]); if (nullptr == ptr_tmp) { throw std::runtime_error("Allocator for AllocatePartitionBuffers Failed! "); @@ -128,15 +128,12 @@ int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { auto col_idx_vb = fixed_width_array_idx_[col]; auto col_idx_schema = singlePartitionFlag ? col_idx_vb : (col_idx_vb - 1); const auto& dst_addrs = partition_fixed_width_value_addrs_[col]; - if (vb.GetVector(col_idx_vb)->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY) { + if (vb.Get(col_idx_vb)->GetEncoding() == OMNI_DICTIONARY) { LogsDebug("Dictionary Columnar process!"); - auto ids_tmp = static_cast(options_.allocator->alloc(num_rows * sizeof(int32_t))); - Buffer *ids (new Buffer((uint8_t*)ids_tmp, 0, num_rows * sizeof(int32_t))); - if (ids->data_ == nullptr) { - throw std::runtime_error("Allocator for SplitFixedWidthValueBuffer ids Failed! "); - } - auto dictionaryTmp = ((DictionaryVector *)(vb.GetVector(col_idx_vb)))->ExtractDictionaryAndIds(0, num_rows, (int32_t *)(ids->data_)); - auto src_addr = VectorHelper::GetValuesAddr(dictionaryTmp); + + DataTypeId type_id = vector_batch_col_types_.at(col_idx_schema); + auto ids_addr = VectorHelper::UnsafeGetValues(vb.Get(col_idx_vb), type_id); + auto src_addr = reinterpret_cast(VectorHelper::UnsafeGetDictionary(vb.Get(col_idx_vb), type_id)); switch (column_type_id_[col_idx_schema]) { #define PROCESS(SHUFFLE_TYPE, CTYPE) \ case SHUFFLE_TYPE: \ @@ -145,8 +142,8 @@ int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { auto dst_offset = \ partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; \ reinterpret_cast(dst_addrs[pid])[dst_offset] = \ - reinterpret_cast(src_addr)[reinterpret_cast(ids->data_)[row]]; \ - partition_fixed_width_buffers_[col][pid][1]->size_ += (1 << SHUFFLE_TYPE); \ + reinterpret_cast(src_addr)[reinterpret_cast(ids_addr)[row]]; \ + partition_fixed_width_buffers_[col][pid][1]->size_ += (1 << SHUFFLE_TYPE); \ partition_buffer_idx_offset_[pid]++; \ } \ break; @@ -160,10 +157,12 @@ int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { auto pid = partition_id_[row]; auto dst_offset = partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; + // 前64位取值、赋值 reinterpret_cast(dst_addrs[pid])[dst_offset << 1] = - reinterpret_cast(src_addr)[reinterpret_cast(ids->data_)[row] << 1]; // 前64位取值、赋值 - reinterpret_cast(dst_addrs[pid])[dst_offset << 1 | 1] = - reinterpret_cast(src_addr)[reinterpret_cast(ids->data_)[row] << 1 | 1]; // 后64位取值、赋值 + reinterpret_cast(src_addr)[reinterpret_cast(ids_addr)[row] << 1]; + // 后64位取值、赋值 + reinterpret_cast(dst_addrs[pid])[(dst_offset << 1) | 1] = + reinterpret_cast(src_addr)[(reinterpret_cast(ids_addr)[row] << 1) | 1]; partition_fixed_width_buffers_[col][pid][1]->size_ += (1 << SHUFFLE_DECIMAL128); //decimal128 16Bytes partition_buffer_idx_offset_[pid]++; @@ -174,13 +173,9 @@ int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { throw std::runtime_error("SplitFixedWidthValueBuffer not match this type: " + column_type_id_[col_idx_schema]); } } - options_.allocator->free(ids->data_, ids->capacity_); - if (nullptr == ids) { - throw std::runtime_error("delete nullptr error for ids"); - } - delete ids; } else { - auto src_addr = VectorHelper::GetValuesAddr(vb.GetVector(col_idx_vb)); + DataTypeId type_id = vector_batch_col_types_.at(col_idx_schema); + auto src_addr = reinterpret_cast(VectorHelper::UnsafeGetValues(vb.Get(col_idx_vb), type_id)); switch (column_type_id_[col_idx_schema]) { #define PROCESS(SHUFFLE_TYPE, CTYPE) \ case SHUFFLE_TYPE: \ @@ -225,54 +220,65 @@ int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { int Splitter::SplitBinaryArray(VectorBatch& vb) { - const auto numRows = vb.GetRowCount(); - auto vecCntVb = vb.GetVectorCount(); - auto vecCntSchema = singlePartitionFlag ? vecCntVb : vecCntVb - 1; - for (auto colSchema = 0; colSchema < vecCntSchema; ++colSchema) { - switch (column_type_id_[colSchema]) { + const auto num_rows = vb.GetRowCount(); + auto vec_cnt_vb = vb.GetVectorCount(); + auto vec_cnt_schema = singlePartitionFlag ? vec_cnt_vb : vec_cnt_vb - 1; + for (auto col_schema = 0; col_schema < vec_cnt_schema; ++col_schema) { + switch (column_type_id_[col_schema]) { case SHUFFLE_BINARY: { - auto colVb = singlePartitionFlag ? colSchema : colSchema + 1; - varcharVectorCache.insert(vb.GetVector(colVb)); // record varchar vector for release - if (vb.GetVector(colVb)->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY) { - for (auto row = 0; row < numRows; ++row) { + auto col_vb = singlePartitionFlag ? col_schema : col_schema + 1; + varcharVectorCache.insert(vb.Get(col_vb)); + if (vb.Get(col_vb)->GetEncoding() == OMNI_DICTIONARY) { + auto vc = reinterpret_cast> *>( + vb.Get(col_vb)); + for (auto row = 0; row < num_rows; ++row) { auto pid = partition_id_[row]; uint8_t *dst = nullptr; - auto str_len = ((DictionaryVector *)(vb.GetVector(colVb)))->GetVarchar(row, &dst); - bool isnull = ((DictionaryVector *)(vb.GetVector(colVb)))->IsValueNull(row); + uint32_t str_len = 0; + if (!vc->IsNull(row)) { + std::string_view value = vc->GetValue(row); + dst = reinterpret_cast(reinterpret_cast(value.data())); + str_len = static_cast(value.length()); + } + bool is_null = vc->IsNull(row); cached_vectorbatch_size_ += str_len; // 累计变长部分cache数据 - VCLocation cl((uint64_t) dst, str_len, isnull); - if ((vc_partition_array_buffers_[pid][colSchema].size() != 0) && - (vc_partition_array_buffers_[pid][colSchema].back().getVcList().size() < + VCLocation cl((uint64_t) dst, str_len, is_null); + if ((vc_partition_array_buffers_[pid][col_schema].size() != 0) && + (vc_partition_array_buffers_[pid][col_schema].back().getVcList().size() < options_.spill_batch_row_num)) { - vc_partition_array_buffers_[pid][colSchema].back().getVcList().push_back(cl); - vc_partition_array_buffers_[pid][colSchema].back().vcb_total_len += str_len; + vc_partition_array_buffers_[pid][col_schema].back().getVcList().push_back(cl); + vc_partition_array_buffers_[pid][col_schema].back().vcb_total_len += str_len; } else { VCBatchInfo svc(options_.spill_batch_row_num); svc.getVcList().push_back(cl); svc.vcb_total_len += str_len; - vc_partition_array_buffers_[pid][colSchema].push_back(svc); + vc_partition_array_buffers_[pid][col_schema].push_back(svc); } } } else { - VarcharVector *vc = nullptr; - vc = static_cast(vb.GetVector(colVb)); - for (auto row = 0; row < numRows; ++row) { + auto vc = reinterpret_cast> *>(vb.Get(col_vb)); + for (auto row = 0; row < num_rows; ++row) { auto pid = partition_id_[row]; uint8_t *dst = nullptr; - int str_len = vc->GetValue(row, &dst); - bool isnull = vc->IsValueNull(row); + uint32_t str_len = 0; + if (!vc->IsNull(row)) { + std::string_view value = vc->GetValue(row); + dst = reinterpret_cast(reinterpret_cast(value.data())); + str_len = static_cast(value.length()); + } + bool is_null = vc->IsNull(row); cached_vectorbatch_size_ += str_len; // 累计变长部分cache数据 - VCLocation cl((uint64_t) dst, str_len, isnull); - if ((vc_partition_array_buffers_[pid][colSchema].size() != 0) && - (vc_partition_array_buffers_[pid][colSchema].back().getVcList().size() < + VCLocation cl((uint64_t) dst, str_len, is_null); + if ((vc_partition_array_buffers_[pid][col_schema].size() != 0) && + (vc_partition_array_buffers_[pid][col_schema].back().getVcList().size() < options_.spill_batch_row_num)) { - vc_partition_array_buffers_[pid][colSchema].back().getVcList().push_back(cl); - vc_partition_array_buffers_[pid][colSchema].back().vcb_total_len += str_len; + vc_partition_array_buffers_[pid][col_schema].back().getVcList().push_back(cl); + vc_partition_array_buffers_[pid][col_schema].back().vcb_total_len += str_len; } else { VCBatchInfo svc(options_.spill_batch_row_num); svc.getVcList().push_back(cl); svc.vcb_total_len += str_len; - vc_partition_array_buffers_[pid][colSchema].push_back(svc); + vc_partition_array_buffers_[pid][col_schema].push_back(svc); } } } @@ -297,7 +303,7 @@ int Splitter::SplitFixedWidthValidityBuffer(VectorBatch& vb){ if (partition_id_cnt_cur_[pid] > 0 && dst_addrs[pid] == nullptr) { // init bitmap if it's null auto new_size = partition_id_cnt_cur_[pid] > options_.buffer_size ? partition_id_cnt_cur_[pid] : options_.buffer_size; - auto ptr_tmp = static_cast(options_.allocator->alloc(new_size)); + auto ptr_tmp = static_cast(options_.allocator->Alloc(new_size)); if (nullptr == ptr_tmp) { throw std::runtime_error("Allocator for ValidityBuffer Failed! "); } @@ -310,7 +316,8 @@ int Splitter::SplitFixedWidthValidityBuffer(VectorBatch& vb){ } // 计算并填充数据 - auto src_addr = const_cast((uint8_t*)(VectorHelper::GetNullsAddr(vb.GetVector(col_idx)))); + auto src_addr = const_cast((uint8_t *)( + reinterpret_cast(omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vb.Get(col_idx))))); std::fill(std::begin(partition_buffer_idx_offset_), std::end(partition_buffer_idx_offset_), 0); const auto num_rows = vb.GetRowCount(); @@ -550,7 +557,7 @@ int Splitter::Split(VectorBatch& vb ) } std::shared_ptr Splitter::CaculateSpilledTmpFilePartitionOffsets() { - void *ptr_tmp = static_cast(options_.allocator->alloc((num_partitions_ + 1) * sizeof(uint64_t))); + void *ptr_tmp = static_cast(options_.allocator->Alloc((num_partitions_ + 1) * sizeof(uint64_t))); if (nullptr == ptr_tmp) { throw std::runtime_error("Allocator for partitionOffsets Failed! "); } @@ -606,7 +613,7 @@ spark::VecType::VecTypeId CastShuffleTypeIdToVecType(int32_t tmpType) { return spark::VecType::VEC_TYPE_CHAR; case OMNI_CONTAINER: return spark::VecType::VEC_TYPE_CONTAINER; - case OMNI_INVALID: + case DataTypeId::OMNI_INVALID: return spark::VecType::VEC_TYPE_INVALID; default: { throw std::runtime_error("castShuffleTypeIdToVecType() unexpected ShuffleTypeId"); @@ -625,9 +632,9 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, colIndexTmpSchema = singlePartitionFlag ? fixed_width_array_idx_[fixColIndexTmp] : fixed_width_array_idx_[fixColIndexTmp] - 1; auto onceCopyLen = splitRowInfoTmp->onceCopyRow * (1 << column_type_id_[colIndexTmpSchema]); // 临时内存,拷贝拼接onceCopyRow批,用完释放 - void *ptr_value_tmp = static_cast(options_.allocator->alloc(onceCopyLen)); + void *ptr_value_tmp = static_cast(options_.allocator->Alloc(onceCopyLen)); std::shared_ptr ptr_value (new Buffer((uint8_t*)ptr_value_tmp, 0, onceCopyLen)); - void *ptr_validity_tmp = static_cast(options_.allocator->alloc(splitRowInfoTmp->onceCopyRow)); + void *ptr_validity_tmp = static_cast(options_.allocator->Alloc(splitRowInfoTmp->onceCopyRow)); std::shared_ptr ptr_validity (new Buffer((uint8_t*)ptr_validity_tmp, 0, splitRowInfoTmp->onceCopyRow)); if (nullptr == ptr_value->data_ || nullptr == ptr_validity->data_) { throw std::runtime_error("Allocator for tmp buffer Failed! "); @@ -659,9 +666,9 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_ + (splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] / (1 << column_type_id_[colIndexTmpSchema])), memCopyLen / (1 << column_type_id_[colIndexTmpSchema])); // 释放内存 - options_.allocator->free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_, + options_.allocator->Free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_, partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->capacity_); - options_.allocator->free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->data_, + options_.allocator->Free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->data_, partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->capacity_); destCopyedLength += memCopyLen; splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp] += 1; // cacheBatchIndex下标后移 @@ -688,8 +695,8 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, vec.set_values(ptr_value->data_, onceCopyLen); vec.set_nulls(ptr_validity->data_, splitRowInfoTmp->onceCopyRow); // 临时内存,拷贝拼接onceCopyRow批,用完释放 - options_.allocator->free(ptr_value->data_, ptr_value->capacity_); - options_.allocator->free(ptr_validity->data_, ptr_validity->capacity_); + options_.allocator->Free(ptr_value->data_, ptr_value->capacity_); + options_.allocator->Free(ptr_validity->data_, ptr_validity->capacity_); } // partition_cached_vectorbatch_[partition_id][cache_index][col][0]代表ByteMap, // partition_cached_vectorbatch_[partition_id][cache_index][col][1]代表value @@ -869,7 +876,7 @@ int Splitter::DeleteSpilledTmpFile() { for (auto &pair : spilled_tmp_files_info_) { auto tmpDataFilePath = pair.first + ".data"; // 释放存储有各个临时文件的偏移数据内存 - options_.allocator->free(pair.second->data_, pair.second->capacity_); + options_.allocator->Free(pair.second->data_, pair.second->capacity_); if (IsFileExist(tmpDataFilePath)) { remove(tmpDataFilePath.c_str()); } @@ -957,7 +964,4 @@ int Splitter::Stop() { } delete vecBatchProto; //free protobuf vecBatch memory return 0; -} - - - +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h index 0ef1989968a764eb67bb5c7aa35853e71a2fbe06..a57f868a335ebbf711b03a00329a882a82ee21f0 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h @@ -41,7 +41,6 @@ using namespace spark; using namespace google::protobuf::io; using namespace omniruntime::vec; using namespace omniruntime::type; -using namespace omniruntime::mem; struct SplitRowInfo { uint32_t copyedRow = 0; @@ -137,7 +136,7 @@ class Splitter { private: void ReleaseVarcharVector() { - std::set::iterator it; + std::set::iterator it; for (it = varcharVectorCache.begin(); it != varcharVectorCache.end(); it++) { delete *it; } @@ -147,9 +146,9 @@ private: void ReleaseVectorBatch(VectorBatch *vb) { int vectorCnt = vb->GetVectorCount(); - std::set vectorAddress; // vector deduplication + std::set vectorAddress; // vector deduplication for (int vecIndex = 0; vecIndex < vectorCnt; vecIndex++) { - Vector *vector = vb->GetVector(vecIndex); + BaseVector *vector = vb->Get(vecIndex); // not varchar vector can be released; if (varcharVectorCache.find(vector) == varcharVectorCache.end() && vectorAddress.find(vector) == vectorAddress.end()) { @@ -161,7 +160,7 @@ private: delete vb; } - std::set varcharVectorCache; + std::set varcharVectorCache; bool first_vector_batch_ = false; std::vector vector_batch_col_types_; InputDataTypes input_col_types; @@ -176,7 +175,7 @@ public: std::map> spilled_tmp_files_info_; - VecBatch *vecBatchProto = new VecBatch(); //protobuf 序列化对象结构 + spark::VecBatch *vecBatchProto = new VecBatch(); // protobuf 序列化对象结构 virtual int Split_Init(); diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h index 446cedc5f89988f115aedb7d9b3bc9b7c1c0a177..04d90130dea30a83651fff3526c08dc0992f9928 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h @@ -40,7 +40,7 @@ struct SplitOptions { int64_t thread_id = -1; int64_t task_attempt_id = -1; - BaseAllocator *allocator = omniruntime::mem::GetProcessRootAllocator(); + Allocator *allocator = Allocator::GetAllocator(); uint64_t spill_batch_row_num = 4096; // default value uint64_t spill_mem_threshold = 1024 * 1024 * 1024; // default value diff --git a/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt index ca8c3848b775add4a1add153245357cb0b799f2f..209972501d52bcf7ff468b4c7a56b88e04123161 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt @@ -29,7 +29,6 @@ target_link_libraries(${TP_TEST_TARGET} pthread stdc++ dl - boostkit-omniop-runtime-1.2.0-aarch64 boostkit-omniop-vector-1.2.0-aarch64 securec spark_columnar_plugin) diff --git a/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp b/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp index 1834345d54466d8e65f34eaea4ba2c99396440e0..c7a55759558e0713f3c3f265c052e5fcce94aa1c 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp +++ b/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp @@ -242,7 +242,7 @@ TEST_F (ShuffleTest, Split_Short_10WRows) { 0, tmpTestingDir); for (uint64_t j = 0; j < 100; j++) { - VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 1000, OMNI_SHORT); + VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 1000, ShortType()); Test_splitter_split(splitterId, vb); } Test_splitter_stop(splitterId); @@ -270,7 +270,7 @@ TEST_F (ShuffleTest, Split_Boolean_10WRows) { 0, tmpTestingDir); for (uint64_t j = 0; j < 100; j++) { - VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 1000, OMNI_BOOLEAN); + VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 1000, BooleanType()); Test_splitter_split(splitterId, vb); } Test_splitter_stop(splitterId); @@ -298,7 +298,7 @@ TEST_F (ShuffleTest, Split_Long_100WRows) { 0, tmpTestingDir); for (uint64_t j = 0; j < 100; j++) { - VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 10000, OMNI_LONG); + VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 10000, LongType()); Test_splitter_split(splitterId, vb); } Test_splitter_stop(splitterId); diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp b/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp index f8a6a6b7f2776f212d7ba6b1c9ee8d9260509116..2ed604e50420c402e9184c0a4011f66d69c00158 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp +++ b/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp @@ -17,15 +17,13 @@ * limitations under the License. */ -#include "gtest/gtest.h" -#include -#include -#include "../../src/jni/OrcColumnarBatchJniReader.h" +#include +#include +#include +#include "jni/OrcColumnarBatchJniReader.h" #include "scan_test.h" -#include "orc/sargs/SearchArgument.hh" static std::string filename = "/resources/orc_data_all_type"; -static orc::ColumnVectorBatch *batchPtr; static orc::StructVectorBatch *root; /* @@ -53,17 +51,24 @@ protected: orc::ReaderOptions readerOpts; orc::RowReaderOptions rowReaderOptions; std::unique_ptr reader = orc::createReader(orc::readFile(PROJECT_PATH + filename), readerOpts); - std::unique_ptr rowReader = reader->createRowReader(); + rowReader = reader->createRowReader().release(); std::unique_ptr batch = rowReader->createRowBatch(4096); rowReader->next(*batch); - batchPtr = batch.release(); - root = static_cast(batchPtr); + types = &(rowReader->getSelectedType()); + root = static_cast(batch.release()); } // run after each case... virtual void TearDown() override { - delete batchPtr; + delete root; + root = nullptr; + types = nullptr; + delete rowReader; + rowReader = nullptr; } + + const orc::Type *types; + orc::RowReader *rowReader; }; TEST_F(ScanTest, test_literal_get_long) @@ -71,11 +76,11 @@ TEST_F(ScanTest, test_literal_get_long) orc::Literal tmpLit(0L); // test get long - getLiteral(tmpLit, (int)(orc::PredicateDataType::LONG), "655361"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::LONG), "655361"); ASSERT_EQ(tmpLit.getLong(), 655361); - getLiteral(tmpLit, (int)(orc::PredicateDataType::LONG), "-655361"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::LONG), "-655361"); ASSERT_EQ(tmpLit.getLong(), -655361); - getLiteral(tmpLit, (int)(orc::PredicateDataType::LONG), "0"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::LONG), "0"); ASSERT_EQ(tmpLit.getLong(), 0); } @@ -84,11 +89,11 @@ TEST_F(ScanTest, test_literal_get_float) orc::Literal tmpLit(0L); // test get float - getLiteral(tmpLit, (int)(orc::PredicateDataType::FLOAT), "12345.6789"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::FLOAT), "12345.6789"); ASSERT_EQ(tmpLit.getFloat(), 12345.6789); - getLiteral(tmpLit, (int)(orc::PredicateDataType::FLOAT), "-12345.6789"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::FLOAT), "-12345.6789"); ASSERT_EQ(tmpLit.getFloat(), -12345.6789); - getLiteral(tmpLit, (int)(orc::PredicateDataType::FLOAT), "0"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::FLOAT), "0"); ASSERT_EQ(tmpLit.getFloat(), 0); } @@ -97,9 +102,9 @@ TEST_F(ScanTest, test_literal_get_string) orc::Literal tmpLit(0L); // test get string - getLiteral(tmpLit, (int)(orc::PredicateDataType::STRING), "testStringForLit"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::STRING), "testStringForLit"); ASSERT_EQ(tmpLit.getString(), "testStringForLit"); - getLiteral(tmpLit, (int)(orc::PredicateDataType::STRING), ""); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::STRING), ""); ASSERT_EQ(tmpLit.getString(), ""); } @@ -108,7 +113,7 @@ TEST_F(ScanTest, test_literal_get_date) orc::Literal tmpLit(0L); // test get date - getLiteral(tmpLit, (int)(orc::PredicateDataType::DATE), "987654321"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::DATE), "987654321"); ASSERT_EQ(tmpLit.getDate(), 987654321); } @@ -117,15 +122,15 @@ TEST_F(ScanTest, test_literal_get_decimal) orc::Literal tmpLit(0L); // test get decimal - getLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "199999999999998.998000 22 6"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "199999999999998.998000 22 6"); ASSERT_EQ(tmpLit.getDecimal().toString(), "199999999999998.998000"); - getLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "10.998000 10 6"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "10.998000 10 6"); ASSERT_EQ(tmpLit.getDecimal().toString(), "10.998000"); - getLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "-10.998000 10 6"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "-10.998000 10 6"); ASSERT_EQ(tmpLit.getDecimal().toString(), "-10.998000"); - getLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "9999.999999 10 6"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "9999.999999 10 6"); ASSERT_EQ(tmpLit.getDecimal().toString(), "9999.999999"); - getLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "-0.000000 10 6"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "-0.000000 10 6"); ASSERT_EQ(tmpLit.getDecimal().toString(), "0.000000"); } @@ -134,17 +139,17 @@ TEST_F(ScanTest, test_literal_get_bool) orc::Literal tmpLit(0L); // test get bool - getLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "true"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "true"); ASSERT_EQ(tmpLit.getBool(), true); - getLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "True"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "True"); ASSERT_EQ(tmpLit.getBool(), true); - getLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "false"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "false"); ASSERT_EQ(tmpLit.getBool(), false); - getLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "False"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "False"); ASSERT_EQ(tmpLit.getBool(), false); std::string tmpStr = ""; try { - getLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "exception"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "exception"); } catch (std::exception &e) { tmpStr = e.what(); } @@ -156,9 +161,9 @@ TEST_F(ScanTest, test_copy_intVec) int omniType = 0; uint64_t omniVecId = 0; // int type - copyToOmniVec(orc::TypeKind::INT, omniType, omniVecId, root->fields[0]); + CopyToOmniVec(types->getSubtype(0), omniType, omniVecId, root->fields[0]); ASSERT_EQ(omniType, omniruntime::type::OMNI_INT); - omniruntime::vec::IntVector *olbInt = (omniruntime::vec::IntVector *)(omniVecId); + auto *olbInt = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbInt->GetValue(0), 10); delete olbInt; } @@ -168,12 +173,11 @@ TEST_F(ScanTest, test_copy_varCharVec) int omniType = 0; uint64_t omniVecId = 0; // varchar type - copyToOmniVec(orc::TypeKind::VARCHAR, omniType, omniVecId, root->fields[1], 60); + CopyToOmniVec(types->getSubtype(1), omniType, omniVecId, root->fields[1]); ASSERT_EQ(omniType, omniruntime::type::OMNI_VARCHAR); - uint8_t *actualChar = nullptr; - omniruntime::vec::VarcharVector *olbVc = (omniruntime::vec::VarcharVector *)(omniVecId); - int len = olbVc->GetValue(0, &actualChar); - std::string actualStr(reinterpret_cast(actualChar), 0, len); + auto *olbVc = (omniruntime::vec::Vector> *)( + omniVecId); + std::string_view actualStr = olbVc->GetValue(0); ASSERT_EQ(actualStr, "varchar_1"); delete olbVc; } @@ -182,14 +186,13 @@ TEST_F(ScanTest, test_copy_stringVec) { int omniType = 0; uint64_t omniVecId = 0; - uint8_t *actualChar = nullptr; // string type - copyToOmniVec(orc::TypeKind::STRING, omniType, omniVecId, root->fields[2]); + CopyToOmniVec(types->getSubtype(2), omniType, omniVecId, root->fields[2]); ASSERT_EQ(omniType, omniruntime::type::OMNI_VARCHAR); - omniruntime::vec::VarcharVector *olbStr = (omniruntime::vec::VarcharVector *)(omniVecId); - int len = olbStr->GetValue(0, &actualChar); - std::string actualStr2(reinterpret_cast(actualChar), 0, len); - ASSERT_EQ(actualStr2, "string_type_1"); + auto *olbStr = (omniruntime::vec::Vector> *)( + omniVecId); + std::string_view actualStr = olbStr->GetValue(0); + ASSERT_EQ(actualStr, "string_type_1"); delete olbStr; } @@ -198,9 +201,9 @@ TEST_F(ScanTest, test_copy_longVec) int omniType = 0; uint64_t omniVecId = 0; // bigint type - copyToOmniVec(orc::TypeKind::LONG, omniType, omniVecId, root->fields[3]); + CopyToOmniVec(types->getSubtype(3), omniType, omniVecId, root->fields[3]); ASSERT_EQ(omniType, omniruntime::type::OMNI_LONG); - omniruntime::vec::LongVector *olbLong = (omniruntime::vec::LongVector *)(omniVecId); + auto *olbLong = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbLong->GetValue(0), 10000); delete olbLong; } @@ -209,15 +212,14 @@ TEST_F(ScanTest, test_copy_charVec) { int omniType = 0; uint64_t omniVecId = 0; - uint8_t *actualChar = nullptr; // char type - copyToOmniVec(orc::TypeKind::CHAR, omniType, omniVecId, root->fields[4], 40); + CopyToOmniVec(types->getSubtype(4), omniType, omniVecId, root->fields[4]); ASSERT_EQ(omniType, omniruntime::type::OMNI_VARCHAR); - omniruntime::vec::VarcharVector *olbChar40 = (omniruntime::vec::VarcharVector *)(omniVecId); - int len = olbChar40->GetValue(0, &actualChar); - std::string actualStr3(reinterpret_cast(actualChar), 0, len); - ASSERT_EQ(actualStr3, "char_1"); - delete olbChar40; + auto *olbChar = (omniruntime::vec::Vector> *)( + omniVecId); + std::string_view actualStr = olbChar->GetValue(0); + ASSERT_EQ(actualStr, "char_1"); + delete olbChar; } TEST_F(ScanTest, test_copy_doubleVec) @@ -225,9 +227,9 @@ TEST_F(ScanTest, test_copy_doubleVec) int omniType = 0; uint64_t omniVecId = 0; // double type - copyToOmniVec(orc::TypeKind::DOUBLE, omniType, omniVecId, root->fields[6]); + CopyToOmniVec(types->getSubtype(6), omniType, omniVecId, root->fields[6]); ASSERT_EQ(omniType, omniruntime::type::OMNI_DOUBLE); - omniruntime::vec::DoubleVector *olbDouble = (omniruntime::vec::DoubleVector *)(omniVecId); + auto *olbDouble = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbDouble->GetValue(0), 1111.1111); delete olbDouble; } @@ -237,9 +239,9 @@ TEST_F(ScanTest, test_copy_booleanVec) int omniType = 0; uint64_t omniVecId = 0; // boolean type - copyToOmniVec(orc::TypeKind::BOOLEAN, omniType, omniVecId, root->fields[9]); + CopyToOmniVec(types->getSubtype(9), omniType, omniVecId, root->fields[9]); ASSERT_EQ(omniType, omniruntime::type::OMNI_BOOLEAN); - omniruntime::vec::BooleanVector *olbBoolean = (omniruntime::vec::BooleanVector *)(omniVecId); + auto *olbBoolean = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbBoolean->GetValue(0), true); delete olbBoolean; } @@ -249,9 +251,9 @@ TEST_F(ScanTest, test_copy_shortVec) int omniType = 0; uint64_t omniVecId = 0; // short type - copyToOmniVec(orc::TypeKind::SHORT, omniType, omniVecId, root->fields[10]); + CopyToOmniVec(types->getSubtype(10), omniType, omniVecId, root->fields[10]); ASSERT_EQ(omniType, omniruntime::type::OMNI_SHORT); - omniruntime::vec::ShortVector *olbShort = (omniruntime::vec::ShortVector *)(omniVecId); + auto *olbShort = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbShort->GetValue(0), 11); delete olbShort; } @@ -265,24 +267,26 @@ TEST_F(ScanTest, test_build_leafs) orc::Literal lit(100L); // test EQUALS - buildLeaves(PredicateOperatorType::EQUALS, litList, lit, "leaf-0", orc::PredicateDataType::LONG, *builder); + BuildLeaves(PredicateOperatorType::EQUALS, litList, lit, "leaf-0", orc::PredicateDataType::LONG, *builder); // test LESS_THAN - buildLeaves(PredicateOperatorType::LESS_THAN, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); + BuildLeaves(PredicateOperatorType::LESS_THAN, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); // test LESS_THAN_EQUALS - buildLeaves(PredicateOperatorType::LESS_THAN_EQUALS, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); + BuildLeaves(PredicateOperatorType::LESS_THAN_EQUALS, litList, lit, "leaf-1", orc::PredicateDataType::LONG, + *builder); // test NULL_SAFE_EQUALS - buildLeaves(PredicateOperatorType::NULL_SAFE_EQUALS, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); + BuildLeaves(PredicateOperatorType::NULL_SAFE_EQUALS, litList, lit, "leaf-1", orc::PredicateDataType::LONG, + *builder); // test IS_NULL - buildLeaves(PredicateOperatorType::IS_NULL, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); + BuildLeaves(PredicateOperatorType::IS_NULL, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); // test BETWEEN std::string tmpStr = ""; try { - buildLeaves(PredicateOperatorType::BETWEEN, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); + BuildLeaves(PredicateOperatorType::BETWEEN, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); } catch (std::exception &e) { tmpStr = e.what(); } diff --git a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp index 586f4bbdb95721b22422d715f645eb502dc1a894..d70a62003645893af12df8f8980c9195bbd6d389 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp +++ b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp @@ -21,199 +21,33 @@ using namespace omniruntime::vec; -void ToVectorTypes(const int32_t *dataTypeIds, int32_t dataTypeCount, std::vector &dataTypes) -{ - for (int i = 0; i < dataTypeCount; ++i) { - if (dataTypeIds[i] == OMNI_VARCHAR) { - dataTypes.push_back(std::make_shared(50)); - continue; - } else if (dataTypeIds[i] == OMNI_CHAR) { - dataTypes.push_back(std::make_shared(50)); - continue; - } - dataTypes.push_back(std::make_shared(dataTypeIds[i])); - } -} - -VectorBatch* CreateInputData(const int32_t numRows, - const int32_t numCols, - int32_t* inputTypeIds, - int64_t* allData) -{ - auto *vecBatch = new VectorBatch(numCols, numRows); - vector inputTypes; - ToVectorTypes(inputTypeIds, numCols, inputTypes); - vecBatch->NewVectors(omniruntime::vec::GetProcessGlobalVecAllocator(), inputTypes); - for (int i = 0; i < numCols; ++i) { - switch (inputTypeIds[i]) { - case OMNI_BOOLEAN: - ((BooleanVector *)vecBatch->GetVector(i))->SetValues(0, (int32_t *)allData[i], numRows); - break; - case OMNI_INT: - ((IntVector *)vecBatch->GetVector(i))->SetValues(0, (int32_t *)allData[i], numRows); - break; - case OMNI_LONG: - ((LongVector *)vecBatch->GetVector(i))->SetValues(0, (int64_t *)allData[i], numRows); - break; - case OMNI_DOUBLE: - ((DoubleVector *)vecBatch->GetVector(i))->SetValues(0, (double *)allData[i], numRows); - break; - case OMNI_SHORT: - ((ShortVector *)vecBatch->GetVector(i))->SetValues(0, (int16_t *)allData[i], numRows); - break; - case OMNI_VARCHAR: - case OMNI_CHAR: { - for (int j = 0; j < numRows; ++j) { - int64_t addr = (reinterpret_cast(allData[i]))[j]; - std::string s (reinterpret_cast(addr)); - ((VarcharVector *)vecBatch->GetVector(i))->SetValue(j, (uint8_t *)(s.c_str()), s.length()); - } - break; - } - case OMNI_DECIMAL128: - ((Decimal128Vector *)vecBatch->GetVector(i))->SetValues(0, (int64_t *) allData[i], numRows); - break; - default:{ - LogError("No such data type %d", inputTypeIds[i]); - } - } - } - return vecBatch; -} - -VarcharVector *CreateVarcharVector(VarcharDataType type, std::string *values, int32_t length) -{ - VectorAllocator * vecAllocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - uint32_t width = type.GetWidth(); - VarcharVector *vector = std::make_unique(vecAllocator, length * width, length).release(); - for (int32_t i = 0; i < length; i++) { - vector->SetValue(i, reinterpret_cast(values[i].c_str()), values[i].length()); - } - return vector; -} - -Decimal128Vector *CreateDecimal128Vector(Decimal128 *values, int32_t length) -{ - VectorAllocator *vecAllocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - Decimal128Vector *vector = std::make_unique(vecAllocator, length).release(); - for (int32_t i = 0; i < length; i++) { - vector->SetValue(i, values[i]); - } - return vector; -} - -Vector *CreateVector(DataType &vecType, int32_t rowCount, va_list &args) -{ - switch (vecType.GetId()) { - case OMNI_INT: - case OMNI_DATE32: - return CreateVector(va_arg(args, int32_t *), rowCount); - case OMNI_LONG: - case OMNI_DECIMAL64: - return CreateVector(va_arg(args, int64_t *), rowCount); - case OMNI_DOUBLE: - return CreateVector(va_arg(args, double *), rowCount); - case OMNI_BOOLEAN: - return CreateVector(va_arg(args, bool *), rowCount); - case OMNI_VARCHAR: - case OMNI_CHAR: - return CreateVarcharVector(static_cast(vecType), va_arg(args, std::string *), rowCount); - case OMNI_DECIMAL128: - return CreateDecimal128Vector(va_arg(args, Decimal128 *), rowCount); - default: - std::cerr << "Unsupported type : " << vecType.GetId() << std::endl; - return nullptr; - } -} - -DictionaryVector *CreateDictionaryVector(DataType &dataType, int32_t rowCount, int32_t *ids, int32_t idsCount, ...) +VectorBatch *CreateVectorBatch(const DataTypes &types, int32_t rowCount, ...) { + int32_t typesCount = types.GetSize(); + auto *vectorBatch = new VectorBatch(rowCount); va_list args; - va_start(args, idsCount); - Vector *dictionary = CreateVector(dataType, rowCount, args); + va_start(args, rowCount); + for (int32_t i = 0; i < typesCount; i++) { + DataTypePtr type = types.GetType(i); + vectorBatch->Append(CreateVector(*type, rowCount, args).release()); + } va_end(args); - auto vec = new DictionaryVector(dictionary, ids, idsCount); - delete dictionary; - return vec; + return vectorBatch; } -Vector *buildVector(const DataType &aggType, int32_t rowNumber) +std::unique_ptr CreateVector(DataType &dataType, int32_t rowCount, va_list &args) { - VectorAllocator *vecAllocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - switch (aggType.GetId()) { - case OMNI_NONE: { - LongVector *col = new LongVector(vecAllocator, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - col->SetValueNull(j); - } - return col; - } - case OMNI_INT: - case OMNI_DATE32: { - IntVector *col = new IntVector(vecAllocator, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - col->SetValue(j, 1); - } - return col; - } - case OMNI_LONG: - case OMNI_DECIMAL64: { - LongVector *col = new LongVector(vecAllocator, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - col->SetValue(j, 1); - } - return col; - } - case OMNI_DOUBLE: { - DoubleVector *col = new DoubleVector(vecAllocator, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - col->SetValue(j, 1); - } - return col; - } - case OMNI_BOOLEAN: { - BooleanVector *col = new BooleanVector(vecAllocator, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - col->SetValue(j, 1); - } - return col; - } - case OMNI_DECIMAL128: { - Decimal128Vector *col = new Decimal128Vector(vecAllocator, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - col->SetValue(j, Decimal128(0, 1)); - } - return col; - } - case OMNI_VARCHAR: - case OMNI_CHAR: { - VarcharDataType charType = (VarcharDataType &)aggType; - VarcharVector *col = new VarcharVector(vecAllocator, charType.GetWidth() * rowNumber, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - std::string str = std::to_string(j); - col->SetValue(j, reinterpret_cast(str.c_str()), str.size()); - } - return col; - } - default: { - LogError("No such %d type support", aggType.GetId()); - return nullptr; - } - } + return DYNAMIC_TYPE_DISPATCH(CreateFlatVector, dataType.GetId(), rowCount, args); } -VectorBatch *CreateVectorBatch(const DataTypes &types, int32_t rowCount, ...) +std::unique_ptr CreateDictionaryVector(DataType &dataType, int32_t rowCount, int32_t *ids, int32_t idsCount, + ...) { - int32_t typesCount = types.GetSize(); - auto *vectorBatch = new VectorBatch(typesCount, rowCount); va_list args; - va_start(args, rowCount); - for (int32_t i = 0; i < typesCount; i++) { - DataTypePtr type = types.GetType(i); - vectorBatch->SetVector(i, CreateVector(*type, rowCount, args)); - } + va_start(args, idsCount); + std::unique_ptr dictionary = CreateVector(dataType, rowCount, args); va_end(args); - return vectorBatch; + return DYNAMIC_TYPE_DISPATCH(CreateDictionary, dataType.GetId(), dictionary.get(), ids, idsCount); } /** @@ -225,24 +59,16 @@ VectorBatch *CreateVectorBatch(const DataTypes &types, int32_t rowCount, ...) */ VectorBatch* CreateVectorBatch_1row_varchar_withPid(int pid, std::string inputString) { // gen vectorBatch - const int32_t numCols = 2; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_VARCHAR; + DataTypes inputTypes(std::vector({ IntType(), VarcharType() })); const int32_t numRows = 1; auto* col1 = new int32_t[numRows]; col1[0] = pid; - auto* col2 = new int64_t[numRows]; - std::string* strTmp = new std::string(inputString); - col2[0] = (int64_t)(strTmp->c_str()); + auto* col2 = new std::string[numRows]; + col2[0] = std::move(inputString); - int64_t allData[numCols] = {reinterpret_cast(col1), - reinterpret_cast(col2)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col1, col2); delete[] col1; delete[] col2; - delete strTmp; return in; } @@ -255,224 +81,144 @@ VectorBatch* CreateVectorBatch_1row_varchar_withPid(int pid, std::string inputSt */ VectorBatch* CreateVectorBatch_4col_withPid(int parNum, int rowNum) { int partitionNum = parNum; - const int32_t numCols = 5; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_INT; - inputTypes[2] = OMNI_LONG; - inputTypes[3] = OMNI_DOUBLE; - inputTypes[4] = OMNI_VARCHAR; + DataTypes inputTypes(std::vector({ IntType(), IntType(), LongType(), DoubleType(), VarcharType() })); const int32_t numRows = rowNum; auto* col0 = new int32_t[numRows]; auto* col1 = new int32_t[numRows]; auto* col2 = new int64_t[numRows]; auto* col3 = new double[numRows]; - auto* col4 = new int64_t[numRows]; - string startStr = "_START_"; - string endStr = "_END_"; + auto* col4 = new std::string[numRows]; + std::string startStr = "_START_"; + std::string endStr = "_END_"; std::vector string_cache_test_; for (int i = 0; i < numRows; i++) { - col0[i] = (i+1) % partitionNum; + col0[i] = (i + 1) % partitionNum; col1[i] = i + 1; col2[i] = i + 1; col3[i] = i + 1; - std::string* strTmp = new std::string(startStr + to_string(i + 1) + endStr); - string_cache_test_.push_back(strTmp); - col4[i] = (int64_t)((*strTmp).c_str()); + std::string strTmp = std::string(startStr + to_string(i + 1) + endStr); + col4[i] = std::move(strTmp); } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3), - reinterpret_cast(col4)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4); delete[] col0; delete[] col1; delete[] col2; delete[] col3; delete[] col4; - - for (uint p = 0; p < string_cache_test_.size(); p++) { - delete string_cache_test_[p]; // release memory - } return in; } -VectorBatch* CreateVectorBatch_1FixCol_withPid(int parNum, int rowNum, int32_t fixColType) { +VectorBatch* CreateVectorBatch_1FixCol_withPid(int parNum, int rowNum, DataTypePtr fixColType) { int partitionNum = parNum; - const int32_t numCols = 2; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = fixColType; + DataTypes inputTypes(std::vector({ IntType(), std::move(fixColType) })); const int32_t numRows = rowNum; auto* col0 = new int32_t[numRows]; auto* col1 = new int64_t[numRows]; for (int i = 0; i < numRows; i++) { - col0[i] = (i+1) % partitionNum; + col0[i] = (i + 1) % partitionNum; col1[i] = i + 1; } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1); delete[] col0; delete[] col1; return in; } VectorBatch* CreateVectorBatch_2column_1row_withPid(int pid, std::string strVar, int intVar) { - const int32_t numCols = 3; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_VARCHAR; - inputTypes[2] = OMNI_INT; + DataTypes inputTypes(std::vector({ IntType(), VarcharType(), IntType() })); const int32_t numRows = 1; auto* col0 = new int32_t[numRows]; - auto* col1 = new int64_t[numRows]; + auto* col1 = new std::string[numRows]; auto* col2 = new int32_t[numRows]; col0[0] = pid; - std::string* strTmp = new std::string(strVar); - col1[0] = (int64_t)(strTmp->c_str()); + col1[0] = std::move(strVar); col2[0] = intVar; - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2); delete[] col0; delete[] col1; delete[] col2; - delete strTmp; return in; } VectorBatch* CreateVectorBatch_4varcharCols_withPid(int parNum, int rowNum) { int partitionNum = parNum; - const int32_t numCols = 5; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_VARCHAR; - inputTypes[2] = OMNI_VARCHAR; - inputTypes[3] = OMNI_VARCHAR; - inputTypes[4] = OMNI_VARCHAR; + DataTypes inputTypes( + std::vector({ IntType(), VarcharType(), VarcharType(), VarcharType(), VarcharType() })); const int32_t numRows = rowNum; auto* col0 = new int32_t[numRows]; - auto* col1 = new int64_t[numRows]; - auto* col2 = new int64_t[numRows]; - auto* col3 = new int64_t[numRows]; - auto* col4 = new int64_t[numRows]; + auto* col1 = new std::string[numRows]; + auto* col2 = new std::string[numRows]; + auto* col3 = new std::string[numRows]; + auto* col4 = new std::string[numRows]; - std::vector string_cache_test_; for (int i = 0; i < numRows; i++) { - col0[i] = (i+1) % partitionNum; - std::string* strTmp1 = new std::string("Col1_START_" + to_string(i + 1) + "_END_"); - col1[i] = (int64_t)((*strTmp1).c_str()); - std::string* strTmp2 = new std::string("Col2_START_" + to_string(i + 1) + "_END_"); - col2[i] = (int64_t)((*strTmp2).c_str()); - std::string* strTmp3 = new std::string("Col3_START_" + to_string(i + 1) + "_END_"); - col3[i] = (int64_t)((*strTmp3).c_str()); - std::string* strTmp4 = new std::string("Col4_START_" + to_string(i + 1) + "_END_"); - col4[i] = (int64_t)((*strTmp4).c_str()); - string_cache_test_.push_back(strTmp1); - string_cache_test_.push_back(strTmp2); - string_cache_test_.push_back(strTmp3); - string_cache_test_.push_back(strTmp4); + col0[i] = (i + 1) % partitionNum; + std::string strTmp1 = std::string("Col1_START_" + to_string(i + 1) + "_END_"); + col1[i] = std::move(strTmp1); + std::string strTmp2 = std::string("Col2_START_" + to_string(i + 1) + "_END_"); + col2[i] = std::move(strTmp2); + std::string strTmp3 = std::string("Col3_START_" + to_string(i + 1) + "_END_"); + col3[i] = std::move(strTmp3); + std::string strTmp4 = std::string("Col4_START_" + to_string(i + 1) + "_END_"); + col4[i] = std::move(strTmp4); } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3), - reinterpret_cast(col4)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4); delete[] col0; delete[] col1; delete[] col2; delete[] col3; delete[] col4; - - for (uint p = 0; p < string_cache_test_.size(); p++) { - delete string_cache_test_[p]; // release memory - } return in; } VectorBatch* CreateVectorBatch_4charCols_withPid(int parNum, int rowNum) { int partitionNum = parNum; - const int32_t numCols = 5; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_CHAR; - inputTypes[2] = OMNI_CHAR; - inputTypes[3] = OMNI_CHAR; - inputTypes[4] = OMNI_CHAR; + DataTypes inputTypes(std::vector({ IntType(), CharType(), CharType(), CharType(), CharType() })); const int32_t numRows = rowNum; auto* col0 = new int32_t[numRows]; - auto* col1 = new int64_t[numRows]; - auto* col2 = new int64_t[numRows]; - auto* col3 = new int64_t[numRows]; - auto* col4 = new int64_t[numRows]; + auto* col1 = new std::string[numRows]; + auto* col2 = new std::string[numRows]; + auto* col3 = new std::string[numRows]; + auto* col4 = new std::string[numRows]; std::vector string_cache_test_; for (int i = 0; i < numRows; i++) { - col0[i] = (i+1) % partitionNum; - std::string* strTmp1 = new std::string("Col1_CHAR_" + to_string(i + 1) + "_END_"); - col1[i] = (int64_t)((*strTmp1).c_str()); - std::string* strTmp2 = new std::string("Col2_CHAR_" + to_string(i + 1) + "_END_"); - col2[i] = (int64_t)((*strTmp2).c_str()); - std::string* strTmp3 = new std::string("Col3_CHAR_" + to_string(i + 1) + "_END_"); - col3[i] = (int64_t)((*strTmp3).c_str()); - std::string* strTmp4 = new std::string("Col4_CHAR_" + to_string(i + 1) + "_END_"); - col4[i] = (int64_t)((*strTmp4).c_str()); - string_cache_test_.push_back(strTmp1); - string_cache_test_.push_back(strTmp2); - string_cache_test_.push_back(strTmp3); - string_cache_test_.push_back(strTmp4); + col0[i] = (i + 1) % partitionNum; + std::string strTmp1 = std::string("Col1_CHAR_" + to_string(i + 1) + "_END_"); + col1[i] = std::move(strTmp1); + std::string strTmp2 = std::string("Col2_CHAR_" + to_string(i + 1) + "_END_"); + col2[i] = std::move(strTmp2); + std::string strTmp3 = std::string("Col3_CHAR_" + to_string(i + 1) + "_END_"); + col3[i] = std::move(strTmp3); + std::string strTmp4 = std::string("Col4_CHAR_" + to_string(i + 1) + "_END_"); + col4[i] = std::move(strTmp4); } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3), - reinterpret_cast(col4)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4); delete[] col0; delete[] col1; delete[] col2; delete[] col3; delete[] col4; - - for (uint p = 0; p < string_cache_test_.size(); p++) { - delete string_cache_test_[p]; // release memory - } return in; } VectorBatch* CreateVectorBatch_5fixedCols_withPid(int parNum, int rowNum) { int partitionNum = parNum; - // gen vectorBatch - const int32_t numCols = 6; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_BOOLEAN; - inputTypes[2] = OMNI_SHORT; - inputTypes[3] = OMNI_INT; - inputTypes[4] = OMNI_LONG; - inputTypes[5] = OMNI_DOUBLE; + DataTypes inputTypes( + std::vector({ IntType(), BooleanType(), ShortType(), IntType(), LongType(), DoubleType() })); const int32_t numRows = rowNum; auto* col0 = new int32_t[numRows]; @@ -490,14 +236,7 @@ VectorBatch* CreateVectorBatch_5fixedCols_withPid(int parNum, int rowNum) { col5[i] = i + 1; } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3), - reinterpret_cast(col4), - reinterpret_cast(col5)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4, col5); delete[] col0; delete[] col1; delete[] col2; @@ -512,71 +251,85 @@ VectorBatch* CreateVectorBatch_2dictionaryCols_withPid(int partitionNum) { // construct input data const int32_t dataSize = 6; // prepare data - int32_t data0[dataSize] = {111, 112, 113, 114, 115, 116}; - int64_t data1[dataSize] = {221, 222, 223, 224, 225, 226}; - void *datas[2] = {data0, data1}; - DataTypes sourceTypes(std::vector({ std::make_unique(), std::make_unique()})); - int32_t ids[] = {0, 1, 2, 3, 4, 5}; - VectorBatch *vectorBatch = new VectorBatch(3, dataSize); - VectorAllocator *allocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - IntVector *intVectorTmp = new IntVector(allocator, 6); - for (int i = 0; i < intVectorTmp->GetSize(); i++) { - intVectorTmp->SetValue(i, (i+1) % partitionNum); - } - for (int32_t i = 0; i < 3; i ++) { - if (i == 0) { - vectorBatch->SetVector(i, intVectorTmp); - } else { - omniruntime::vec::DataType dataType = *(sourceTypes.Get()[i - 1]); - vectorBatch->SetVector(i, CreateDictionaryVector(dataType, dataSize, ids, dataSize, datas[i - 1])); - } + auto *col0 = new int32_t[dataSize]; + for (int32_t i = 0; i< dataSize; i++) { + col0[i] = (i + 1) % partitionNum; } + int32_t col1[dataSize] = {111, 112, 113, 114, 115, 116}; + int64_t col2[dataSize] = {221, 222, 223, 224, 225, 226}; + void *datas[2] = {col1, col2}; + DataTypes sourceTypes(std::vector({ IntType(), LongType() })); + int32_t ids[] = {0, 1, 2, 3, 4, 5}; + + VectorBatch *vectorBatch = new VectorBatch(dataSize); + auto Vec0 = CreateVector(dataSize, col0); + vectorBatch->Append(Vec0.release()); + auto dicVec0 = CreateDictionaryVector(*sourceTypes.GetType(0), dataSize, ids, dataSize, datas[0]); + auto dicVec1 = CreateDictionaryVector(*sourceTypes.GetType(1), dataSize, ids, dataSize, datas[1]); + vectorBatch->Append(dicVec0.release()); + vectorBatch->Append(dicVec1.release()); + + delete[] col0; return vectorBatch; } VectorBatch* CreateVectorBatch_1decimal128Col_withPid(int partitionNum, int rowNum) { - auto decimal128InputVec = buildVector(Decimal128DataType(38, 2), rowNum); - VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); - IntVector *intVectorPid = new IntVector(allocator, rowNum); - for (int i = 0; i < intVectorPid->GetSize(); i++) { - intVectorPid->SetValue(i, (i+1) % partitionNum); + const int32_t numRows = rowNum; + DataTypes inputTypes(std::vector({ IntType(), Decimal128Type(38, 2) })); + + auto *col0 = new int32_t[numRows]; + auto *col1 = new Decimal128[numRows]; + for (int32_t i = 0; i < numRows; i++) { + col0[i] = (i + 1) % partitionNum; + col1[i] = Decimal128(0, 1); } - VectorBatch *vecBatch = new VectorBatch(2); - vecBatch->SetVector(0, intVectorPid); - vecBatch->SetVector(1, decimal128InputVec); - return vecBatch; + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1); + delete[] col0; + delete[] col1; + return in; } VectorBatch* CreateVectorBatch_1decimal64Col_withPid(int partitionNum, int rowNum) { - auto decimal64InputVec = buildVector(Decimal64DataType(7, 2), rowNum); - VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); - IntVector *intVectorPid = new IntVector(allocator, rowNum); - for (int i = 0; i < intVectorPid->GetSize(); i++) { - intVectorPid->SetValue(i, (i+1) % partitionNum); + const int32_t numRows = rowNum; + DataTypes inputTypes(std::vector({ IntType(), Decimal64Type(7, 2) })); + + auto *col0 = new int32_t[numRows]; + auto *col1 = new int64_t[numRows]; + for (int32_t i = 0; i < numRows; i++) { + col0[i] = (i + 1) % partitionNum; + col1[i] = 1; } - VectorBatch *vecBatch = new VectorBatch(2); - vecBatch->SetVector(0, intVectorPid); - vecBatch->SetVector(1, decimal64InputVec); - return vecBatch; + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1); + delete[] col0; + delete[] col1; + return in; } VectorBatch* CreateVectorBatch_2decimalCol_withPid(int partitionNum, int rowNum) { - auto decimal64InputVec = buildVector(Decimal64DataType(7, 2), rowNum); - auto decimal128InputVec = buildVector(Decimal128DataType(38, 2), rowNum); - VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); - IntVector *intVectorPid = new IntVector(allocator, rowNum); - for (int i = 0; i < intVectorPid->GetSize(); i++) { - intVectorPid->SetValue(i, (i+1) % partitionNum); + const int32_t numRows = rowNum; + DataTypes inputTypes(std::vector({ IntType(), Decimal64Type(7, 2), Decimal128Type(38, 2) })); + + auto *col0 = new int32_t[numRows]; + auto *col1 = new int64_t[numRows]; + auto *col2 = new Decimal128[numRows]; + for (int32_t i = 0; i < numRows; i++) { + col0[i] = (i + 1) % partitionNum; + col1[i] = 1; + col2[i] = Decimal128(0, 1); } - VectorBatch *vecBatch = new VectorBatch(3); - vecBatch->SetVector(0, intVectorPid); - vecBatch->SetVector(1, decimal64InputVec); - vecBatch->SetVector(2, decimal128InputVec); - return vecBatch; + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2); + delete[] col0; + delete[] col1; + delete[] col2; + return in; } VectorBatch* CreateVectorBatch_someNullRow_vectorBatch() { const int32_t numRows = 6; + const int32_t numCols = 6; bool data0[numRows] = {true, false, true, false, true, false}; int16_t data1[numRows] = {0, 1, 2, 3, 4, 6}; int32_t data2[numRows] = {0, 1, 2, 0, 1, 2}; @@ -584,50 +337,32 @@ VectorBatch* CreateVectorBatch_someNullRow_vectorBatch() { double data4[numRows] = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; std::string data5[numRows] = {"abcde", "fghij", "klmno", "pqrst", "", ""}; - auto vec0 = CreateVector(data0, numRows); - auto vec1 = CreateVector(data1, numRows); - auto vec2 = CreateVector(data2, numRows); - auto vec3 = CreateVector(data3, numRows); - auto vec4 = CreateVector(data4, numRows); - auto vec5 = CreateVarcharVector(VarcharDataType(5), data5, numRows); - for (int i = 0; i < numRows; i = i + 2) { - vec0->SetValueNull(i); - vec1->SetValueNull(i); - vec2->SetValueNull(i); - vec3->SetValueNull(i); - vec4->SetValueNull(i); - vec5->SetValueNull(i); + DataTypes inputTypes( + std::vector({ BooleanType(), ShortType(), IntType(), LongType(), DoubleType(), VarcharType(5) })); + VectorBatch* vecBatch = CreateVectorBatch(inputTypes, numRows, data0, data1, data2, data3, data4, data5); + for (int32_t i = 0; i < numCols; i++) { + for (int32_t j = 0; j < numRows; j = j + 2) { + vecBatch->Get(i)->SetNull(j); + } } - VectorBatch *vecBatch = new VectorBatch(6); - vecBatch->SetVector(0, vec0); - vecBatch->SetVector(1, vec1); - vecBatch->SetVector(2, vec2); - vecBatch->SetVector(3, vec3); - vecBatch->SetVector(4, vec4); - vecBatch->SetVector(5, vec5); return vecBatch; } VectorBatch* CreateVectorBatch_someNullCol_vectorBatch() { const int32_t numRows = 6; + const int32_t numCols = 4; int32_t data1[numRows] = {0, 1, 2, 0, 1, 2}; int64_t data2[numRows] = {0, 1, 2, 3, 4, 5}; double data3[numRows] = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; std::string data4[numRows] = {"abcde", "fghij", "klmno", "pqrst", "", ""}; - auto vec0 = CreateVector(data1, numRows); - auto vec1 = CreateVector(data2, numRows); - auto vec2 = CreateVector(data3, numRows); - auto vec3 = CreateVarcharVector(VarcharDataType(5), data4, numRows); - for (int i = 0; i < numRows; i = i + 1) { - vec1->SetValueNull(i); - vec3->SetValueNull(i); + DataTypes inputTypes(std::vector({ IntType(), LongType(), DoubleType(), VarcharType(5) })); + VectorBatch* vecBatch = CreateVectorBatch(inputTypes, numRows, data1, data2, data3, data4); + for (int32_t i = 0; i < numCols; i = i + 2) { + for (int32_t j = 0; j < numRows; j++) { + vecBatch->Get(i)->SetNull(j); + } } - VectorBatch *vecBatch = new VectorBatch(4); - vecBatch->SetVector(0, vec0); - vecBatch->SetVector(1, vec1); - vecBatch->SetVector(2, vec2); - vecBatch->SetVector(3, vec3); return vecBatch; } diff --git a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h index 496a4cc6fc6d0a8834a95db72ccccb5376fe02b6..aad8ca49fb3ded5cdcfc44ee53f7b18d52389efa 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h +++ b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h @@ -32,15 +32,62 @@ static ConcurrentMap> shuffle_splitter_holder_; static std::string s_shuffle_tests_dir = "/tmp/shuffleTests"; -VectorBatch* CreateInputData(const int32_t numRows, const int32_t numCols, int32_t* inputTypeIds, int64_t* allData); +VectorBatch *CreateVectorBatch(const DataTypes &types, int32_t rowCount, ...); -Vector *buildVector(const DataType &aggType, int32_t rowNumber); +std::unique_ptr CreateVector(DataType &dataType, int32_t rowCount, va_list &args); + +template std::unique_ptr CreateVector(int32_t length, T *values) +{ + std::unique_ptr> vector = std::make_unique>(length); + for (int32_t i = 0; i < length; i++) { + vector->SetValue(i, values[i]); + } + return vector; +} + +template +std::unique_ptr CreateFlatVector(int32_t length, va_list &args) +{ + using namespace omniruntime::type; + using T = typename NativeType::type; + using VarcharVector = Vector>; + if constexpr (std::is_same_v || std::is_same_v) { + std::unique_ptr vector = std::make_unique(length); + std::string *str = va_arg(args, std::string *); + for (int32_t i = 0; i < length; i++) { + std::string_view value(str[i].data(), str[i].length()); + vector->SetValue(i, value); + } + return vector; + } else { + std::unique_ptr> vector = std::make_unique>(length); + T *value = va_arg(args, T *); + for (int32_t i = 0; i < length; i++) { + vector->SetValue(i, value[i]); + } + return vector; + } +} + +std::unique_ptr CreateDictionaryVector(DataType &dataType, int32_t rowCount, int32_t *ids, int32_t idsCount, + ...); + +template +std::unique_ptr CreateDictionary(BaseVector *vector, int32_t *ids, int32_t size) +{ + using T = typename NativeType::type; + if constexpr (std::is_same_v || std::is_same_v) { + return VectorHelper::CreateStringDictionary(ids, size, + reinterpret_cast> *>(vector)); + } + return VectorHelper::CreateDictionary(ids, size, reinterpret_cast *>(vector)); +} VectorBatch* CreateVectorBatch_1row_varchar_withPid(int pid, std::string inputChar); VectorBatch* CreateVectorBatch_4col_withPid(int parNum, int rowNum); -VectorBatch* CreateVectorBatch_1FixCol_withPid(int parNum, int rowNum, int32_t fixColType); +VectorBatch* CreateVectorBatch_1FixCol_withPid(int parNum, int rowNum, DataTypePtr fixColType); VectorBatch* CreateVectorBatch_2column_1row_withPid(int pid, std::string strVar, int intVar); @@ -79,14 +126,6 @@ void Test_splitter_stop(long splitter_id); void Test_splitter_close(long splitter_id); -template T *CreateVector(V *values, int32_t length) -{ - VectorAllocator *vecAllocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - auto vector = new T(vecAllocator, length); - vector->SetValues(0, values, length); - return vector; -} - void GetFilePath(const char *path, const char *filename, char *filepath); void DeletePathAll(const char* path); diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml index c38a853744cd2011682fe2af2ca85c6affafdc87..26ba96ff69bc04872f4c1d94e877a0831099c037 100644 --- a/omnioperator/omniop-spark-extension/java/pom.xml +++ b/omnioperator/omniop-spark-extension/java/pom.xml @@ -7,7 +7,7 @@ com.huawei.kunpeng boostkit-omniop-spark-parent - 3.1.1-1.2.0 + 3.2.1-1.2.0 ../pom.xml @@ -20,6 +20,7 @@ ../cpp/build/releases/ ${cpp.test} incremental + 3.2.1 0.6.1 3.0.0 1.6.2 @@ -103,20 +104,20 @@ spark-core_${scala.binary.version} test-jar test - 3.1.1 + ${spark.version} org.apache.spark spark-catalyst_${scala.binary.version} test-jar test - 3.1.1 + ${spark.version} org.apache.spark spark-sql_${scala.binary.version} test-jar - 3.1.1 + ${spark.version} test @@ -127,7 +128,7 @@ org.apache.spark spark-hive_${scala.binary.version} - 3.1.1 + ${spark.version} provided diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java index 1e4d1c7bb053489559359408c5460b59adc36a4b..d80a236533c6b2b3305b2f443b759877239d6089 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java @@ -19,7 +19,6 @@ package com.huawei.boostkit.spark.jni; import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.type.Decimal128DataType; import nova.hetu.omniruntime.vector.*; import org.apache.spark.sql.catalyst.util.RebaseDateTime; @@ -273,7 +272,7 @@ public class OrcColumnarBatchJniReader { break; } case OMNI_DECIMAL128: { - vecList[i] = new Decimal128Vec(vecNativeIds[nativeGetId], Decimal128DataType.DECIMAL128); + vecList[i] = new Decimal128Vec(vecNativeIds[nativeGetId]); break; } default: { diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java index 808f96e1fb666def4ff9fc224f01020a81a5baf7..5379fd7c9501762279f4fa0279263c9658e4d827 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java @@ -194,32 +194,32 @@ public class OmniColumnVector extends WritableColumnVector { @Override public boolean hasNull() { if (dictionaryData != null) { - return dictionaryData.hasNullValue(); + return dictionaryData.hasNull(); } if (type instanceof BooleanType) { - return booleanDataVec.hasNullValue(); + return booleanDataVec.hasNull(); } else if (type instanceof ByteType) { - return charsTypeDataVec.hasNullValue(); + return charsTypeDataVec.hasNull(); } else if (type instanceof ShortType) { - return shortDataVec.hasNullValue(); + return shortDataVec.hasNull(); } else if (type instanceof IntegerType) { - return intDataVec.hasNullValue(); + return intDataVec.hasNull(); } else if (type instanceof DecimalType) { if (DecimalType.is64BitDecimalType(type)) { - return longDataVec.hasNullValue(); + return longDataVec.hasNull(); } else { - return decimal128DataVec.hasNullValue(); + return decimal128DataVec.hasNull(); } } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) { - return longDataVec.hasNullValue(); + return longDataVec.hasNull(); } else if (type instanceof FloatType) { return false; } else if (type instanceof DoubleType) { - return doubleDataVec.hasNullValue(); + return doubleDataVec.hasNull(); } else if (type instanceof StringType) { - return charsTypeDataVec.hasNullValue(); + return charsTypeDataVec.hasNull(); } else if (type instanceof DateType) { - return intDataVec.hasNullValue(); + return intDataVec.hasNull(); } throw new UnsupportedOperationException("hasNull is not supported for type:" + type); } @@ -806,7 +806,7 @@ public class OmniColumnVector extends WritableColumnVector { if (type instanceof BooleanType) { booleanDataVec = new BooleanVec(newCapacity); } else if (type instanceof ByteType) { - charsTypeDataVec = new VarcharVec(newCapacity * 4, newCapacity); + charsTypeDataVec = new VarcharVec(newCapacity); } else if (type instanceof ShortType) { shortDataVec = new ShortVec(newCapacity); } else if (type instanceof IntegerType) { @@ -825,7 +825,7 @@ public class OmniColumnVector extends WritableColumnVector { doubleDataVec = new DoubleVec(newCapacity); } else if (type instanceof StringType) { // need to set with real column size, suppose char(200) utf8 - charsTypeDataVec = new VarcharVec(newCapacity * 4 * 200, newCapacity); + charsTypeDataVec = new VarcharVec(newCapacity); } else if (type instanceof DateType) { intDataVec = new IntVec(newCapacity); } else { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala index a4e4eaa0a877f7ee2e3401ecf4ee98fecfcb7314..6d18abdb714f6ab88de0ca6e719e5b938f7fbb32 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, CustomShuffleReaderExec} +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ @@ -37,6 +37,9 @@ case class RowGuard(child: SparkPlan) extends SparkPlan { } def children: Seq[SparkPlan] = Seq(child) + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = + legacyWithNewChildren(newChildren) } case class ColumnarGuardRule() extends Rule[SparkPlan] { @@ -127,9 +130,9 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { left match { case exec: BroadcastExchangeExec => new ColumnarBroadcastExchangeExec(exec.mode, exec.child) - case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec) => + case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec, _) => new ColumnarBroadcastExchangeExec(plan.mode, plan.child) - case BroadcastQueryStageExec(_, plan: ReusedExchangeExec) => + case BroadcastQueryStageExec(_, plan: ReusedExchangeExec, _) => plan match { case ReusedExchangeExec(_, b: BroadcastExchangeExec) => new ColumnarBroadcastExchangeExec(b.mode, b.child) @@ -141,9 +144,9 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { right match { case exec: BroadcastExchangeExec => new ColumnarBroadcastExchangeExec(exec.mode, exec.child) - case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec) => + case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec, _) => new ColumnarBroadcastExchangeExec(plan.mode, plan.child) - case BroadcastQueryStageExec(_, plan: ReusedExchangeExec) => + case BroadcastQueryStageExec(_, plan: ReusedExchangeExec, _) => plan match { case ReusedExchangeExec(_, b: BroadcastExchangeExec) => new ColumnarBroadcastExchangeExec(b.mode, b.child) @@ -182,7 +185,8 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.buildSide, plan.condition, plan.left, - plan.right).buildCheck() + plan.right, + plan.isSkewJoin).buildCheck() case plan: BroadcastNestedLoopJoinExec => return false case p => p @@ -237,7 +241,7 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { case p if !supportCodegen(p) => // insert row guard them recursively p.withNewChildren(p.children.map(insertRowGuardOrNot)) - case p: CustomShuffleReaderExec => + case p: OmniAQEShuffleReadExec => p.withNewChildren(p.children.map(insertRowGuardOrNot)) case p: BroadcastQueryStageExec => p diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index d3fcbaf539493ff8a26b72b6e1b98c4c448b54c9..ed8980e15c2f25ec14d525a0a0e6adae6855e048 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -20,11 +20,11 @@ package com.huawei.boostkit.spark import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} -import org.apache.spark.sql.catalyst.expressions.DynamicPruningSubquery +import org.apache.spark.sql.catalyst.expressions.{Ascending, DynamicPruningSubquery, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate.Partial import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowToOmniColumnarExec, _} -import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ColumnarCustomShuffleReaderExec, CustomShuffleReaderExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec, AQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ @@ -122,16 +122,45 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { ColumnarConditionProjectExec(plan.projectList, condition, child) case join : ColumnarBroadcastHashJoinExec => if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { - ColumnarBroadcastHashJoinExec( - join.leftKeys, - join.rightKeys, - join.joinType, - join.buildSide, - join.condition, - join.left, - join.right, - join.isNullAwareAntiJoin, - plan.projectList) + ColumnarBroadcastHashJoinExec( + join.leftKeys, + join.rightKeys, + join.joinType, + join.buildSide, + join.condition, + join.left, + join.right, + join.isNullAwareAntiJoin, + plan.projectList) + } else { + ColumnarProjectExec(plan.projectList, child) + } + case join : ColumnarShuffledHashJoinExec => + if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { + ColumnarShuffledHashJoinExec( + join.leftKeys, + join.rightKeys, + join.joinType, + join.buildSide, + join.condition, + join.left, + join.right, + join.isSkewJoin, + plan.projectList) + } else { + ColumnarProjectExec(plan.projectList, child) + } + case join : ColumnarSortMergeJoinExec => + if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { + ColumnarSortMergeJoinExec( + join.leftKeys, + join.rightKeys, + join.joinType, + join.condition, + join.left, + join.right, + join.isSkewJoin, + plan.projectList) } else { ColumnarProjectExec(plan.projectList, child) } @@ -311,7 +340,8 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { plan.buildSide, plan.condition, left, - right) + right, + plan.isSkewJoin) case plan: SortMergeJoinExec if enableColumnarSortMergeJoin => logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") val left = replaceWithColumnarPlan(plan.left) @@ -332,7 +362,16 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { case plan: WindowExec if enableColumnarWindow => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) + child match { + case ColumnarSortExec(sortOrder, _, sortChild, _) => + if (Seq(plan.partitionSpec.map(SortOrder(_, Ascending)) ++ plan.orderSpec) == Seq(sortOrder)) { + ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, sortChild) + } else { + ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) + } + case _ => + ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) + } case plan: UnionExec if enableColumnarUnion => val children = plan.children.map(replaceWithColumnarPlan) logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") @@ -341,19 +380,19 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin) - case plan: CustomShuffleReaderExec if columnarConf.enableColumnarShuffle => + case plan: AQEShuffleReadExec if columnarConf.enableColumnarShuffle => plan.child match { case shuffle: ColumnarShuffleExchangeExec => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarCustomShuffleReaderExec(plan.child, plan.partitionSpecs) - case ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec) => + OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs) + case ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec, _) => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarCustomShuffleReaderExec(plan.child, plan.partitionSpecs) - case ShuffleQueryStageExec(_, reused: ReusedExchangeExec) => + OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs) + case ShuffleQueryStageExec(_, reused: ReusedExchangeExec, _) => reused match { case ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec) => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarCustomShuffleReaderExec( + OmniAQEShuffleReadExec( plan.child, plan.partitionSpecs) case _ => @@ -375,13 +414,15 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { curPlan.id, BroadcastExchangeExec( originalBroadcastPlan.mode, - ColumnarBroadcastExchangeAdaptorExec(originalBroadcastPlan, 1))) + ColumnarBroadcastExchangeAdaptorExec(originalBroadcastPlan, 1)), + curPlan._canonicalized) case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeExec) => BroadcastQueryStageExec( curPlan.id, BroadcastExchangeExec( originalBroadcastPlan.mode, - ColumnarBroadcastExchangeAdaptorExec(curPlan.plan, 1))) + ColumnarBroadcastExchangeAdaptorExec(curPlan.plan, 1)), + curPlan._canonicalized) case _ => curPlan } @@ -394,7 +435,22 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { var isSupportAdaptive: Boolean = true def apply(plan: SparkPlan): SparkPlan = { - replaceWithColumnarPlan(plan) + handleColumnarToRowParitalFetch(replaceWithColumnarPlan(plan)) + } + + private def handleColumnarToRowParitalFetch(plan: SparkPlan): SparkPlan = { + // simple check plan tree have OmniColumnarToRow and no LimitExec and TakeOrderedAndProjectExec plan + val noParitalFetch = if (plan.find(_.isInstanceOf[OmniColumnarToRowExec]).isDefined) { + (!plan.find(node => + node.isInstanceOf[LimitExec] || node.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) + } else { + false + } + val newPlan = plan.transformUp { + case c: OmniColumnarToRowExec if noParitalFetch => + c.copy(c.child, false) + } + newPlan } def setAdaptiveSupport(enable: Boolean): Unit = { isSupportAdaptive = enable } @@ -409,11 +465,26 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { case ColumnarToRowExec(child: ColumnarBroadcastExchangeExec) => replaceWithColumnarPlan(child) case plan: ColumnarToRowExec => - val child = replaceWithColumnarPlan(plan.child) - if (conf.getConfString("spark.omni.sql.columnar.columnarToRow", "true").toBoolean) { - OmniColumnarToRowExec(child) - } else { - ColumnarToRowExec(child) + plan.child match { + case child: BroadcastQueryStageExec => + child.plan match { + case originalBroadcastPlan: ColumnarBroadcastExchangeExec => + BroadcastQueryStageExec( + child.id, + BroadcastExchangeExec( + originalBroadcastPlan.mode, + ColumnarBroadcastExchangeAdaptorExec(originalBroadcastPlan, 1)), child._canonicalized) + case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeExec) => + BroadcastQueryStageExec( + child.id, + BroadcastExchangeExec( + originalBroadcastPlan.mode, + ColumnarBroadcastExchangeAdaptorExec(child.plan, 1)), child._canonicalized) + case _ => + replaceColumnarToRow(plan, conf) + } + case _ => + replaceColumnarToRow(plan, conf) } case r: SparkPlan if !r.isInstanceOf[QueryStageExec] && !r.supportsColumnar && r.children.exists(c => @@ -421,7 +492,11 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { val children = r.children.map { case c: ColumnarToRowExec => val child = replaceWithColumnarPlan(c.child) - OmniColumnarToRowExec(child) + if (conf.getConfString("spark.omni.sql.columnar.columnarToRow", "true").toBoolean) { + OmniColumnarToRowExec(child) + } else { + ColumnarToRowExec(child) + } case other => replaceWithColumnarPlan(other) } @@ -430,6 +505,15 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { val children = p.children.map(replaceWithColumnarPlan) p.withNewChildren(children) } + + def replaceColumnarToRow(plan: ColumnarToRowExec, conf: SQLConf) : SparkPlan = { + val child = replaceWithColumnarPlan(plan.child) + if (conf.getConfString("spark.omni.sql.columnar.columnarToRow", "true").toBoolean) { + OmniColumnarToRowExec(child) + } else { + ColumnarToRowExec(child) + } + } } case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule with Logging { @@ -487,4 +571,4 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { extensions.injectColumnar(session => ColumnarOverrideRules(session)) extensions.injectPlannerStrategy(_ => ShuffleJoinStrategy) } -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala index e773a780dcfa6e66dc8c96e97e29d80f59703e73..9d7f844bcc19601ac065083b988085c340631ad3 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala @@ -24,7 +24,7 @@ import nova.hetu.omniruntime.`type`.DataType.DataTypeId * @since 2022/4/15 */ object Constant { - val DEFAULT_STRING_TYPE_LENGTH = 2000 + val DEFAULT_STRING_TYPE_LENGTH = 50 val OMNI_VARCHAR_TYPE: String = DataTypeId.OMNI_VARCHAR.ordinal().toString val OMNI_SHOR_TYPE: String = DataTypeId.OMNI_SHORT.ordinal().toString val OMNI_INTEGER_TYPE: String = DataTypeId.OMNI_INT.ordinal().toString diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala index 9a45854da24375e1db310356db943df8f245fa49..1f9a3ce129bc46cd9fa70e6e4c85f25097c77776 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala @@ -22,8 +22,10 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper} import org.apache.spark.sql.catalyst.planning._ +import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{joins, SparkPlan} +import org.apache.spark.sql.execution.{SparkPlan, joins} +import org.apache.spark.sql.internal.SQLConf object ShuffleJoinStrategy extends Strategy with PredicateHelper @@ -107,6 +109,39 @@ object ShuffleJoinStrategy extends Strategy case _ => Nil } + override def getShuffleHashJoinBuildSide( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + hint: JoinHint, + hintOnly: Boolean, + conf: SQLConf): Option[BuildSide] = { + val buildLeft = if (hintOnly) { + hintToShuffleHashJoinLeft(hint) + } else { + canBuildLocalHashMapBySize(left, conf) && muchSmaller(left, right) + } + val buildRight = if (hintOnly) { + hintToShuffleHashJoinRight(hint) + } else { + canBuildLocalHashMapBySize(right, conf) && muchSmaller(right, left) + } + getBuildSide( + canBuildShuffledHashJoinLeft(joinType) && buildLeft, + canBuildShuffledHashJoinRight(joinType) && buildRight, + left, + right + ) + } + + private def canBuildLocalHashMapBySize(plan: LogicalPlan, conf: SQLConf): Boolean = { + plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions + } + + private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { + a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes + } + private def getBuildSide( canBuildLeft: Boolean, canBuildRight: Boolean, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 49f60368853931c52052a4e556ef060a61cc6061..978308bcd2099b1a7121b5e9ab7ff085ba4c7187 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -666,14 +666,14 @@ object OmniExpressionAdaptor extends Logging { } } - def toOmniAggFunType(agg: AggregateExpression, isHashAgg: Boolean = false, isFinal: Boolean = false): FunctionType = { + def toOmniAggFunType(agg: AggregateExpression, isHashAgg: Boolean = false, isMergeCount: Boolean = false): FunctionType = { agg.aggregateFunction match { - case Sum(_) => OMNI_AGGREGATION_TYPE_SUM + case Sum(_, _) => OMNI_AGGREGATION_TYPE_SUM case Max(_) => OMNI_AGGREGATION_TYPE_MAX - case Average(_) => OMNI_AGGREGATION_TYPE_AVG + case Average(_, _) => OMNI_AGGREGATION_TYPE_AVG case Min(_) => OMNI_AGGREGATION_TYPE_MIN case Count(Literal(1, IntegerType) :: Nil) | Count(ArrayBuffer(Literal(1, IntegerType))) => - if (isFinal) { + if (isMergeCount) { OMNI_AGGREGATION_TYPE_COUNT_COLUMN } else { OMNI_AGGREGATION_TYPE_COUNT_ALL diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala index e95ab8dcbfd4f480cd0820f2cc60d774f7f87ddb..ed99f6b4311a48492438095a87d450f7d9d89a5a 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala @@ -26,7 +26,7 @@ import nova.hetu.omniruntime.operator.OmniOperator import nova.hetu.omniruntime.operator.aggregator.{OmniAggregationWithExprOperatorFactory, OmniHashAggregationWithExprOperatorFactory} import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} import nova.hetu.omniruntime.vector._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, ExprId, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, ExprId, NamedExpression, SortOrder} import org.apache.spark.sql.execution.datasources.orc.OrcColumnVector import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.vectorized.{OmniColumnVector, OnHeapColumnVector} @@ -34,6 +34,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} +import scala.collection.mutable.ListBuffer import java.util object OmniAdaptorUtil { @@ -122,7 +123,7 @@ object OmniAdaptorUtil { } offsets(i + 1) = totalSize } - val vec = new VarcharVec(totalSize, columnSize) + val vec = new VarcharVec(columnSize) val values = new Array[Byte](totalSize) for (i <- 0 until columnSize) { if (null != columnVector.getUTF8String(i)) { @@ -276,6 +277,7 @@ object OmniAdaptorUtil { def getAggOperator(groupingExpressions: Seq[NamedExpression], omniGroupByChanel: Array[String], omniAggChannels: Array[Array[String]], + omniAggChannelsFilter: Array[String], omniSourceTypes: Array[nova.hetu.omniruntime.`type`.DataType], omniAggFunctionTypes: Array[FunctionType], omniAggOutputTypes: Array[Array[nova.hetu.omniruntime.`type`.DataType]], @@ -286,6 +288,7 @@ object OmniAdaptorUtil { operator = new OmniHashAggregationWithExprOperatorFactory( omniGroupByChanel, omniAggChannels, + omniAggChannelsFilter, omniSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, @@ -296,6 +299,7 @@ object OmniAdaptorUtil { operator = new OmniAggregationWithExprOperatorFactory( omniGroupByChanel, omniAggChannels, + omniAggChannelsFilter, omniSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, @@ -305,4 +309,61 @@ object OmniAdaptorUtil { } operator } + + def pruneOutput(output: Seq[Attribute], projectList: Seq[NamedExpression]): Seq[Attribute] = { + if (projectList.nonEmpty) { + val projectOutput = ListBuffer[Attribute]() + for (project <- projectList) { + for (col <- output) { + if (col.exprId.equals(getProjectAliasExprId(project))) { + projectOutput += col + } + } + } + projectOutput + } else { + output + } + } + + def getIndexArray(output: Seq[Attribute], projectList: Seq[NamedExpression]): Array[Int] = { + if (projectList.nonEmpty) { + val indexList = ListBuffer[Int]() + for (project <- projectList) { + for (i <- output.indices) { + val col = output(i) + if (col.exprId.equals(getProjectAliasExprId(project))) { + indexList += i + } + } + } + indexList.toArray + } else { + output.indices.toArray + } + } + + def reorderVecs(prunedOutput: Seq[Attribute], projectList: Seq[NamedExpression], resultVecs: Array[nova.hetu.omniruntime.vector.Vec], vecs: Array[OmniColumnVector]) = { + for (index <- projectList.indices) { + val project = projectList(index) + for (i <- prunedOutput.indices) { + val col = prunedOutput(i) + if (col.exprId.equals(getProjectAliasExprId(project))) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + } + } + } + } + + def getProjectAliasExprId(project: NamedExpression): ExprId = { + project match { + case alias: Alias => + // The condition of parameter is restricted. If parameter type is alias, its child type must be attributeReference. + alias.child.asInstanceOf[AttributeReference].exprId + case _ => + project.exprId + } + } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala index 7eca3427ec3f6c618f84e70aeb85ce98d0267176..615ddb6b7449d3ba3f2b6839df8eee67b6d5b05e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala @@ -71,7 +71,7 @@ class ColumnarShuffleWriter[K, V]( override def write(records: Iterator[Product2[K, V]]): Unit = { if (!records.hasNext) { partitionLengths = new Array[Long](dep.partitioner.numPartitions) - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, null) + shuffleBlockResolver.writeMetadataFileAndCommit(dep.shuffleId, mapId, partitionLengths, Array[Long](), null) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) return } @@ -107,7 +107,7 @@ class ColumnarShuffleWriter[K, V]( jniWrapper.split(nativeSplitter, vb.getNativeVectorBatch) dep.splitTime.add(System.nanoTime() - startTime) dep.numInputRows.add(cb.numRows) - writeMetrics.incRecordsWritten(1) + writeMetrics.incRecordsWritten(cb.numRows) } } val startTime = System.nanoTime() @@ -122,10 +122,11 @@ class ColumnarShuffleWriter[K, V]( partitionLengths = splitResult.getPartitionLengths try { - shuffleBlockResolver.writeIndexFileAndCommit( + shuffleBlockResolver.writeMetadataFileAndCommit( dep.shuffleId, mapId, partitionLengths, + Array[Long](), dataTmp) } finally { if (dataTmp.exists() && !dataTmp.delete()) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/OmniColumnarShuffleManager.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/OmniColumnarShuffleManager.scala index e7c66ee726ae4b9090e41e5d71de386e4b94ed13..28427bba2842f77d53327121001966dbbdb17a01 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/OmniColumnarShuffleManager.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/OmniColumnarShuffleManager.scala @@ -99,7 +99,7 @@ class OmniColumnarShuffleManager(conf: SparkConf) extends ColumnarShuffleManager env.conf, metrics, shuffleExecutorComponents) - case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K@unchecked, V@unchecked] => + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, bypassMergeSortHandle, @@ -107,9 +107,8 @@ class OmniColumnarShuffleManager(conf: SparkConf) extends ColumnarShuffleManager env.conf, metrics, shuffleExecutorComponents) - case other: BaseShuffleHandle[K@unchecked, V@unchecked, _] => + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter( - shuffleBlockResolver, other, mapId, context, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index cb23b68f09bb085d86e133d3ef40628e8c5ca4c2..47dc5c80614b9b11f466a5164e8e3bdfbd966b53 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -101,6 +101,9 @@ case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPl |${ExplainUtils.generateFieldString("Input", child.output)} |""".stripMargin } + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarProjectExec = + copy(child = newChild) } case class ColumnarFilterExec(condition: Expression, child: SparkPlan) @@ -109,6 +112,10 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan) override def supportsColumnar: Boolean = true override def nodeName: String = "OmniColumnarFilter" + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarFilterExec = { + copy(this.condition, newChild) + } + // Split out all the IsNotNulls from condition. private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet) @@ -116,7 +123,7 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan) } // If one expression and its children are null intolerant, it is null intolerant. - private def isNullIntolerant(expr: Expression): Boolean = expr match { + override def isNullIntolerant(expr: Expression): Boolean = expr match { case e: NullIntolerant => e.children.forall(isNullIntolerant) case _ => false } @@ -267,6 +274,9 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarConditionProjectExec = + copy(child = newChild) + override lazy val metrics = Map( "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), @@ -383,7 +393,7 @@ case class ColumnarUnionExec(children: Seq[SparkPlan]) extends SparkPlan { children.map(_.output).transpose.map { attrs => val firstAttr = attrs.head val nullable = attrs.exists(_.nullable) - val newDt = attrs.map(_.dataType).reduce(StructType.merge) + val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge) if (firstAttr.dataType == newDt) { firstAttr.withNullability(nullable) } else { @@ -393,6 +403,10 @@ case class ColumnarUnionExec(children: Seq[SparkPlan]) extends SparkPlan { } } + override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = { + copy(children = newChildren) + } + def buildCheck(): Unit = { val inputTypes = new Array[DataType](output.size) output.zipWithIndex.foreach { @@ -420,7 +434,7 @@ class ColumnarRangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("numOutputRows") - sqlContext + session.sqlContext .sparkContext .parallelize(0 until numSlices, numSlices) .mapPartitionsWithIndex { (i, _) => diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala index d137388ab3c41c3ee103ac974cb594990379d394..1d236c16d3849906ec997726ca62ee5957ff0740 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala @@ -64,4 +64,7 @@ case class ColumnarBroadcastExchangeAdaptorExec(child: SparkPlan, numPartitions: "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "output_batches"), "processTime" -> SQLMetrics.createTimingMetric(sparkContext, "totaltime_datatoarrowcolumnar")) + + override protected def withNewChildInternal(newChild: SparkPlan): + ColumnarBroadcastExchangeAdaptorExec = copy(child = newChild) } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala index 72d1aae05d8f9e4a22a7f0bc17e68aca8b157d74..8a29e0d2bc1531351210fe7ae77b7da0577e2fa2 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala @@ -65,7 +65,7 @@ class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) @transient override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( - sqlContext.sparkSession, ColumnarBroadcastExchangeExec.executionContext) { + session.sqlContext.sparkSession, ColumnarBroadcastExchangeExec.executionContext) { try { // Setup a job group here so later it may get cancelled by groupId if necessary. sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)", @@ -159,6 +159,9 @@ class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) } } + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarBroadcastExchangeExec = + new ColumnarBroadcastExchangeExec(this.mode, newChild) + override protected def doPrepare(): Unit = { // Materialize the future. relationFuture diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala index 4af2ec06566b2b535d92dcd5765671fa26938716..75b0dfb7a297f117dab5ffacc068501b96ecb2a7 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.execution import java.util.concurrent.TimeUnit.NANOSECONDS - import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer - import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -31,9 +29,8 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.util.SparkMemoryUtils import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OmniColumnVector, WritableColumnVector} -import org.apache.spark.sql.types.{BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType} +import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType} import org.apache.spark.sql.vectorized.ColumnarBatch - import nova.hetu.omniruntime.vector.Vec /** @@ -101,6 +98,7 @@ private object OmniRowToColumnConverter { private def getConverterForType(dataType: DataType, nullable: Boolean): TypeConverter = { val core = dataType match { + case BinaryType => BinaryConverter case BooleanType => BooleanConverter case ByteType => ByteConverter case ShortType => ShortConverter @@ -123,6 +121,13 @@ private object OmniRowToColumnConverter { } } + private object BinaryConverter extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + val bytes = row.getBinary(column) + cv.appendByteArray(bytes, 0, bytes.length) + } + } + private object BooleanConverter extends TypeConverter { override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = cv.appendBoolean(row.getBoolean(column)) @@ -232,8 +237,11 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti "rowToOmniColumnarTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in row to OmniColumnar") ) + override protected def withNewChildInternal(newChild: SparkPlan): RowToOmniColumnarExec = + copy(child = newChild) + override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val enableOffHeapColumnVector = sqlContext.conf.offHeapColumnVectorEnabled + val enableOffHeapColumnVector = session.sqlContext.conf.offHeapColumnVectorEnabled val numInputRows = longMetric("numInputRows") val numOutputBatches = longMetric("numOutputBatches") val rowToOmniColumnarTime = longMetric("rowToOmniColumnarTime") @@ -285,7 +293,8 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti } -case class OmniColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransition { +case class OmniColumnarToRowExec(child: SparkPlan, + mayPartialFetch: Boolean = true) extends ColumnarToRowTransition { assert(child.supportsColumnar) override def nodeName: String = "OmniColumnarToRow" @@ -302,6 +311,14 @@ case class OmniColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransiti "omniColumnarToRowTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omniColumnar to row") ) + override def verboseStringWithOperatorId(): String = { + s""" + |$formattedNodeName + |$simpleStringWithNodeId + |${ExplainUtils.generateFieldString("mayPartialFetch", String.valueOf(mayPartialFetch))} + |""".stripMargin + } + override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val numInputBatches = longMetric("numInputBatches") @@ -310,16 +327,20 @@ case class OmniColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransiti // plan (this) in the closure. val localOutput = this.output child.executeColumnar().mapPartitionsInternal { batches => - ColumnarBatchToInternalRow.convert(localOutput, batches, numOutputRows, numInputBatches, omniColumnarToRowTime) + ColumnarBatchToInternalRow.convert(localOutput, batches, numOutputRows, numInputBatches, omniColumnarToRowTime, mayPartialFetch) } } + + override protected def withNewChildInternal(newChild: SparkPlan): + OmniColumnarToRowExec = copy(child = newChild) } object ColumnarBatchToInternalRow { def convert(output: Seq[Attribute], batches: Iterator[ColumnarBatch], numOutputRows: SQLMetric, numInputBatches: SQLMetric, - rowToOmniColumnarTime: SQLMetric): Iterator[InternalRow] = { + rowToOmniColumnarTime: SQLMetric, + mayPartialFetch: Boolean = true): Iterator[InternalRow] = { val startTime = System.nanoTime() val toUnsafe = UnsafeProjection.create(output, output) @@ -345,11 +366,13 @@ object ColumnarBatchToInternalRow { val numOutputRowsMetric: SQLMetric = numOutputRows var closed = false - SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => - // only invoke if fetch partial rows of batch - if (!closed) { - toClosedVecs.foreach {vec => - vec.close() + // only invoke if fetch partial rows of batch + if (mayPartialFetch) { + SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => + if (!closed) { + toClosedVecs.foreach {vec => + vec.close() + } } } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala index 27b05b16c017c43e73a0c3b6d4f05ea02d11f951..b25d97d604da1ae0cbaef04b34bbf53e61b8af83 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.execution import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP @@ -161,4 +178,6 @@ case class ColumnarExpandExec( throw new UnsupportedOperationException(s"ColumnarExpandExec operator doesn't support doExecute().") } + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarExpandExec = + copy(child = newChild) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index 73091d069cb311f129fef45078d936ad365e14e0..155e289aaa3243a1765a75387e3b9c1229a293af 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -496,20 +496,35 @@ abstract class BaseColumnarFileSourceScanExec( logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + s"open cost is considered as scanning $openCostInBytes bytes.") - val splitFiles = selectedPartitions.flatMap { partition => + // Filter files with bucket pruning if possible + val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: Path => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + // Do not prune the file if bucket file name is invalid + filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) + case _ => + _ => true + } + + var splitFiles = selectedPartitions.flatMap { partition => partition.files.flatMap { file => // getPath() is very expensive so we only want to call it once in this block: val filePath = file.getPath - val isSplitable = relation.fileFormat.isSplitable( - relation.sparkSession, relation.options, filePath) - PartitionedFileUtil.splitFiles( - sparkSession = relation.sparkSession, - file = file, - filePath = filePath, - isSplitable = isSplitable, - maxSplitBytes = maxSplitBytes, - partitionValues = partition.values - ) + + if (shouldProcess(filePath)) { + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, relation.options, filePath) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values + ) + } else { + Seq.empty + } } }.sortBy(_.length)(implicitly[Ordering[Long]].reverse) @@ -544,14 +559,19 @@ abstract class BaseColumnarFileSourceScanExec( val omniAggFunctionTypes = new Array[FunctionType](agg.aggregateExpressions.size) val omniAggOutputTypes = new Array[Array[DataType]](agg.aggregateExpressions.size) val omniAggChannels = new Array[Array[String]](agg.aggregateExpressions.size) + val omniAggChannelsFilter = new Array[String](agg.aggregateExpressions.size) var omniAggindex = 0 for (exp <- agg.aggregateExpressions) { + if (exp.filter.isDefined) { + omniAggChannelsFilter(omniAggindex) = + rewriteToOmniJsonExpressionLiteral(exp.filter.get, attrAggExpsIdMap) + } if (exp.mode == Final) { throw new UnsupportedOperationException(s"Unsupported final aggregate expression in operator fusion, exp: $exp") } else if (exp.mode == Partial) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Average(_) | Max(_) | Count(_) | First(_, _) => + case Sum(_, _) | Min(_) | Average(_, _) | Max(_) | Count(_) | First(_, _) => val aggExp = exp.aggregateFunction.children.head omniOutputExressionOrder += { exp.aggregateFunction.inputAggBufferAttributes.head.exprId -> @@ -569,7 +589,7 @@ abstract class BaseColumnarFileSourceScanExec( } } else if (exp.mode == PartialMerge) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Average(_) | Max(_) | Count(_) | First(_, _) => + case Sum(_, _) | Min(_) | Average(_, _) | Max(_) | Count(_) | First(_, _) => val aggExp = exp.aggregateFunction.children.head omniOutputExressionOrder += { exp.aggregateFunction.inputAggBufferAttributes.head.exprId -> @@ -604,8 +624,8 @@ abstract class BaseColumnarFileSourceScanExec( case (attr, i) => omniAggSourceTypes(i) = sparkTypeToOmniType(attr.dataType, attr.metadata) } - (omniGroupByChanel, omniAggChannels, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, - omniAggInputRaws, omniAggOutputPartials, resultIdxToOmniResultIdxMap) + (omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, omniAggSourceTypes, omniAggFunctionTypes, + omniAggOutputTypes, omniAggInputRaws, omniAggOutputPartials, resultIdxToOmniResultIdxMap) } def genProjectOutput(project: ColumnarProjectExec) = { @@ -834,8 +854,8 @@ case class ColumnarMultipleOperatorExec( val omniCodegenTime = longMetric("omniJitTime") val getOutputTime = longMetric("outputTime") - val (omniGroupByChanel, omniAggChannels, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, - omniAggInputRaw, omniAggOutputPartial, resultIdxToOmniResultIdxMap) = genAggOutput(aggregate) + val (omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, omniAggSourceTypes, omniAggFunctionTypes, + omniAggOutputTypes, omniAggInputRaw, omniAggOutputPartial, resultIdxToOmniResultIdxMap) = genAggOutput(aggregate) val (proj1OmniExpressions, proj1OmniInputTypes) = genProjectOutput(proj1) val (buildTypes1, buildJoinColsExp1, joinFilter1, probeTypes1, probeOutputCols1, probeHashColsExp1, buildOutputCols1, buildOutputTypes1, relation1) = genJoinOutput(join1) @@ -857,6 +877,7 @@ case class ColumnarMultipleOperatorExec( val aggOperator = OmniAdaptorUtil.getAggOperator(aggregate.groupingExpressions, omniGroupByChanel, omniAggChannels, + omniAggChannelsFilter, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, @@ -1181,8 +1202,8 @@ case class ColumnarMultipleOperatorExec1( val omniCodegenTime = longMetric("omniJitTime") val getOutputTime = longMetric("outputTime") - val (omniGroupByChanel, omniAggChannels, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, - omniAggInputRaw, omniAggOutputPartial, resultIdxToOmniResultIdxMap) = genAggOutput(aggregate) + val (omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, omniAggSourceTypes, omniAggFunctionTypes, + omniAggOutputTypes, omniAggInputRaw, omniAggOutputPartial, resultIdxToOmniResultIdxMap) = genAggOutput(aggregate) val (proj1OmniExpressions, proj1OmniInputTypes) = genProjectOutput(proj1) val (buildTypes1, buildJoinColsExp1, joinFilter1, probeTypes1, probeOutputCols1, probeHashColsExp1, buildOutputCols1, buildOutputTypes1, relation1, reserved1) = genJoinOutputWithReverse(join1) @@ -1217,6 +1238,7 @@ case class ColumnarMultipleOperatorExec1( val aggOperator = OmniAdaptorUtil.getAggOperator(aggregate.groupingExpressions, omniGroupByChanel, omniAggChannels, + omniAggChannelsFilter, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala index e2618842a7dd6a2170b8738be2a299cca2a86d47..9c007243db067843c32fcb20e8e7d28380a23a2e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala @@ -54,6 +54,9 @@ case class ColumnarHashAggregateExec( extends BaseAggregateExec with AliasAwareOutputPartitioning { + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarHashAggregateExec = + copy(child = newChild) + override def verboseStringWithOperatorId(): String = { s""" |$formattedNodeName @@ -92,14 +95,16 @@ case class ColumnarHashAggregateExec( val omniAggFunctionTypes = new Array[FunctionType](aggregateExpressions.size) val omniAggOutputTypes = new Array[Array[DataType]](aggregateExpressions.size) var omniAggChannels = new Array[Array[String]](aggregateExpressions.size) + val omniAggChannelsFilter = new Array[String](aggregateExpressions.size) var index = 0 for (exp <- aggregateExpressions) { if (exp.filter.isDefined) { - throw new UnsupportedOperationException("Unsupported filter in AggregateExpression") + omniAggChannelsFilter(index) = + rewriteToOmniJsonExpressionLiteral(exp.filter.get, attrExpsIdMap) } if (exp.mode == Final) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.dataType) omniAggChannels(index) = @@ -110,22 +115,19 @@ case class ColumnarHashAggregateExec( } } else if (exp.mode == PartialMerge) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => - omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) => + omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) omniAggChannels(index) = toOmniAggInOutJSonExp(exp.aggregateFunction.inputAggBufferAttributes, attrExpsIdMap) omniInputRaws(index) = false omniOutputPartials(index) = true - if (omniAggFunctionTypes(index) == OMNI_AGGREGATION_TYPE_COUNT_ALL) { - omniAggChannels(index) = null - } case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") } } else if (exp.mode == Partial) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) @@ -150,7 +152,7 @@ case class ColumnarHashAggregateExec( omniSourceTypes(i) = sparkTypeToOmniType(attr.dataType, attr.metadata) } - for (aggChannel <-omniAggChannels) { + for (aggChannel <- omniAggChannels) { if (!isSimpleColumnForAll(aggChannel)) { checkOmniJsonWhiteList("", aggChannel.toArray) } @@ -160,6 +162,12 @@ case class ColumnarHashAggregateExec( checkOmniJsonWhiteList("", omniGroupByChanel) } + for (filter <- omniAggChannelsFilter) { + if (filter != null && !isSimpleColumn(filter)) { + checkOmniJsonWhiteList(filter, new Array[AnyRef](0)) + } + } + // final steps contail all Final mode aggregate if (aggregateExpressions.filter(_.mode == Final).size == aggregateExpressions.size) { val finalOut = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes @@ -191,6 +199,7 @@ case class ColumnarHashAggregateExec( val omniAggFunctionTypes = new Array[FunctionType](aggregateExpressions.size) val omniAggOutputTypes = new Array[Array[DataType]](aggregateExpressions.size) var omniAggChannels = new Array[Array[String]](aggregateExpressions.size) + val omniAggChannelsFilter = new Array[String](aggregateExpressions.size) val finalStep = (aggregateExpressions.filter (_.mode == Final).size == aggregateExpressions.size) @@ -198,11 +207,12 @@ case class ColumnarHashAggregateExec( var index = 0 for (exp <- aggregateExpressions) { if (exp.filter.isDefined) { - throw new UnsupportedOperationException("Unsupported filter in AggregateExpression") + omniAggChannelsFilter(index) = + rewriteToOmniJsonExpressionLiteral(exp.filter.get, attrExpsIdMap) } if (exp.mode == Final) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_, _) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.dataType) @@ -214,22 +224,19 @@ case class ColumnarHashAggregateExec( } } else if (exp.mode == PartialMerge) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => - omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_, _) => + omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) omniAggChannels(index) = toOmniAggInOutJSonExp(exp.aggregateFunction.inputAggBufferAttributes, attrExpsIdMap) omniInputRaws(index) = false omniOutputPartials(index) = true - if (omniAggFunctionTypes(index) == OMNI_AGGREGATION_TYPE_COUNT_ALL) { - omniAggChannels(index) = null - } case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") } } else if (exp.mode == Partial) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_, _) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) @@ -260,6 +267,7 @@ case class ColumnarHashAggregateExec( val operator = OmniAdaptorUtil.getAggOperator(groupingExpressions, omniGroupByChanel, omniAggChannels, + omniAggChannelsFilter, omniSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index cea0a1438b1c64a0d1372e2a272742aa9be08502..2303d7ee138bbe322b2357e995538ca6f081a2ad 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -41,6 +41,7 @@ import org.apache.spark.shuffle.ColumnarShuffleDependency import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.createShuffleWriteProcessor @@ -53,16 +54,17 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.MutablePair +import org.apache.spark.util.random.XORShiftRandom -class ColumnarShuffleExchangeExec( - override val outputPartitioning: Partitioning, - child: SparkPlan, - shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) - extends ShuffleExchangeExec(outputPartitioning, child, shuffleOrigin) with ShuffleExchangeLike{ +case class ColumnarShuffleExchangeExec( + override val outputPartitioning: Partitioning, + child: SparkPlan, + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) + extends ShuffleExchangeLike { private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) - override lazy val readMetrics = + private[sql] lazy val readMetrics = SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics: Map[String, SQLMetric] = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), @@ -100,6 +102,16 @@ class ColumnarShuffleExchangeExec( override def numPartitions: Int = columnarShuffleDependency.partitioner.numPartitions + override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[ColumnarBatch] = { + new ShuffledColumnarRDD(columnarShuffleDependency, readMetrics, partitionSpecs) + } + + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value + Statistics(dataSize, Some(rowCount)) + } + @transient lazy val columnarShuffleDependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { ColumnarShuffleExchangeExec.prepareShuffleDependency( @@ -155,6 +167,8 @@ class ColumnarShuffleExchangeExec( cachedShuffleRDD } } + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarShuffleExchangeExec = + copy(child = newChild) } object ColumnarShuffleExchangeExec extends Logging { @@ -229,7 +243,8 @@ object ColumnarShuffleExchangeExec extends Logging { (columnarBatch: ColumnarBatch, numPartitions: Int) => { val pidArr = new Array[Int](columnarBatch.numRows()) for (i <- 0 until columnarBatch.numRows()) { - val position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions) + val partitionId = TaskContext.get().partitionId() + val position = new XORShiftRandom(partitionId).nextInt(numPartitions) pidArr(i) = position + 1 } val vec = new IntVec(columnarBatch.numRows()) @@ -324,6 +339,7 @@ object ColumnarShuffleExchangeExec extends Logging { rdd.mapPartitionsWithIndexInternal((_, cbIter) => { cbIter.map { cb => (0, cb) } }, isOrderSensitive = isOrderSensitive) + case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") } val numCols = outputAttributes.size @@ -341,6 +357,7 @@ object ColumnarShuffleExchangeExec extends Logging { new PartitionInfo("hash", numPartitions, numCols, intputTypes) case RangePartitioning(ordering, numPartitions) => new PartitionInfo("range", numPartitions, numCols, intputTypes) + case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") } new ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala index 7c7001dbc1c468465a0115946aeff9849d51a3df..49f2451112f66915d95b57e76dfdd8203a2af635 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala @@ -56,6 +56,9 @@ case class ColumnarSortExec( override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarSortExec = + copy(child = newChild) + override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala index 6fec9f9a054f345a83bc20f278c2ed3be57e6dbd..92efd4d539e04b141c378bb0e3d0b8d52fd73f4a 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala @@ -49,6 +49,9 @@ case class ColumnarTakeOrderedAndProjectExec( override def nodeName: String = "OmniColumnarTakeOrderedAndProject" + override protected def withNewChildInternal(newChild: SparkPlan): + ColumnarTakeOrderedAndProjectExec = copy(child = newChild) + val serializer: Serializer = new ColumnarBatchSerializer( longMetric("avgReadBatchNumRows"), longMetric("numOutputRows")) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala index e5534d3c67680a747f1b4d92fcb2c377c81577c9..4b5da24b793b3ac571900340355bf06e81503202 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala @@ -46,10 +46,16 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], orderSpec: Seq[SortOrder], child: SparkPlan) extends WindowExecBase { + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) + override def nodeName: String = "OmniColumnarWindow" override def supportsColumnar: Boolean = true + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarWindowExec = + copy(child = newChild) + override lazy val metrics = Map( "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), @@ -59,25 +65,6 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) - override def output: Seq[Attribute] = - child.output ++ windowExpression.map(_.toAttribute) - - override def requiredChildDistribution: Seq[Distribution] = { - if (partitionSpec.isEmpty) { - // Only show warning when the number of bytes is larger than 100 MiB? - logWarning("No Partition Defined for Window operation! Moving all data to a single " - + "partition, this can cause serious performance degradation.") - AllTuples :: Nil - } else ClusteredDistribution(partitionSpec) :: Nil - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputPartitioning: Partitioning = child.outputPartitioning - override protected def doExecute(): RDD[InternalRow] = { throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") } @@ -217,7 +204,7 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], val winExpressions: Seq[Expression] = windowFrameExpressionFactoryPairs.flatMap(_._1) val windowFunType = new Array[FunctionType](winExpressions.size) val omminPartitionChannels = new Array[Int](partitionSpec.size) - val preGroupedChannels = new Array[Int](winExpressions.size) + val preGroupedChannels = new Array[Int](0) var windowArgKeys = new Array[String](winExpressions.size) val windowFunRetType = new Array[DataType](winExpressions.size) val omniAttrExpsIdMap = getExprIdMap(child.output) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala index 1e728239b988592aa7ef8e89cba4cccf0751c065..7f664121bc7f309d3c3f1226ba49a3afb2e231b9 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala @@ -24,6 +24,43 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsRe import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch +sealed trait ShufflePartitionSpec + +// A partition that reads data of one or more reducers, from `startReducerIndex` (inclusive) to +// `endReducerIndex` (exclusive). +case class CoalescedPartitionSpec( + startReducerIndex: Int, + endReducerIndex: Int, + @transient dataSize: Option[Long] = None) extends ShufflePartitionSpec + +object CoalescedPartitionSpec { + def apply(startReducerIndex: Int, + endReducerIndex: Int, + dataSize: Long): CoalescedPartitionSpec = { + CoalescedPartitionSpec(startReducerIndex, endReducerIndex, Some(dataSize)) + } +} + +// A partition that reads partial data of one reducer, from `startMapIndex` (inclusive) to +// `endMapIndex` (exclusive). +case class PartialReducerPartitionSpec( + reducerIndex: Int, + startMapIndex: Int, + endMapIndex: Int, + @transient dataSize: Long) extends ShufflePartitionSpec + +// A partition that reads partial data of one mapper, from `startReducerIndex` (inclusive) to +// `endReducerIndex` (exclusive). +case class PartialMapperPartitionSpec( + mapIndex: Int, + startReducerIndex: Int, + endReducerIndex: Int) extends ShufflePartitionSpec + +case class CoalescedMapperPartitionSpec( + startMapIndex: Int, + endMapIndex: Int, + numReducers: Int) extends ShufflePartitionSpec + /** * The [[Partition]] used by [[ShuffledRowRDD]]. */ @@ -70,7 +107,7 @@ class ShuffledColumnarRDD( override def getPreferredLocations(partition: Partition): Seq[String] = { val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] partition.asInstanceOf[ShuffledColumnarRDDPartition].spec match { - case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => startReducerIndex.until(endReducerIndex).flatMap { reducerIndex => tracker.getPreferredLocationsForShuffle(dependency, reducerIndex) } @@ -80,6 +117,9 @@ class ShuffledColumnarRDD( case PartialMapperPartitionSpec(mapIndex, _, _) => tracker.getMapLocation(dependency, mapIndex, mapIndex + 1) + + case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, numReducers) => + tracker.getMapLocation(dependency, startMapIndex, endMapIndex) } } @@ -89,7 +129,7 @@ class ShuffledColumnarRDD( // as well as the `tempMetrics` for basic shuffle metrics. val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) val reader = split.asInstanceOf[ShuffledColumnarRDDPartition].spec match { - case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => SparkEnv.get.shuffleManager.getReader( dependency.shuffleHandle, startReducerIndex, @@ -116,7 +156,22 @@ class ShuffledColumnarRDD( endReducerIndex, context, sqlMetricsReporter) + + case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, numReducers) => + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + startMapIndex, + endMapIndex, + 0, + numReducers, + context, + sqlMetricsReporter) } reader.read().asInstanceOf[Iterator[Product2[Int, ColumnarBatch]]].map(_._2) } + + override def clearDependencies(): Unit = { + super.clearDependencies() + dependency = null + } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala index d34b93e5b0da5b61ac35c0824acbf817f1a5e938..be4efd90cfc9a1617d4be7bf5121b4ac97d1e50d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition, UnknownPartitioning} +import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -36,7 +37,7 @@ import scala.collection.mutable.ArrayBuffer * node during canonicalization. * @param partitionSpecs The partition specs that defines the arrangement. */ -case class ColumnarCustomShuffleReaderExec( +case class OmniAQEShuffleReadExec( child: SparkPlan, partitionSpecs: Seq[ShufflePartitionSpec]) extends UnaryExecNode { @@ -57,9 +58,9 @@ case class ColumnarCustomShuffleReaderExec( partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size == partitionSpecs.length) { child match { - case ShuffleQueryStageExec(_, s: ShuffleExchangeLike) => + case ShuffleQueryStageExec(_, s: ShuffleExchangeLike, _) => s.child.outputPartitioning - case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeLike)) => + case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeLike), _) => s.child.outputPartitioning match { case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning] case other => other @@ -67,13 +68,34 @@ case class ColumnarCustomShuffleReaderExec( case _ => throw new IllegalStateException("operating on canonicalization plan") } + } else if (isCoalescedRead) { + // For coalesced shuffle read, the data distribution is not changed, only the number of + // partitions is changed. + child.outputPartitioning match { + case h: HashPartitioning => + CurrentOrigin.withOrigin(h.origin)(h.copy(numPartitions = partitionSpecs.length)) + case r: RangePartitioning => + CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = partitionSpecs.length)) + // This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses + // `RoundRobinPartitioning` but we don't need to retain the number of partitions. + case r: RoundRobinPartitioning => + r.copy(numPartitions = partitionSpecs.length) + case other @ SinglePartition => + throw new IllegalStateException( + "Unexpected partitioning for coalesced shuffle read: " + other) + case _ => + // Spark plugins may have custom partitioning and may replace this operator + // during the postStageOptimization phase, so return UnknownPartitioning here + // rather than throw an exception + UnknownPartitioning(partitionSpecs.length) + } } else { UnknownPartitioning(partitionSpecs.length) } } override def stringArgs: Iterator[Any] = { - val desc = if (isLocalReader) { + val desc = if (isLocalRead) { "local" } else if (hasCoalescedPartition && hasSkewedPartition) { "coalesced and skewed" @@ -87,14 +109,38 @@ case class ColumnarCustomShuffleReaderExec( Iterator(desc) } - def hasCoalescedPartition: Boolean = - partitionSpecs.exists(_.isInstanceOf[CoalescedPartitionSpec]) + /** + * Returns true iff some partitions were actually combined + */ + private def isCoalescedSpec(spec: ShufflePartitionSpec) = spec match { + case CoalescedPartitionSpec(0, 0, _) => true + case s: CoalescedPartitionSpec => s.endReducerIndex - s.startReducerIndex > 1 + case _ => false + } + + /** + * Returns true iff some non-empty partitions were combined + */ + def hasCoalescedPartition: Boolean = { + partitionSpecs.exists(isCoalescedSpec) + } def hasSkewedPartition: Boolean = partitionSpecs.exists(_.isInstanceOf[PartialReducerPartitionSpec]) - def isLocalReader: Boolean = - partitionSpecs.exists(_.isInstanceOf[PartialMapperPartitionSpec]) + def isLocalRead: Boolean = + partitionSpecs.exists(_.isInstanceOf[PartialMapperPartitionSpec]) || + partitionSpecs.exists(_.isInstanceOf[CoalescedMapperPartitionSpec]) + + def isCoalescedRead: Boolean = { + partitionSpecs.sliding(2).forall { + // A single partition spec which is `CoalescedPartitionSpec` also means coalesced read. + case Seq(_: CoalescedPartitionSpec) => true + case Seq(l: CoalescedPartitionSpec, r: CoalescedPartitionSpec) => + l.endReducerIndex <= r.startReducerIndex + case _ => false + } + } private def shuffleStage = child match { case stage: ShuffleQueryStageExec => Some(stage) @@ -102,13 +148,13 @@ case class ColumnarCustomShuffleReaderExec( } @transient private lazy val partitionDataSizes: Option[Seq[Long]] = { - if (partitionSpecs.nonEmpty && !isLocalReader && shuffleStage.get.mapStats.isDefined) { - val bytesByPartitionId = shuffleStage.get.mapStats.get.bytesByPartitionId + if (!isLocalRead && shuffleStage.get.mapStats.isDefined) { Some(partitionSpecs.map { - case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => - startReducerIndex.until(endReducerIndex).map(bytesByPartitionId).sum + case p: CoalescedPartitionSpec => + assert(p.dataSize.isDefined) + p.dataSize.get case p: PartialReducerPartitionSpec => p.dataSize - case p => throw new IllegalStateException("unexpected " + p) + case p => throw new IllegalStateException(s"unexpected $p") }) } else { None @@ -141,6 +187,13 @@ case class ColumnarCustomShuffleReaderExec( driverAccumUpdates += (skewedSplits.id -> numSplits) } + if (hasCoalescedPartition) { + val numCoalescedPartitionsMetric = metrics("numCoalescedPartitions") + val x = partitionSpecs.count(isCoalescedSpec) + numCoalescedPartitionsMetric.set(x) + driverAccumUpdates += numCoalescedPartitionsMetric.id -> x + } + partitionDataSizes.foreach { dataSizes => val partitionDataSizeMetrics = metrics("partitionDataSize") driverAccumUpdates ++= dataSizes.map(partitionDataSizeMetrics.id -> _) @@ -154,8 +207,8 @@ case class ColumnarCustomShuffleReaderExec( @transient override lazy val metrics: Map[String, SQLMetric] = { if (shuffleStage.isDefined) { Map("numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions")) ++ { - if (isLocalReader) { - // We split the mapper partition evenly when creating local shuffle reader, so no + if (isLocalRead) { + // We split the mapper partition evenly when creating local shuffle read, so no // data size info is available. Map.empty } else { @@ -171,6 +224,13 @@ case class ColumnarCustomShuffleReaderExec( } else { Map.empty } + } ++ { + if (hasCoalescedPartition) { + Map("numCoalescedPartitions" -> + SQLMetrics.createMetric(sparkContext, "number of coalesced partitions")) + } else { + Map.empty + } } } else { // It's a canonicalized plan, no need to report metrics. @@ -178,24 +238,19 @@ case class ColumnarCustomShuffleReaderExec( } } - private var cachedShuffleRDD: RDD[ColumnarBatch] = null - private lazy val shuffleRDD: RDD[_] = { - sendDriverMetrics() - if (cachedShuffleRDD == null) { - cachedShuffleRDD = child match { - case stage: ShuffleQueryStageExec => - new ShuffledColumnarRDD( - stage.shuffle - .asInstanceOf[ColumnarShuffleExchangeExec] - .columnarShuffleDependency, - stage.shuffle.asInstanceOf[ColumnarShuffleExchangeExec].readMetrics, - partitionSpecs.toArray) - case _ => - throw new IllegalStateException("operating on canonicalized plan") - } + shuffleStage match { + case Some(stage) => + sendDriverMetrics() + new ShuffledColumnarRDD( + stage.shuffle + .asInstanceOf[ColumnarShuffleExchangeExec] + .columnarShuffleDependency, + stage.shuffle.asInstanceOf[ColumnarShuffleExchangeExec].readMetrics, + partitionSpecs.toArray) + case _ => + throw new IllegalStateException("operating on canonicalized plan") } - cachedShuffleRDD } override protected def doExecute(): RDD[InternalRow] = { @@ -205,4 +260,7 @@ case class ColumnarCustomShuffleReaderExec( override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { shuffleRDD.asInstanceOf[RDD[ColumnarBatch]] } + + override protected def withNewChildInternal(newChild: SparkPlan): OmniAQEShuffleReadExec = + new OmniAQEShuffleReadExec(newChild, this.partitionSpecs) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala index 0e5a7eae6efaac64d59c9effdcd8304d30c5c9fe..ce32bb25d50a080be91f460ae9c608fb363fb34d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala @@ -93,7 +93,7 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ } else { // ORC predicate pushdown if (orcFilterPushDown) { - OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { + OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f => OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala index aeb7d4ccb64b4ec83e5fa476c6949fbf61e4e525..644bfe1ff80f138a440ed6fe3ce4e52d6dc75745 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala @@ -25,7 +25,7 @@ import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, isSimpleColumn, isSimpleColumnForAll} import com.huawei.boostkit.spark.util.OmniAdaptorUtil -import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getIndexArray, pruneOutput, reorderVecs, transColBatchToOmniVecs} import nova.hetu.omniruntime.`type`.DataType import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} import nova.hetu.omniruntime.operator.join.{OmniHashBuilderWithExprOperatorFactory, OmniLookupJoinWithExprOperatorFactory} @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{CodegenSupport, ColumnarHashedRelation, SparkPlan} +import org.apache.spark.sql.execution.{CodegenSupport, ColumnarHashedRelation, ExplainUtils, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} import org.apache.spark.sql.execution.vectorized.OmniColumnVector @@ -65,6 +65,24 @@ case class ColumnarBroadcastHashJoinExec( projectList: Seq[NamedExpression] = Seq.empty) extends HashJoin { + override def verboseStringWithOperatorId(): String = { + val joinCondStr = if (condition.isDefined) { + s"${condition.get}${condition.get.dataType}" + } else "None" + s""" + |$formattedNodeName + |$simpleStringWithNodeId + |${ExplainUtils.generateFieldString("buildOutput", buildOutput ++ buildOutput.map(_.dataType))} + |${ExplainUtils.generateFieldString("streamedOutput", streamedOutput ++ streamedOutput.map(_.dataType))} + |${ExplainUtils.generateFieldString("leftKeys", leftKeys ++ leftKeys.map(_.dataType))} + |${ExplainUtils.generateFieldString("rightKeys", rightKeys ++ rightKeys.map(_.dataType))} + |${ExplainUtils.generateFieldString("condition", joinCondStr)} + |${ExplainUtils.generateFieldString("projectList", projectList.map(_.toAttribute) ++ projectList.map(_.toAttribute).map(_.dataType))} + |${ExplainUtils.generateFieldString("output", output ++ output.map(_.dataType))} + |Condition : $condition + |""".stripMargin + } + if (isNullAwareAntiJoin) { require(leftKeys.length == 1, "leftKeys length should be 1") require(rightKeys.length == 1, "rightKeys length should be 1") @@ -97,6 +115,9 @@ case class ColumnarBroadcastHashJoinExec( override def nodeName: String = "OmniColumnarBroadcastHashJoin" + override protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): + ColumnarBroadcastHashJoinExec = copy(left = newLeft, right = newRight) + override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildBoundKeys, isNullAwareAntiJoin) buildSide match { @@ -109,7 +130,7 @@ case class ColumnarBroadcastHashJoinExec( override lazy val outputPartitioning: Partitioning = { joinType match { - case Inner if sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => + case Inner if session.sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => streamedPlan.outputPartitioning match { case h: HashPartitioning => expandOutputPartitioning(h) case c: PartitioningCollection => expandOutputPartitioning(c) @@ -150,7 +171,7 @@ case class ColumnarBroadcastHashJoinExec( // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). // The expanded expressions are returned as PartitioningCollection. private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { - val maxNumCombinations = sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit + val maxNumCombinations = session.sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit var currentNumCombinations = 0 def generateExprCombinations( @@ -315,9 +336,20 @@ case class ColumnarBroadcastHashJoinExec( val buildOp = buildOpFactory.createOperator() buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) + val startLookupCodegen = System.nanoTime() + val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) + val lookupOpFactory = new OmniLookupJoinWithExprOperatorFactory(probeTypes, probeOutputCols, + probeHashColsExp, buildOutputCols, buildOutputTypes, lookupJoinType, buildOpFactory, + new OperatorConfig(SpillConfig.NONE, + new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + val lookupOp = lookupOpFactory.createOperator() + lookupCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupCodegen) + // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + lookupOp.close() buildOp.close() + lookupOpFactory.close() buildOpFactory.close() }) @@ -331,21 +363,6 @@ case class ColumnarBroadcastHashJoinExec( buildOp.getOutput buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp) - val startLookupCodegen = System.nanoTime() - val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) - val lookupOpFactory = new OmniLookupJoinWithExprOperatorFactory(probeTypes, probeOutputCols, - probeHashColsExp, buildOutputCols, buildOutputTypes, lookupJoinType, buildOpFactory, - new OperatorConfig(SpillConfig.NONE, - new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) - val lookupOp = lookupOpFactory.createOperator() - lookupCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupCodegen) - - // close operator - SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { - lookupOp.close() - lookupOpFactory.close() - }) - val streamedPlanOutput = pruneOutput(streamedPlan.output, projectList) val prunedOutput = streamedPlanOutput ++ prunedBuildOutput val resultSchema = this.schema @@ -422,9 +439,11 @@ case class ColumnarBroadcastHashJoinExec( index += 1 } } - numOutputRows += result.getRowCount + val rowCnt: Int = result.getRowCount + numOutputRows += rowCnt numOutputVecBatchs += 1 - new ColumnarBatch(vecs.toArray, result.getRowCount) + result.close() + new ColumnarBatch(vecs.toArray, rowCnt) } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala index 792bddcf1e933f8bd101482dc66f68e6a055e89e..8e3f8383f3a20d2e9720c405fbd86c8d506097d8 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala @@ -19,25 +19,23 @@ package org.apache.spark.sql.execution.joins import java.util.Optional import java.util.concurrent.TimeUnit.NANOSECONDS - import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, isSimpleColumn, isSimpleColumnForAll} import com.huawei.boostkit.spark.util.OmniAdaptorUtil -import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getIndexArray, pruneOutput, reorderVecs, transColBatchToOmniVecs} import nova.hetu.omniruntime.`type`.DataType import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} -import nova.hetu.omniruntime.operator.join._ +import nova.hetu.omniruntime.operator.join.{OmniHashBuilderWithExprOperatorFactory, OmniLookupJoinWithExprOperatorFactory, OmniLookupOuterJoinWithExprOperatorFactory} import nova.hetu.omniruntime.vector.VecBatch - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildSide} -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi} +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.util.SparkMemoryUtils import org.apache.spark.sql.execution.vectorized.OmniColumnVector @@ -50,9 +48,29 @@ case class ColumnarShuffledHashJoinExec( buildSide: BuildSide, condition: Option[Expression], left: SparkPlan, - right: SparkPlan) + right: SparkPlan, + isSkewJoin: Boolean, + projectList: Seq[NamedExpression] = Seq.empty) extends HashJoin with ShuffledJoin { + override def verboseStringWithOperatorId(): String = { + val joinCondStr = if (condition.isDefined) { + s"${condition.get}${condition.get.dataType}" + } else "None" + s""" + |$formattedNodeName + |$simpleStringWithNodeId + |${ExplainUtils.generateFieldString("buildOutput", buildOutput ++ buildOutput.map(_.dataType))} + |${ExplainUtils.generateFieldString("streamedOutput", streamedOutput ++ streamedOutput.map(_.dataType))} + |${ExplainUtils.generateFieldString("leftKeys", leftKeys ++ leftKeys.map(_.dataType))} + |${ExplainUtils.generateFieldString("rightKeys", rightKeys ++ rightKeys.map(_.dataType))} + |${ExplainUtils.generateFieldString("condition", joinCondStr)} + |${ExplainUtils.generateFieldString("projectList", projectList.map(_.toAttribute) ++ projectList.map(_.toAttribute).map(_.dataType))} + |${ExplainUtils.generateFieldString("output", output ++ output.map(_.dataType))} + |Condition : $condition + |""".stripMargin + } + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "lookupAddInputTime" -> SQLMetrics.createTimingMetric(sparkContext, @@ -77,10 +95,19 @@ case class ColumnarShuffledHashJoinExec( override def nodeName: String = "OmniColumnarShuffledHashJoin" - override def output: Seq[Attribute] = super[ShuffledJoin].output + override def output: Seq[Attribute] = { + if (projectList.nonEmpty) { + projectList.map(_.toAttribute) + } else { + super[ShuffledJoin].output + } + } override def outputPartitioning: Partitioning = super[ShuffledJoin].outputPartitioning + override protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): + ColumnarShuffledHashJoinExec = copy(left = newLeft, right = newRight) + override def outputOrdering: Seq[SortOrder] = joinType match { case FullOuter => Nil case _ => super.outputOrdering @@ -159,7 +186,7 @@ case class ColumnarShuffledHashJoinExec( val buildOutputCols: Array[Int] = joinType match { case Inner | FullOuter | LeftOuter => - buildOutput.indices.toArray + getIndexArray(buildOutput, projectList) case LeftExistence(_) => Array[Int]() case x => @@ -171,11 +198,17 @@ case class ColumnarShuffledHashJoinExec( OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))) }.toArray + val prunedBuildOutput = pruneOutput(buildOutput, projectList) + val buildOutputTypes = new Array[DataType](prunedBuildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 + prunedBuildOutput.zipWithIndex.foreach { case (att, i) => + buildOutputTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } + val probeTypes = new Array[DataType](streamedOutput.size) streamedOutput.zipWithIndex.foreach { case (attr, i) => probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) } - val probeOutputCols = streamedOutput.indices.toArray + val probeOutputCols = getIndexArray(streamedOutput, projectList) val probeHashColsExp = streamedKeys.map { x => OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute))) @@ -197,8 +230,19 @@ case class ColumnarShuffledHashJoinExec( val buildOp = buildOpFactory.createOperator() buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) + val startLookupCodegen = System.nanoTime() + val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) + val lookupOpFactory = new OmniLookupJoinWithExprOperatorFactory(probeTypes, + probeOutputCols, probeHashColsExp, buildOutputCols, buildOutputTypes, lookupJoinType, + buildOpFactory, new OperatorConfig(SpillConfig.NONE, + new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + val lookupOp = lookupOpFactory.createOperator() + lookupCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupCodegen) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + lookupOp.close() buildOp.close() + lookupOpFactory.close() buildOpFactory.close() }) @@ -219,32 +263,19 @@ case class ColumnarShuffledHashJoinExec( buildOp.getOutput buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp) - val startLookupCodegen = System.nanoTime() - val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) - val lookupOpFactory = new OmniLookupJoinWithExprOperatorFactory(probeTypes, - probeOutputCols, probeHashColsExp, buildOutputCols, buildTypes, lookupJoinType, - buildOpFactory, new OperatorConfig(SpillConfig.NONE, - new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) - - val lookupOp = lookupOpFactory.createOperator() - lookupCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupCodegen) - - SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { - lookupOp.close() - lookupOpFactory.close() - }) - + val streamedPlanOutput = pruneOutput(streamedPlan.output, projectList) + val prunedOutput = streamedPlanOutput ++ prunedBuildOutput val resultSchema = this.schema val reverse = buildSide == BuildLeft var left = 0 - var leftLen = streamedPlan.output.size - var right = streamedPlan.output.size + var leftLen = streamedPlanOutput.size + var right = streamedPlanOutput.size var rightLen = output.size if (reverse) { - left = streamedPlan.output.size + left = streamedPlanOutput.size leftLen = output.size right = 0 - rightLen = streamedPlan.output.size + rightLen = streamedPlanOutput.size } val joinIter: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] { @@ -287,28 +318,34 @@ case class ColumnarShuffledHashJoinExec( val resultVecs = result.getVectors val vecs = OmniColumnVector .allocateColumns(result.getRowCount, resultSchema, false) - var index = 0 - for (i <- left until leftLen) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(i)) - index += 1 - } - for (i <- right until rightLen) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(i)) - index += 1 + if (projectList.nonEmpty) { + reorderVecs(prunedOutput, projectList, resultVecs, vecs) + } else { + var index = 0 + for (i <- left until leftLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } + for (i <- right until rightLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } } - numOutputRows += result.getRowCount + val rowCnt: Int = result.getRowCount + numOutputRows += rowCnt numOutputVecBatchs += 1 - new ColumnarBatch(vecs.toArray, result.getRowCount) + result.close() + new ColumnarBatch(vecs.toArray, rowCnt) } } if ("FULL OUTER" == joinType.sql) { val lookupOuterOpFactory = new OmniLookupOuterJoinWithExprOperatorFactory(probeTypes, probeOutputCols, - probeHashColsExp, buildOutputCols, buildTypes, buildOpFactory, + probeHashColsExp, buildOutputCols, buildOutputTypes, buildOpFactory, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) @@ -334,18 +371,22 @@ case class ColumnarShuffledHashJoinExec( val resultVecs = result.getVectors val vecs = OmniColumnVector .allocateColumns(result.getRowCount, resultSchema, false) - var index = 0 - for (i <- left until leftLen) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(i)) - index += 1 - } - for (i <- right until rightLen) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(i)) - index += 1 + if (projectList.nonEmpty) { + reorderVecs(prunedOutput, projectList, resultVecs, vecs) + } else { + var index = 0 + for (i <- left until leftLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } + for (i <- right until rightLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } } numOutputRows += result.getRowCount numOutputVecBatchs += 1 diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala index 8925d05bfb89773d6d4dec000877b84f3cadd28b..bd93e0a5442de46c8318d9dce98022bebcea9511 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala @@ -25,15 +25,19 @@ import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, isSimpleColumn, isSimpleColumnForAll} import com.huawei.boostkit.spark.util.OmniAdaptorUtil -import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getIndexArray, pruneOutput, reorderVecs, transColBatchToOmniVecs} import nova.hetu.omniruntime.`type`.DataType import nova.hetu.omniruntime.constants.JoinType._ import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} import nova.hetu.omniruntime.operator.join.{OmniSmjBufferedTableWithExprOperatorFactory, OmniSmjStreamedTableWithExprOperatorFactory} import nova.hetu.omniruntime.vector.{BooleanVec, Decimal128Vec, DoubleVec, IntVec, LongVec, VarcharVec, Vec, VecBatch, ShortVec} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} @@ -43,22 +47,16 @@ import org.apache.spark.sql.vectorized.ColumnarBatch /** * Performs a sort merge join of two child relations. */ -class ColumnarSortMergeJoinExec( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan, - isSkewJoin: Boolean = false) - extends SortMergeJoinExec( +case class ColumnarSortMergeJoinExec( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isSkewJoin: Boolean) with CodegenSupport { + isSkewJoin: Boolean = false, + projectList: Seq[NamedExpression] = Seq.empty) + extends ShuffledJoin with CodegenSupport { override def supportsColumnar: Boolean = true @@ -68,6 +66,67 @@ class ColumnarSortMergeJoinExec( if (isSkewJoin) "OmniColumnarSortMergeJoin(skew=true)" else "OmniColumnarSortMergeJoin" } + override protected def withNewChildrenInternal(newLeft: SparkPlan, + newRight: SparkPlan): + ColumnarSortMergeJoinExec = copy(left = newLeft, right = newRight) + + override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator + + override def requiredChildDistribution: Seq[Distribution] = { + if (isSkewJoin) { + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + } else { + super.requiredChildDistribution + } + } + + override def outputOrdering: Seq[SortOrder] = joinType match { + case _: InnerLike => + val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering) + val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering) + leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => + val sameOrderExpressions = ExpressionSet(lKey.sameOrderExpressions ++ rKey.children) + SortOrder(lKey.child, Ascending, sameOrderExpressions.toSeq) + } + case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering) + case RightOuter => getKeyOrdering(rightKeys, right.outputOrdering) + case FullOuter => Nil + case LeftExistence(_) => getKeyOrdering(leftKeys, left.outputOrdering) + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + + private def getKeyOrdering(keys: Seq[Expression], childOutputOrdering: Seq[SortOrder]) + : Seq[SortOrder] = { + val requiredOrdering = requiredOrders(keys) + if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) { + keys.zip(childOutputOrdering).map { case (key, childOrder) => + val sameOrderExpressionSet = ExpressionSet(childOrder.children) - key + SortOrder(key, Ascending, sameOrderExpressionSet.toSeq) + } + } else { + requiredOrdering + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + keys.map(SortOrder(_, Ascending)) + } + + override def output : Seq[Attribute] = { + if (projectList.nonEmpty) { + projectList.map(_.toAttribute) + } else { + super[ShuffledJoin].output + } + } + + override def needCopyResult: Boolean = true + val SMJ_NEED_ADD_STREAM_TBL_DATA = 2 val SMJ_NEED_ADD_BUFFERED_TBL_DATA = 3 val SCAN_FINISH = 4 @@ -94,6 +153,37 @@ class ColumnarSortMergeJoinExec( "numBufferVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of buffered vecBatchs") ) + override def verboseStringWithOperatorId(): String = { + val joinCondStr = if (condition.isDefined) { + s"${condition.get}${condition.get.dataType}" + } else "None" + + s""" + |$formattedNodeName + |$simpleStringWithNodeId + |${ExplainUtils.generateFieldString("Stream input", left.output ++ left.output.map(_.dataType))} + |${ExplainUtils.generateFieldString("Buffer input", right.output ++ right.output.map(_.dataType))} + |${ExplainUtils.generateFieldString("Left keys", leftKeys ++ leftKeys.map(_.dataType))} + |${ExplainUtils.generateFieldString("Right keys", rightKeys ++ rightKeys.map(_.dataType))} + |${ExplainUtils.generateFieldString("Join condition", joinCondStr)} + |${ExplainUtils.generateFieldString("Project List", projectList ++ projectList.map(_.dataType))} + |${ExplainUtils.generateFieldString("Output", output ++ output.map(_.dataType))} + |Condition : $condition + |""".stripMargin + } + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"This operator doesn't support doExecute.") + } + + protected override def doProduce(ctx: CodegenContext): String = { + throw new UnsupportedOperationException(s"This operator doesn't support doProduce.") + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + left.execute() :: right.execute() :: Nil + } + def buildCheck(): Unit = { joinType match { case Inner | LeftOuter | FullOuter | LeftSemi | LeftAnti => @@ -160,7 +250,7 @@ class ColumnarSortMergeJoinExec( OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute))) }.toArray - val streamedOutputChannel = left.output.indices.toArray + val streamedOutputChannel = getIndexArray(left.output, projectList) val bufferedTypes = new Array[DataType](right.output.size) right.output.zipWithIndex.foreach { case (attr, i) => @@ -172,7 +262,7 @@ class ColumnarSortMergeJoinExec( }.toArray val bufferedOutputChannel: Array[Int] = joinType match { case Inner | LeftOuter | FullOuter => - right.output.indices.toArray + getIndexArray(right.output, projectList) case LeftExistence(_) => Array[Int]() case x => @@ -214,6 +304,9 @@ class ColumnarSortMergeJoinExec( streamedOpFactory.close() }) + val prunedStreamOutput = pruneOutput(left.output, projectList) + val prunedBufferOutput = pruneOutput(right.output, projectList) + val prunedOutput = prunedStreamOutput ++ prunedBufferOutput val resultSchema = this.schema val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf val enableSortMergeJoinBatchMerge: Boolean = columnarConf.enableSortMergeJoinBatchMerge @@ -321,10 +414,14 @@ class ColumnarSortMergeJoinExec( getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOutputTime) val resultVecs = result.getVectors val vecs = OmniColumnVector.allocateColumns(result.getRowCount, resultSchema, false) - for (index <- output.indices) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(index)) + if (projectList.nonEmpty) { + reorderVecs(prunedOutput, projectList, resultVecs, vecs) + } else { + for (index <- output.indices) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(index)) + } } numOutputVecBatchs += 1 numOutputRows += result.getRowCount @@ -345,7 +442,7 @@ class ColumnarSortMergeJoinExec( case DataType.DataTypeId.OMNI_BOOLEAN => new BooleanVec(0) case DataType.DataTypeId.OMNI_CHAR | DataType.DataTypeId.OMNI_VARCHAR => - new VarcharVec(0, 0) + new VarcharVec(0) case DataType.DataTypeId.OMNI_DECIMAL128 => new Decimal128Vec(0) case DataType.DataTypeId.OMNI_SHORT => diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala index c67d45032589b74ee414625010cba01ba716465b..68ac49cec66b2da845c14b085f2f40316ce0c24b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala @@ -57,8 +57,7 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, vecs(index) = new BooleanVec(columnSize) case StringType => val vecType: DataType = sparkTypeToOmniType(field.dataType, field.metadata) - vecs(index) = new VarcharVec(vecType.asInstanceOf[VarcharDataType].getWidth * columnSize, - columnSize) + vecs(index) = new VarcharVec(columnSize) case dt: DecimalType => if (DecimalType.is64BitDecimalType(dt)) { vecs(index) = new LongVec(columnSize) @@ -98,6 +97,8 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, src.close() } } + // close bufferedBatch + bufferedBatch.foreach(batch => batch.close()) } private def flush(): Unit = { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala index 6012da931bb3b93ef8a3e6690d42ba3d1e4949e0..946c90a9baf346dc4e47253ced50a53def22374b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.execution.util -import nova.hetu.omniruntime.vector.VecAllocator - +import nova.hetu.omniruntime.memory +import nova.hetu.omniruntime.memory.MemoryManager import org.apache.spark.{SparkEnv, TaskContext} object SparkMemoryUtils { private val max: Long = SparkEnv.get.conf.getSizeAsBytes("spark.memory.offHeap.size", "1g") - VecAllocator.setRootAllocatorLimit(max) + MemoryManager.setGlobalMemoryLimit(max) def init(): Unit = {} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java deleted file mode 100644 index d95be18832b926500b599821b6b6fd0baa8861c5..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark; - -import com.huawei.boostkit.spark.jni.SparkJniWrapper; - -import java.io.File; -import nova.hetu.omniruntime.type.DataType; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_CHAR; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE32; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE64; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL128; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL64; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DOUBLE; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; -import nova.hetu.omniruntime.type.DataTypeSerializer; -import nova.hetu.omniruntime.vector.VecBatch; - -import org.junit.After; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; - -import java.io.IOException; - -public class ColumnShuffleCompressionTest extends ColumnShuffleTest { - private static String shuffleDataFile = ""; - - @BeforeClass - public static void runOnceBeforeClass() { - File folder = new File(shuffleTestDir); - if (!folder.exists() && !folder.isDirectory()) { - folder.mkdirs(); - } - } - - @AfterClass - public static void runOnceAfterClass() { - File folder = new File(shuffleTestDir); - if (folder.exists()) { - deleteDir(folder); - } - } - - @Before - public void runBeforeTestMethod() { - - } - - @After - public void runAfterTestMethod() { - File file = new File(shuffleDataFile); - if (file.exists()) { - file.delete(); - } - } - - @Test - public void columnShuffleUncompressedTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_uncompressed_test"; - columnShuffleTestCompress("uncompressed", shuffleDataFile); - } - - @Test - public void columnShuffleSnappyCompressTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_snappy_test"; - columnShuffleTestCompress("snappy", shuffleDataFile); - } - - @Test - public void columnShuffleLz4CompressTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_lz4_test"; - columnShuffleTestCompress("lz4", shuffleDataFile); - } - - @Test - public void columnShuffleZlibCompressTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_zlib_test"; - columnShuffleTestCompress("zlib", shuffleDataFile); - } - - public void columnShuffleTestCompress(String compressType, String dataFile) throws IOException { - DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR, - OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; - DataType[] types = dataTypeId2DataType(idTypes); - String inputType = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - partitionNum, - inputType, - types.length, - 1024, //shuffle value_buffer init size - compressType, - dataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - for (int i = 0; i < 999; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 1000, partitionNum, true, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java deleted file mode 100644 index c8fd474137a93ea8831d3dc3ab432e409018cc55..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark; - -import com.huawei.boostkit.spark.jni.SparkJniWrapper; - -import java.io.File; -import nova.hetu.omniruntime.type.DataType; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_CHAR; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE32; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE64; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL128; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL64; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DOUBLE; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; -import nova.hetu.omniruntime.type.DataTypeSerializer; -import nova.hetu.omniruntime.vector.VecBatch; - -import org.junit.After; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; - -import java.io.IOException; - -public class ColumnShuffleDiffPartitionTest extends ColumnShuffleTest { - private static String shuffleDataFile = ""; - - @BeforeClass - public static void runOnceBeforeClass() { - File folder = new File(shuffleTestDir); - if (!folder.exists() && !folder.isDirectory()) { - folder.mkdirs(); - } - } - - @AfterClass - public static void runOnceAfterClass() { - File folder = new File(shuffleTestDir); - if (folder.exists()) { - deleteDir(folder); - } - } - - @Before - public void runBeforeTestMethod() { - - } - - @After - public void runAfterTestMethod() { - File file = new File(shuffleDataFile); - if (file.exists()) { - file.delete(); - } - } - - @Test - public void columnShuffleSinglePartitionTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_singlePartition_test"; - columnShufflePartitionTest("single", shuffleDataFile); - } - - @Test - public void columnShuffleHashPartitionTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_hashPartition_test"; - columnShufflePartitionTest("hash", shuffleDataFile); - } - - @Test - public void columnShuffleRangePartitionTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_rangePartition_test"; - columnShufflePartitionTest("range", shuffleDataFile); - } - - public void columnShufflePartitionTest(String partitionType, String dataFile) throws IOException { - DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR, - OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 1; - boolean pidVec = true; - if (partitionType.equals("single")){ - pidVec = false; - } - long splitterId = jniWrapper.nativeMake( - partitionType, - 1, - tmpStr, - types.length, - 3, - "lz4", - dataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - for (int i = 0; i < 99; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, true, pidVec); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java deleted file mode 100644 index dc53fda8a1a04a15bf7ffb9919926d4812208fc0..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java +++ /dev/null @@ -1,303 +0,0 @@ -/* - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark; - -import com.huawei.boostkit.spark.jni.SparkJniWrapper; - -import java.io.File; -import nova.hetu.omniruntime.type.DataType; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_CHAR; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE32; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE64; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL128; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL64; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DOUBLE; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; -import nova.hetu.omniruntime.type.DataTypeSerializer; -import nova.hetu.omniruntime.vector.VecBatch; - -import org.junit.After; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; - -import java.io.IOException; - -public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest { - private static String shuffleDataFile = ""; - - @BeforeClass - public static void runOnceBeforeClass() { - File folder = new File(shuffleTestDir); - if (!folder.exists() && !folder.isDirectory()) { - folder.mkdirs(); - } - } - - @AfterClass - public static void runOnceAfterClass() { - File folder = new File(shuffleTestDir); - if (folder.exists()) { - deleteDir(folder); - } - } - - @Before - public void runBeforeTestMethod() { - - } - - @After - public void runAfterTestMethod() { - File file = new File(shuffleDataFile); - if (file.exists()) { - file.delete(); - } - } - - @Test - public void columnShuffleMixColTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_MixCol_test"; - DataType.DataTypeId[] idTypes = {OMNI_LONG, OMNI_DOUBLE, OMNI_INT, OMNI_VARCHAR, OMNI_CHAR, - OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - partitionNum, - tmpStr, - types.length, - 3, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - for (int i = 0; i < 999; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, true, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Test - public void columnShuffleVarCharFirstTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_varCharFirst_test"; - DataType.DataTypeId[] idTypes = {OMNI_VARCHAR, OMNI_LONG, OMNI_DOUBLE, OMNI_INT, OMNI_CHAR, - OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - partitionNum, - tmpStr, - types.length, - 3, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 0, - 4096, - 1024*1024*1024); - for (int i = 0; i < 999; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, true, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Test - public void columnShuffle1Row1024VBTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_1row1024vb_test"; - DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR, - OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - partitionNum, - tmpStr, - types.length, - 3, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - for (int i = 0; i < 1024; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 1, partitionNum, false, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Test - public void columnShuffle1024Row1VBTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_1024row1vb_test"; - DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR, - OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - partitionNum, - tmpStr, - types.length, - 3, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - for (int i = 0; i < 1; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 1024, partitionNum, false, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Test - public void columnShuffleChangeRowVBTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_changeRow_test"; - DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int numPartition = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - numPartition, - tmpStr, - types.length, - 3, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - for (int i = 1; i < 1000; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, i, numPartition, false, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Test - public void columnShuffleVarChar1RowVBTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_varChar1Row_test"; - DataType.DataTypeId[] idTypes = {OMNI_VARCHAR}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - partitionNum, - tmpStr, - types.length, - 3, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - VecBatch vecBatchTmp1 = new VecBatch(buildValChar(3, "N")); - jniWrapper.split(splitterId, vecBatchTmp1.getNativeVectorBatch()); - VecBatch vecBatchTmp2 = new VecBatch(buildValChar(2, "F")); - jniWrapper.split(splitterId, vecBatchTmp2.getNativeVectorBatch()); - VecBatch vecBatchTmp3 = new VecBatch(buildValChar(3, "N")); - jniWrapper.split(splitterId, vecBatchTmp3.getNativeVectorBatch()); - VecBatch vecBatchTmp4 = new VecBatch(buildValChar(2, "F")); - jniWrapper.split(splitterId, vecBatchTmp4.getNativeVectorBatch()); - VecBatch vecBatchTmp5 = new VecBatch(buildValChar(2, "F")); - jniWrapper.split(splitterId, vecBatchTmp5.getNativeVectorBatch()); - VecBatch vecBatchTmp6 = new VecBatch(buildValChar(2, "F")); - jniWrapper.split(splitterId, vecBatchTmp6.getNativeVectorBatch()); - VecBatch vecBatchTmp7 = new VecBatch(buildValChar(1, "R")); - jniWrapper.split(splitterId, vecBatchTmp7.getNativeVectorBatch()); - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Test - public void columnShuffleFix1RowVBTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_fix1Row_test"; - DataType.DataTypeId[] idTypes = {OMNI_INT}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - partitionNum, - tmpStr, - types.length, - 3, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - VecBatch vecBatchTmp1 = new VecBatch(buildValInt(3, 1)); - jniWrapper.split(splitterId, vecBatchTmp1.getNativeVectorBatch()); - VecBatch vecBatchTmp2 = new VecBatch(buildValInt(2, 2)); - jniWrapper.split(splitterId, vecBatchTmp2.getNativeVectorBatch()); - VecBatch vecBatchTmp3 = new VecBatch(buildValInt(3, 3)); - jniWrapper.split(splitterId, vecBatchTmp3.getNativeVectorBatch()); - VecBatch vecBatchTmp4 = new VecBatch(buildValInt(2, 4)); - jniWrapper.split(splitterId, vecBatchTmp4.getNativeVectorBatch()); - VecBatch vecBatchTmp5 = new VecBatch(buildValInt(2, 5)); - jniWrapper.split(splitterId, vecBatchTmp5.getNativeVectorBatch()); - VecBatch vecBatchTmp6 = new VecBatch(buildValInt(1, 6)); - jniWrapper.split(splitterId, vecBatchTmp6.getNativeVectorBatch()); - VecBatch vecBatchTmp7 = new VecBatch(buildValInt(3, 7)); - jniWrapper.split(splitterId, vecBatchTmp7.getNativeVectorBatch()); - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java deleted file mode 100644 index 2ef81ac49e545aa617136b9d4f3e7e769ea34652..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark; - -import com.huawei.boostkit.spark.jni.SparkJniWrapper; - -import java.io.File; -import nova.hetu.omniruntime.type.DataType; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_CHAR; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE32; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE64; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL128; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL64; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DOUBLE; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; -import nova.hetu.omniruntime.type.DataTypeSerializer; -import nova.hetu.omniruntime.vector.VecBatch; - -import org.junit.After; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Ignore; -import org.junit.Test; - -import java.io.IOException; - -public class ColumnShuffleGBSizeTest extends ColumnShuffleTest { - private static String shuffleDataFile = ""; - - @BeforeClass - public static void runOnceBeforeClass() { - File folder = new File(shuffleTestDir); - if (!folder.exists() && !folder.isDirectory()) { - folder.mkdirs(); - } - } - - @AfterClass - public static void runOnceAfterClass() { - File folder = new File(shuffleTestDir); - if (folder.exists()) { - deleteDir(folder); - } - } - - @Before - public void runBeforeTestMethod() { - - } - - @After - public void runAfterTestMethod() { - File file = new File(shuffleDataFile); - if (file.exists()) { - file.delete(); - } - } - - @Test - public void columnShuffleFixed1GBTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_fixed1GB_test"; - DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 3; - long splitterId = jniWrapper.nativeMake( - "hash", - partitionNum, - tmpStr, - types.length, - 4096, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - for (int i = 0; i < 6 * 1024; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Ignore - public void columnShuffleFixed10GBTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_fixed10GB_test"; - DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 3; - long splitterId = jniWrapper.nativeMake( - "hash", - partitionNum, - tmpStr, - types.length, - 4096, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - for (int i = 0; i < 10 * 8 * 1024; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Test - public void columnShuffleVarChar1GBTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_varChar1GB_test"; - DataType.DataTypeId[] idTypes = {OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - partitionNum, - tmpStr, - types.length, - 1024, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core - for (int i = 0; i < 99; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, false, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Ignore - public void columnShuffleVarChar10GBTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_varChar10GB_test"; - DataType.DataTypeId[] idTypes = {OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - partitionNum, - tmpStr, - types.length, - 1024, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - for (int i = 0; i < 10 * 3 * 999; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Test - public void columnShuffleMix1GBTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_mix1GB_test"; - DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR, - OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - partitionNum, - tmpStr, - types.length, - 4096, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core - for (int i = 0; i < 6 * 999; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Ignore - public void columnShuffleMix10GBTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_mix10GB_test"; - DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR, - OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int partitionNum = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - partitionNum, - tmpStr, - types.length, - 4096, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - for (int i = 0; i < 3 * 9 * 999; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java deleted file mode 100644 index 98fc18dd8f3237928cc066887e6fcb2205686692..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark; - -import com.huawei.boostkit.spark.jni.SparkJniWrapper; - -import java.io.File; -import nova.hetu.omniruntime.type.DataType; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_CHAR; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE32; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE64; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL128; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL64; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DOUBLE; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; -import nova.hetu.omniruntime.type.DataTypeSerializer; -import nova.hetu.omniruntime.vector.VecBatch; - -import org.junit.After; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; - -import java.io.IOException; - -public class ColumnShuffleNullTest extends ColumnShuffleTest { - private static String shuffleDataFile = ""; - - @BeforeClass - public static void runOnceBeforeClass() { - File folder = new File(shuffleTestDir); - if (!folder.exists() && !folder.isDirectory()) { - folder.mkdirs(); - } - } - - @AfterClass - public static void runOnceAfterClass() { - File folder = new File(shuffleTestDir); - if (folder.exists()) { - deleteDir(folder); - } - } - - @Before - public void runBeforeTestMethod() { - - } - - @After - public void runAfterTestMethod() { - File file = new File(shuffleDataFile); - if (file.exists()) { - file.delete(); - } - } - - @Test - public void columnShuffleFixNullTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_fixNull_test"; - DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int numPartition = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - numPartition, - tmpStr, - types.length, - 3, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core - for (int i = 0; i < 1; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 9, numPartition, true, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Test - public void columnShuffleVarCharNullTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_fixNull_test"; - DataType.DataTypeId[] idTypes = {OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR,OMNI_VARCHAR}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int numPartition = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - numPartition, - tmpStr, - types.length, - 3, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core - for (int i = 0; i < 1; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 9, numPartition, true, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Test - public void columnShuffleMixNullTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_MixNull_test"; - DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE,OMNI_VARCHAR, OMNI_CHAR, - OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int numPartition = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - numPartition, - tmpStr, - types.length, - 3, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core - for (int i = 0; i < 1; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 9, numPartition, true, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } - - @Test - public void columnShuffleMixNullFullTest() throws IOException { - shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_MixNullFull_test"; - DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE,OMNI_VARCHAR, OMNI_CHAR, - OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; - DataType[] types = dataTypeId2DataType(idTypes); - String tmpStr = DataTypeSerializer.serialize(types); - SparkJniWrapper jniWrapper = new SparkJniWrapper(); - int numPartition = 4; - long splitterId = jniWrapper.nativeMake( - "hash", - numPartition, - tmpStr, - types.length, - 3, //shuffle value_buffer init size - "lz4", - shuffleDataFile, - 0, - shuffleTestDir, - 64 * 1024, - 4096, - 1024 * 1024 * 1024); - for (int i = 0; i < 1; i++) { - VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, numPartition, true, true); - jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); - } - jniWrapper.stop(splitterId); - jniWrapper.close(splitterId); - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleTest.java deleted file mode 100644 index 74fccca66fad64dac9c96ae5f60591de40e92012..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleTest.java +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark; - -import java.io.File; -import nova.hetu.omniruntime.type.CharDataType; -import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.type.Date32DataType; -import nova.hetu.omniruntime.type.Date64DataType; -import nova.hetu.omniruntime.type.Decimal128DataType; -import nova.hetu.omniruntime.type.Decimal64DataType; -import nova.hetu.omniruntime.type.DoubleDataType; -import nova.hetu.omniruntime.type.IntDataType; -import nova.hetu.omniruntime.type.LongDataType; -import nova.hetu.omniruntime.type.VarcharDataType; -import nova.hetu.omniruntime.vector.Decimal128Vec; -import nova.hetu.omniruntime.vector.DoubleVec; -import nova.hetu.omniruntime.vector.IntVec; -import nova.hetu.omniruntime.vector.LongVec; -import nova.hetu.omniruntime.vector.VarcharVec; -import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecBatch; - -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; - -abstract class ColumnShuffleTest { - public static String shuffleTestDir = "/tmp/shuffleTests"; - - public DataType[] dataTypeId2DataType(DataType.DataTypeId[] idTypes) { - DataType[] types = new DataType[idTypes.length]; - for(int i = 0; i < idTypes.length; i++) { - switch (idTypes[i]) { - case OMNI_INT: { - types[i] = IntDataType.INTEGER; - break; - } - case OMNI_LONG: { - types[i] = LongDataType.LONG; - break; - } - case OMNI_DOUBLE: { - types[i] = DoubleDataType.DOUBLE; - break; - } - case OMNI_VARCHAR: { - types[i] = VarcharDataType.VARCHAR; - break; - } - case OMNI_CHAR: { - types[i] = CharDataType.CHAR; - break; - } - case OMNI_DATE32: { - types[i] = Date32DataType.DATE32; - break; - } - case OMNI_DATE64: { - types[i] = Date64DataType.DATE64; - break; - } - case OMNI_DECIMAL64: { - types[i] = Decimal64DataType.DECIMAL64; - break; - } - case OMNI_DECIMAL128: { - types[i] = Decimal128DataType.DECIMAL128; // Or types[i] = new Decimal128DataType(2, 0); - break; - } - default: { - throw new UnsupportedOperationException("Unsupported type : " + idTypes[i]); - } - } - } - return types; - } - - public VecBatch buildVecBatch(DataType.DataTypeId[] idTypes, int rowNum, int partitionNum, boolean mixHalfNull, boolean withPidVec) { - List columns = new ArrayList<>(); - Vec tmpVec = null; - // prepare pidVec - if (withPidVec) { - IntVec pidVec = new IntVec(rowNum); - for (int i = 0; i < rowNum; i++) { - pidVec.set(i, i % partitionNum); - } - columns.add(pidVec); - } - - for(int i = 0; i < idTypes.length; i++) { - switch (idTypes[i]) { - case OMNI_INT: - case OMNI_DATE32:{ - tmpVec = new IntVec(rowNum); - for (int j = 0; j < rowNum; j++) { - ((IntVec)tmpVec).set(j, j + 1); - if (mixHalfNull && (j % 2) == 0) { - tmpVec.setNull(j); - } - } - break; - } - case OMNI_LONG: - case OMNI_DECIMAL64: - case OMNI_DATE64: { - tmpVec = new LongVec(rowNum); - for (int j = 0; j < rowNum; j++) { - ((LongVec)tmpVec).set(j, j + 1); - if (mixHalfNull && (j % 2) == 0) { - tmpVec.setNull(j); - } - } - break; - } - case OMNI_DOUBLE: { - tmpVec = new DoubleVec(rowNum); - for (int j = 0; j < rowNum; j++) { - ((DoubleVec)tmpVec).set(j, j + 1); - if (mixHalfNull && (j % 2) == 0) { - tmpVec.setNull(j); - } - } - break; - } - case OMNI_VARCHAR: - case OMNI_CHAR: { - tmpVec = new VarcharVec(rowNum * 16, rowNum); - for (int j = 0; j < rowNum; j++) { - ((VarcharVec)tmpVec).set(j, ("VAR_" + (j + 1) + "_END").getBytes(StandardCharsets.UTF_8)); - if (mixHalfNull && (j % 2) == 0) { - tmpVec.setNull(j); - } - } - break; - } - case OMNI_DECIMAL128: { - long[][] arr = new long[rowNum][2]; - for (int j = 0; j < rowNum; j++) { - arr[j][0] = 2 * j; - arr[j][1] = 2 * j + 1; - if (mixHalfNull && (j % 2) == 0) { - arr[j] = null; - } - } - tmpVec = createDecimal128Vec(arr); - break; - } - default: { - throw new UnsupportedOperationException("Unsupported type : " + idTypes[i]); - } - } - columns.add(tmpVec); - } - return new VecBatch(columns); - } - - public Decimal128Vec createDecimal128Vec(long[][] data) { - Decimal128Vec result = new Decimal128Vec(data.length); - for (int i = 0; i < data.length; i++) { - if (data[i] == null) { - result.setNull(i); - } else { - result.set(i, new long[]{data[i][0], data[i][1]}); - } - } - return result; - } - - public List buildValInt(int pid, int val) { - IntVec c0 = new IntVec(1); - IntVec c1 = new IntVec(1); - c0.set(0, pid); - c1.set(0, val); - List columns = new ArrayList<>(); - columns.add(c0); - columns.add(c1); - return columns; - } - - public List buildValChar(int pid, String varChar) { - IntVec c0 = new IntVec(1); - VarcharVec c1 = new VarcharVec(8, 1); - c0.set(0, pid); - c1.set(0, varChar.getBytes(StandardCharsets.UTF_8)); - List columns = new ArrayList<>(); - columns.add(c0); - columns.add(c1); - return columns; - } - - public static boolean deleteDir(File dir) { - if (dir.isDirectory()) { - String[] children = dir.list(); - for (int i=0; i includedColumns = new ArrayList(); - // type long - includedColumns.add("i_item_sk"); - // type char 16 - includedColumns.add("i_item_id"); - // type char 200 - includedColumns.add("i_item_desc"); - // type int - includedColumns.add("i_current_price"); - job.put("includedColumns", includedColumns.toArray()); - - orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); - assertTrue(orcColumnarBatchJniReader.recordReader != 0); - } - - public void initBatch() { - orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); - assertTrue(orcColumnarBatchJniReader.batchReader != 0); - } - - @Test - public void testNext() { - int[] typeId = new int[4]; - long[] vecNativeId = new long[4]; - long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); - assertTrue(rtn == 4096); - LongVec vec1 = new LongVec(vecNativeId[0]); - VarcharVec vec2 = new VarcharVec(vecNativeId[1]); - VarcharVec vec3 = new VarcharVec(vecNativeId[2]); - IntVec vec4 = new IntVec(vecNativeId[3]); - assertTrue(vec1.get(10) == 11); - String tmp1 = new String(vec2.get(4080)); - assertTrue(tmp1.equals("AAAAAAAABPPAAAAA")); - String tmp2 = new String(vec3.get(4070)); - assertTrue(tmp2.equals("Particular, arab cases shall like less current, different names. Computers start for the changes. Scottish, trying exercises operate marks; long, supreme miners may ro")); - assertTrue(0 == vec4.get(1000)); - vec1.close(); - vec2.close(); - vec3.close(); - vec4.close(); - } -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java deleted file mode 100644 index d9fe13683343f4299ad2b4b2290b0cbf47d761e1..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark.jni; - -import junit.framework.TestCase; -import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.IntVec; -import nova.hetu.omniruntime.vector.LongVec; -import nova.hetu.omniruntime.vector.VarcharVec; -import nova.hetu.omniruntime.vector.Vec; -import org.json.JSONObject; -import org.junit.After; -import org.junit.Before; -import org.junit.FixMethodOrder; -import org.junit.Test; -import org.junit.runners.MethodSorters; - -import java.io.File; -import java.util.ArrayList; - -import static org.junit.Assert.*; - -@FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) -public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase { - public OrcColumnarBatchJniReader orcColumnarBatchJniReader; - - @Before - public void setUp() throws Exception { - orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); - initReaderJava(); - initRecordReaderJava(); - initBatch(); - } - - @After - public void tearDown() throws Exception { - System.out.println("orcColumnarBatchJniReader test finished"); - } - - public void initReaderJava() { - JSONObject job = new JSONObject(); - job.put("serializedTail",""); - job.put("tailLocation",9223372036854775807L); - File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); - System.out.println(directory.getAbsolutePath()); - orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReader(directory.getAbsolutePath(), job); - assertTrue(orcColumnarBatchJniReader.reader != 0); - } - - public void initRecordReaderJava() { - JSONObject job = new JSONObject(); - job.put("include",""); - job.put("offset", 0); - job.put("length", 3345152); - - ArrayList includedColumns = new ArrayList(); - includedColumns.add("i_item_sk"); - includedColumns.add("i_item_id"); - job.put("includedColumns", includedColumns.toArray()); - - orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); - assertTrue(orcColumnarBatchJniReader.recordReader != 0); - } - - public void initBatch() { - orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); - assertTrue(orcColumnarBatchJniReader.batchReader != 0); - } - - @Test - public void testNext() { - int[] typeId = new int[2]; - long[] vecNativeId = new long[2]; - long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); - assertTrue(rtn == 4096); - LongVec vec1 = new LongVec(vecNativeId[0]); - VarcharVec vec2 = new VarcharVec(vecNativeId[1]); - assertTrue(vec1.get(4090) == 4091); - assertTrue(vec1.get(4000) == 4001); - String tmp1 = new String(vec2.get(4090)); - String tmp2 = new String(vec2.get(4000)); - assertTrue(tmp1.equals("AAAAAAAAKPPAAAAA")); - assertTrue(tmp2.equals("AAAAAAAAAKPAAAAA")); - vec1.close(); - vec2.close(); - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java deleted file mode 100644 index 87f0cc1d2920982de3b73d9046d173a8f2c8fbb8..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark.jni; - -import static org.junit.Assert.*; -import junit.framework.TestCase; -import org.apache.hadoop.mapred.join.ArrayListBackedIterator; -import org.apache.orc.OrcFile.ReaderOptions; -import org.apache.orc.Reader.Options; -import org.hamcrest.Condition; -import org.json.JSONObject; -import org.junit.After; -import org.junit.Before; -import org.junit.FixMethodOrder; -import org.junit.Test; -import org.junit.runners.MethodSorters; -import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.IntVec; -import nova.hetu.omniruntime.vector.LongVec; -import nova.hetu.omniruntime.vector.VarcharVec; -import nova.hetu.omniruntime.vector.Vec; - -import java.io.File; -import java.lang.reflect.Array; -import java.util.ArrayList; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -@FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) -public class OrcColumnarBatchJniReaderPushDownTest extends TestCase { - public OrcColumnarBatchJniReader orcColumnarBatchJniReader; - - @Before - public void setUp() throws Exception { - orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); - initReaderJava(); - initRecordReaderJava(); - initBatch(); - } - - @After - public void tearDown() throws Exception { - System.out.println("orcColumnarBatchJniReader test finished"); - } - - public void initReaderJava() { - JSONObject job = new JSONObject(); - job.put("serializedTail",""); - job.put("tailLocation",9223372036854775807L); - File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); - System.out.println(directory.getAbsolutePath()); - orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReader(directory.getAbsolutePath(), job); - assertTrue(orcColumnarBatchJniReader.reader != 0); - } - - public void initRecordReaderJava() { - JSONObject job = new JSONObject(); - job.put("include",""); - job.put("offset", 0); - job.put("length", 3345152); - - ArrayList childList1 = new ArrayList(); - JSONObject child1 = new JSONObject(); - child1.put("op", 3); - child1.put("leaf", "leaf-0"); - childList1.add(child1); - JSONObject subChild1 = new JSONObject(); - subChild1.put("op", 2); - subChild1.put("child", childList1); - - ArrayList childList2 = new ArrayList(); - JSONObject child2 = new JSONObject(); - child2.put("op", 3); - child2.put("leaf", "leaf-1"); - childList2.add(child2); - JSONObject subChild2 = new JSONObject(); - subChild2.put("op", 2); - subChild2.put("child", childList2); - - ArrayList childs = new ArrayList(); - childs.add(subChild1); - childs.add(subChild2); - - JSONObject expressionTree = new JSONObject(); - expressionTree.put("op", 1); - expressionTree.put("child", childs); - job.put("expressionTree", expressionTree); - - JSONObject leaves = new JSONObject(); - JSONObject leaf0 = new JSONObject(); - leaf0.put("op", 6); - leaf0.put("name", "i_item_sk"); - leaf0.put("type", 0); - leaf0.put("literal", ""); - leaf0.put("literalList", new ArrayList()); - - JSONObject leaf1 = new JSONObject(); - leaf1.put("op", 3); - leaf1.put("name", "i_item_sk"); - leaf1.put("type", 0); - leaf1.put("literal", "100"); - leaf1.put("literalList", new ArrayList()); - - leaves.put("leaf-0", leaf0); - leaves.put("leaf-1", leaf1); - job.put("leaves", leaves); - - ArrayList includedColumns = new ArrayList(); - includedColumns.add("i_item_sk"); - includedColumns.add("i_item_id"); - job.put("includedColumns", includedColumns.toArray()); - - orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); - assertTrue(orcColumnarBatchJniReader.recordReader != 0); - } - - public void initBatch() { - orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); - assertTrue(orcColumnarBatchJniReader.batchReader != 0); - } - - @Test - public void testNext() { - int[] typeId = new int[2]; - long[] vecNativeId = new long[2]; - long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); - assertTrue(rtn == 4096); - LongVec vec1 = new LongVec(vecNativeId[0]); - VarcharVec vec2 = new VarcharVec(vecNativeId[1]); - assertTrue(11 == vec1.get(10)); - assertTrue(21 == vec1.get(20)); - String tmp1 = new String(vec2.get(10)); - String tmp2 = new String(vec2.get(20)); - assertTrue(tmp1.equals("AAAAAAAAKAAAAAAA")); - assertTrue(tmp2.equals("AAAAAAAAEBAAAAAA")); - vec1.close(); - vec2.close(); - } - -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java deleted file mode 100644 index 484365c537231b46816e139b090d2384f08b5588..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark.jni; - -import junit.framework.TestCase; -import nova.hetu.omniruntime.vector.IntVec; -import nova.hetu.omniruntime.vector.LongVec; -import nova.hetu.omniruntime.vector.VarcharVec; -import org.json.JSONObject; -import org.junit.After; -import org.junit.Before; -import org.junit.FixMethodOrder; -import org.junit.Test; -import org.junit.runners.MethodSorters; - -import java.io.File; -import java.util.ArrayList; - -import static org.junit.Assert.*; - -@FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) -public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase { - public OrcColumnarBatchJniReader orcColumnarBatchJniReader; - - @Before - public void setUp() throws Exception { - orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); - initReaderJava(); - initRecordReaderJava(); - initBatch(); - } - - @After - public void tearDown() throws Exception { - System.out.println("orcColumnarBatchJniReader test finished"); - } - - public void initReaderJava() { - JSONObject job = new JSONObject(); - job.put("serializedTail",""); - job.put("tailLocation",9223372036854775807L); - File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/part-00000-2d6ca713-08b0-4b40-828c-f7ee0c81bb9a-c000.snappy.orc"); - System.out.println(directory.getAbsolutePath()); - orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReader(directory.getAbsolutePath(), job); - assertTrue(orcColumnarBatchJniReader.reader != 0); - } - - public void initRecordReaderJava() { - JSONObject job = new JSONObject(); - job.put("include",""); - job.put("offset", 0); - job.put("length", 3345152); - - ArrayList includedColumns = new ArrayList(); - // type long - includedColumns.add("i_item_sk"); - // type char 16 - includedColumns.add("i_item_id"); - // type char 200 - includedColumns.add("i_item_desc"); - // type int - includedColumns.add("i_current_price"); - job.put("includedColumns", includedColumns.toArray()); - - orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); - assertTrue(orcColumnarBatchJniReader.recordReader != 0); - } - - public void initBatch() { - orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); - assertTrue(orcColumnarBatchJniReader.batchReader != 0); - } - - @Test - public void testNext() { - int[] typeId = new int[4]; - long[] vecNativeId = new long[4]; - long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); - assertTrue(rtn == 4096); - LongVec vec1 = new LongVec(vecNativeId[0]); - VarcharVec vec2 = new VarcharVec(vecNativeId[1]); - VarcharVec vec3 = new VarcharVec(vecNativeId[2]); - IntVec vec4 = new IntVec(vecNativeId[3]); - - assertTrue(vec1.get(4095) == 4096); - String tmp1 = new String(vec2.get(4095)); - assertTrue(tmp1.equals("AAAAAAAAAAABAAAA")); - String tmp2 = new String(vec3.get(4095)); - assertTrue(tmp2.equals("Find")); - assertTrue(vec4.get(4095) == 6); - vec1.close(); - vec2.close(); - vec3.close(); - vec4.close(); - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java deleted file mode 100644 index b03d60aac4b61291c614bce9f7a52503918a1106..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark.jni; - -import junit.framework.TestCase; -import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.IntVec; -import nova.hetu.omniruntime.vector.LongVec; -import nova.hetu.omniruntime.vector.VarcharVec; -import nova.hetu.omniruntime.vector.Vec; -import org.json.JSONObject; -import org.junit.After; -import org.junit.Before; -import org.junit.FixMethodOrder; -import org.junit.Test; -import org.junit.runners.MethodSorters; - -import java.io.File; -import java.util.ArrayList; - -import static org.junit.Assert.*; - -@FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) -public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase { - public OrcColumnarBatchJniReader orcColumnarBatchJniReader; - - @Before - public void setUp() throws Exception { - orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); - initReaderJava(); - initRecordReaderJava(); - initBatch(); - } - - @After - public void tearDown() throws Exception { - System.out.println("orcColumnarBatchJniReader test finished"); - } - - public void initReaderJava() { - JSONObject job = new JSONObject(); - job.put("serializedTail",""); - job.put("tailLocation",9223372036854775807L); - File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/part-00000-2d6ca713-08b0-4b40-828c-f7ee0c81bb9a-c000.snappy.orc"); - System.out.println(directory.getAbsolutePath()); - orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReader(directory.getAbsolutePath(), job); - assertTrue(orcColumnarBatchJniReader.reader != 0); - } - - public void initRecordReaderJava() { - JSONObject job = new JSONObject(); - job.put("include",""); - job.put("offset", 0); - job.put("length", 3345152); - - ArrayList childList1 = new ArrayList(); - JSONObject child1 = new JSONObject(); - child1.put("op", 3); - child1.put("leaf", "leaf-0"); - childList1.add(child1); - JSONObject subChild1 = new JSONObject(); - subChild1.put("op", 2); - subChild1.put("child", childList1); - - ArrayList childList2 = new ArrayList(); - JSONObject child2 = new JSONObject(); - child2.put("op", 3); - child2.put("leaf", "leaf-1"); - childList2.add(child2); - JSONObject subChild2 = new JSONObject(); - subChild2.put("op", 2); - subChild2.put("child", childList2); - - ArrayList childs = new ArrayList(); - childs.add(subChild1); - childs.add(subChild2); - - JSONObject expressionTree = new JSONObject(); - expressionTree.put("op", 1); - expressionTree.put("child", childs); - job.put("expressionTree", expressionTree); - - JSONObject leaves = new JSONObject(); - JSONObject leaf0 = new JSONObject(); - leaf0.put("op", 6); - leaf0.put("name", "i_item_sk"); - leaf0.put("type", 0); - leaf0.put("literal", ""); - leaf0.put("literalList", new ArrayList()); - - JSONObject leaf1 = new JSONObject(); - leaf1.put("op", 3); - leaf1.put("name", "i_item_sk"); - leaf1.put("type", 0); - leaf1.put("literal", "100"); - leaf1.put("literalList", new ArrayList()); - - leaves.put("leaf-0", leaf0); - leaves.put("leaf-1", leaf1); - job.put("leaves", leaves); - - ArrayList includedColumns = new ArrayList(); - // type long - includedColumns.add("i_item_sk"); - // type char 16 - includedColumns.add("i_item_id"); - // type char 200 - includedColumns.add("i_item_desc"); - // type int - includedColumns.add("i_current_price"); - job.put("includedColumns", includedColumns.toArray()); - - orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); - assertTrue(orcColumnarBatchJniReader.recordReader != 0); - } - - public void initBatch() { - orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); - assertTrue(orcColumnarBatchJniReader.batchReader != 0); - } - - @Test - public void testNext() { - int[] typeId = new int[4]; - long[] vecNativeId = new long[4]; - long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); - assertTrue(rtn == 4096); - LongVec vec1 = new LongVec(vecNativeId[0]); - VarcharVec vec2 = new VarcharVec(vecNativeId[1]); - VarcharVec vec3 = new VarcharVec(vecNativeId[2]); - IntVec vec4 = new IntVec(vecNativeId[3]); - - assertTrue(vec1.get(10) == 11); - String tmp1 = new String(vec2.get(4080)); - assertTrue(tmp1.equals("AAAAAAAABPPAAAAA")); - String tmp2 = new String(vec3.get(4070)); - assertTrue(tmp2.equals("Particular, arab cases shall like less current, different names. Computers start for the changes. Scottish, trying exercises operate marks; long, supreme miners may ro")); - assertTrue(vec4.get(1000) == 0); - vec1.close(); - vec2.close(); - vec3.close(); - vec4.close(); - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java deleted file mode 100644 index 99801bcfb86567a5a2cb44dc43e4428496b00ed3..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark.jni; - -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.io.Input; -import junit.framework.TestCase; -import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.IntVec; -import nova.hetu.omniruntime.vector.LongVec; -import nova.hetu.omniruntime.vector.VarcharVec; -import nova.hetu.omniruntime.vector.Vec; -import org.apache.commons.codec.binary.Base64; -import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; -import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl; -import org.apache.orc.OrcConf; -import org.apache.orc.OrcFile; -import org.apache.orc.Reader; -import org.apache.orc.TypeDescription; -import org.apache.orc.mapred.OrcInputFormat; -import org.json.JSONObject; -import org.junit.After; -import org.junit.Before; -import org.junit.FixMethodOrder; -import org.junit.Test; -import org.junit.runners.MethodSorters; -import org.apache.hadoop.conf.Configuration; -import java.io.File; -import java.util.ArrayList; -import org.apache.orc.Reader.Options; - -import static org.junit.Assert.*; - -@FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) -public class OrcColumnarBatchJniReaderTest extends TestCase { - public Configuration conf = new Configuration(); - public OrcColumnarBatchJniReader orcColumnarBatchJniReader; - public int batchSize = 4096; - - @Before - public void setUp() throws Exception { - Configuration conf = new Configuration(); - TypeDescription schema = - TypeDescription.fromString("struct<`i_item_sk`:bigint,`i_item_id`:string>"); - Options options = new Options(conf) - .range(0, Integer.MAX_VALUE) - .useZeroCopy(false) - .skipCorruptRecords(false) - .tolerateMissingSchema(true); - - options.schema(schema); - options.include(OrcInputFormat.parseInclude(schema, - null)); - String kryoSarg = "AQEAb3JnLmFwYWNoZS5oYWRvb3AuaGl2ZS5xbC5pby5zYXJnLkV4cHJlc3Npb25UcmXlAQEBamF2YS51dGlsLkFycmF5TGlz9AECAQABAQEBAQEAAQAAAAEEAAEBAwEAAQEBAQEBAAEAAAIIAAEJAAEBAgEBAQIBAscBb3JnLmFwYWNoZS5oYWRvb3AuaGl2ZS5xbC5pby5zYXJnLlNlYXJjaEFyZ3VtZW50SW1wbCRQcmVkaWNhdGVMZWFmSW1wbAEBaV9pdGVtX3PrAAABBwEBAQIBEAkAAAEEEg=="; - String sargColumns = "i_item_sk,i_item_id,i_rec_start_date,i_rec_end_date,i_item_desc,i_current_price,i_wholesale_cost,i_brand_id,i_brand,i_class_id,i_class,i_category_id,i_category,i_manufact_id,i_manufact,i_size,i_formulation,i_color,i_units,i_container,i_manager_id,i_product_name"; - if (kryoSarg != null && sargColumns != null) { - byte[] sargBytes = Base64.decodeBase64(kryoSarg); - SearchArgument sarg = - new Kryo().readObject(new Input(sargBytes), SearchArgumentImpl.class); - options.searchArgument(sarg, sargColumns.split(",")); - sarg.getExpression().toString(); - } - - orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); - initReaderJava(); - initRecordReaderJava(options); - initBatch(options); - } - - @After - public void tearDown() throws Exception { - System.out.println("orcColumnarBatchJniReader test finished"); - } - - public void initReaderJava() { - OrcFile.ReaderOptions readerOptions = OrcFile.readerOptions(conf); - File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); - String path = directory.getAbsolutePath(); - orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReaderJava(path, readerOptions); - assertTrue(orcColumnarBatchJniReader.reader != 0); - } - - public void initRecordReaderJava(Options options) { - orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReaderJava(options); - assertTrue(orcColumnarBatchJniReader.recordReader != 0); - } - - public void initBatch(Options options) { - orcColumnarBatchJniReader.initBatchJava(batchSize); - assertTrue(orcColumnarBatchJniReader.batchReader != 0); - } - - @Test - public void testNext() { - Vec[] vecs = new Vec[2]; - long rtn = orcColumnarBatchJniReader.next(vecs); - assertTrue(rtn == 4096); - assertTrue(((LongVec) vecs[0]).get(0) == 1); - String str = new String(((VarcharVec) vecs[1]).get(0)); - assertTrue(str.equals("AAAAAAAABAAAAAAA")); - vecs[0].close(); - vecs[1].close(); - } - - @Test - public void testGetProgress() { - String tmp = ""; - try { - double progressValue = orcColumnarBatchJniReader.getProgress(); - } catch (Exception e) { - tmp = e.getMessage(); - } finally { - assertTrue(tmp.equals("recordReaderGetProgress is unsupported")); - } - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0 b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0 deleted file mode 100644 index 65e4e602cebab6ce7ca576d8ab20b1ae8841c981..0000000000000000000000000000000000000000 Binary files a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0 and /dev/null differ diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/part-00000-2d6ca713-08b0-4b40-828c-f7ee0c81bb9a-c000.snappy.orc b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/part-00000-2d6ca713-08b0-4b40-828c-f7ee0c81bb9a-c000.snappy.orc deleted file mode 100644 index a79c7be758d63ce6f56b16deed0765e85c96a866..0000000000000000000000000000000000000000 Binary files a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/part-00000-2d6ca713-08b0-4b40-828c-f7ee0c81bb9a-c000.snappy.orc and /dev/null differ diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarNativeReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarNativeReaderTest.java deleted file mode 100644 index fc7a2e2d9518ea2c6f556123a2645e808d79d373..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarNativeReaderTest.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.orc; - -import junit.framework.TestCase; -import org.apache.orc.Reader.Options; -import org.apache.hadoop.conf.Configuration; -import org.junit.After; -import org.junit.Before; -import org.junit.FixMethodOrder; -import org.junit.Test; -import org.junit.runners.MethodSorters; - -@FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) -public class OrcColumnarNativeReaderTest extends TestCase{ - - @Before - public void setUp() throws Exception { - } - - @After - public void tearDown() throws Exception { - System.out.println("OrcColumnarNativeReaderTest test finished"); - } - - @Test - public void testBuildOptions() { - Configuration conf = new Configuration(); - Options options = OrcColumnarNativeReader.buildOptions(conf,0,1024); - assertTrue(options.getLength() == 1024L); - assertTrue(options.getOffset() == 0L); - } - -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/org/apache/spark/sql/execution/vectorized/OmniColumnVectorTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/org/apache/spark/sql/execution/vectorized/OmniColumnVectorTest.java deleted file mode 100644 index 4a36d7b3fda8948958383bfc00e5a50136537248..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/java/org/apache/spark/sql/execution/vectorized/OmniColumnVectorTest.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.vectorized; - -import junit.framework.TestCase; -import nova.hetu.omniruntime.vector.*; -import org.apache.orc.Reader.Options; -import org.apache.hadoop.conf.Configuration; -import org.apache.spark.sql.execution.datasources.orc.OrcColumnarNativeReader; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DataTypes; -import org.junit.After; -import org.junit.Before; -import org.junit.FixMethodOrder; -import org.junit.Test; -import org.junit.runners.MethodSorters; - -import javax.validation.constraints.AssertTrue; - -import static org.junit.Assert.*; - -@FixMethodOrder(value = MethodSorters.NAME_ASCENDING) -public class OmniColumnVectorTest extends TestCase { - - @Before - public void setUp() throws Exception { - } - - @After - public void tearDown() throws Exception { - System.out.println("OmniColumnVectorTest test finished"); - } - - - @Test - public void testNewOmniColumnVector() { - OmniColumnVector vecTmp = new OmniColumnVector(4096, DataTypes.LongType, true); - LongVec vecLong = new LongVec(4096); - vecTmp.setVec(vecLong); - vecTmp.putLong(0, 123L); - assertTrue(vecTmp.getLong(0) == 123L); - assertTrue(vecTmp.getVec() != null); - vecTmp.close(); - - OmniColumnVector vecTmp1 = new OmniColumnVector(4096, DataTypes.IntegerType, true); - IntVec vecInt = new IntVec(4096); - vecTmp1.setVec(vecInt); - vecTmp1.putInt(0, 123); - assertTrue(vecTmp1.getInt(0) == 123); - assertTrue(vecTmp1.getVec() != null); - vecTmp1.close(); - - OmniColumnVector vecTmp3 = new OmniColumnVector(4096, DataTypes.BooleanType, true); - BooleanVec vecBoolean = new BooleanVec(4096); - vecTmp3.setVec(vecBoolean); - vecTmp3.putBoolean(0, true); - assertTrue(vecTmp3.getBoolean(0) == true); - assertTrue(vecTmp3.getVec() != null); - vecTmp3.close(); - - OmniColumnVector vecTmp4 = new OmniColumnVector(4096, DataTypes.BooleanType, false); - BooleanVec vecBoolean1 = new BooleanVec(4096); - vecTmp4.setVec(vecBoolean1); - vecTmp4.putBoolean(0, true); - assertTrue(vecTmp4.getBoolean(0) == true); - assertTrue(vecTmp4.getVec() != null); - vecTmp4.close(); - } - - @Test - public void testGetsPuts() { - OmniColumnVector vecTmp = new OmniColumnVector(4096, DataTypes.LongType, true); - LongVec vecLong = new LongVec(4096); - vecTmp.setVec(vecLong); - vecTmp.putLongs(0, 10, 123L); - long[] gets = vecTmp.getLongs(0, 10); - for (long i : gets) { - assertTrue(i == 123L); - } - assertTrue(vecTmp.getVec() != null); - vecTmp.close(); - - OmniColumnVector vecTmp1 = new OmniColumnVector(4096, DataTypes.IntegerType, true); - IntVec vecInt = new IntVec(4096); - vecTmp1.setVec(vecInt); - vecTmp1.putInts(0, 10, 123); - int[] getInts = vecTmp1.getInts(0, 10); - for (int i : getInts) { - assertTrue(i == 123); - } - assertTrue(vecTmp1.getVec() != null); - vecTmp1.close(); - - OmniColumnVector vecTmp3 = new OmniColumnVector(4096, DataTypes.BooleanType, true); - BooleanVec vecBoolean = new BooleanVec(4096); - vecTmp3.setVec(vecBoolean); - vecTmp3.putBooleans(0, 10, true); - boolean[] getBools = vecTmp3.getBooleans(0, 10); - for (boolean i : getBools) { - assertTrue(i == true); - } - assertTrue(vecTmp3.getVec() != null); - vecTmp3.close(); - - OmniColumnVector vecTmp4 = new OmniColumnVector(4096, DataTypes.BooleanType, false); - BooleanVec vecBoolean1 = new BooleanVec(4096); - vecTmp4.setVec(vecBoolean1); - vecTmp4.putBooleans(0, 10, true); - boolean[] getBools1 = vecTmp4.getBooleans(0, 10); - for (boolean i : getBools1) { - assertTrue(i == true); - } - System.out.println(vecTmp4.getBoolean(0)); - assertTrue(vecTmp4.getBoolean(0) == true); - assertTrue(vecTmp4.getVec() != null); - vecTmp4.close(); - } - -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties deleted file mode 100644 index 89eabe8e6a2fa36aa9523293e7368a3076856dd1..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties +++ /dev/null @@ -1,12 +0,0 @@ -# -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. -# - -hive.metastore.uris=thrift://server1:9083 -spark.sql.warehouse.dir=/user/hive/warehouse -spark.memory.offHeap.size=8G -spark.sql.codegen.wholeStage=false -spark.sql.extensions=com.huawei.boostkit.spark.ColumnarPlugin -spark.shuffle.manager=org.apache.spark.shuffle.sort.OmniColumnarShuffleManager -spark.sql.orc.impl=native -hive.db=tpcds_bin_partitioned_orc_2 \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q1.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q1.sql deleted file mode 100644 index 6478818e67814d2bc3c3a2239bb7a0000f27b1e0..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q1.sql +++ /dev/null @@ -1,14 +0,0 @@ -select i_item_id - ,i_item_desc - ,i_current_price -from item, inventory, date_dim, store_sales -where i_current_price between 76 and 76+30 -and inv_item_sk = i_item_sk -and d_date_sk=inv_date_sk -and d_date between cast('1998-06-29' as date) and cast('1998-08-29' as date) -and i_manufact_id in (512,409,677,16) -and inv_quantity_on_hand between 100 and 500 -and ss_item_sk = i_item_sk -group by i_item_id,i_item_desc,i_current_price -order by i_item_id -limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q10.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q10.sql deleted file mode 100644 index 9ac4277eba4447c7205ed294fb038bf4c17955a5..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q10.sql +++ /dev/null @@ -1,36 +0,0 @@ -select - i_brand_id brand_id, - i_brand brand, - i_manufact_id, - i_manufact, - sum(ss_ext_sales_price) ext_price -from - date_dim, - store_sales, - item, - customer, - customer_address, - store -where - d_date_sk = ss_sold_date_sk - and ss_item_sk = i_item_sk - and i_manager_id = 7 - and d_moy = 11 - and d_year = 1999 - and ss_customer_sk = c_customer_sk - and c_current_addr_sk = ca_address_sk - and substr(ca_zip,1,5) <> substr(s_zip,1,5) - and ss_store_sk = s_store_sk - and ss_sold_date_sk between 2451484 and 2451513 -- partition key filter -group by - i_brand, - i_brand_id, - i_manufact_id, - i_manufact -order by - ext_price desc, - i_brand, - i_brand_id, - i_manufact_id, - i_manufact -limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q2.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q2.sql deleted file mode 100644 index 5a2ade87aa05decff9262402dfe547910211d730..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q2.sql +++ /dev/null @@ -1,48 +0,0 @@ -with v1 as ( - select i_category, i_brand, - s_store_name, s_company_name, - d_year, d_moy, - sum(ss_sales_price) sum_sales, - avg(sum(ss_sales_price)) over - (partition by i_category, i_brand, - s_store_name,s_company_name,d_year) - avg_monthly_sales, - rank() over - (partition by i_category, i_brand, - s_store_name,s_company_name - order by d_year,d_moy) rn - from item, store_sales, date_dim, store - where ss_item_sk = i_item_sk and - ss_sold_date_sk = d_date_sk and - ss_store_sk = s_store_sk and - ( - d_year = 2000 or - ( d_year = 2000-1 and d_moy =12) or - ( d_year = 2000+1 and d_moy =1) - ) - group by i_category, i_brand, - s_store_name, s_company_name, - d_year, d_moy), - v2 as( - select v1.i_category, v1.i_brand - ,v1.d_year - ,v1.avg_monthly_sales - ,v1.sum_sales, v1_lag.sum_sales psum, v1_lead.sum_sales nsum - from v1, v1 v1_lag, v1 v1_lead - where v1.i_category = v1_lag.i_category and - v1.i_category = v1_lead.i_category and - v1.i_brand = v1_lag.i_brand and - v1.i_brand = v1_lead.i_brand and - v1.s_store_name = v1_lag.s_store_name and - v1.s_store_name = v1_lead.s_store_name and - v1.s_company_name = v1_lag.s_company_name and - v1.s_company_name = v1_lead.s_company_name and - v1.rn = v1_lag.rn + 1 and - v1.rn = v1_lead.rn -1) -select * -from v2 -where d_year = 2000 and - avg_monthly_sales > 0 and - case when avg_monthly_sales > 0 then abs(sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1 -order by sum_sales - avg_monthly_sales, d_year -limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q3.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q3.sql deleted file mode 100644 index 33bd52ce6e07c6b7d214f23ce4cc2ab6bc23c707..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q3.sql +++ /dev/null @@ -1,34 +0,0 @@ -select - * -from - (select - i_manufact_id, - sum(ss_sales_price) sum_sales, - avg(sum(ss_sales_price)) over (partition by i_manufact_id) avg_quarterly_sales - from - item, - store_sales, - date_dim, - store - where - ss_item_sk = i_item_sk - and ss_sold_date_sk = d_date_sk - and ss_store_sk = s_store_sk - and d_month_seq in (1212, 1212 + 1, 1212 + 2, 1212 + 3, 1212 + 4, 1212 +5, 1212 + 6, 1212+7, 1212 + 8, 1212 + 9, 1212 + 10, 1212 + 11) - and ((i_category in ('Books', 'Children', 'Electronics') - and i_class in ('personal', 'portable', 'reference', 'self-help') - and i_brand in ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')) - or (i_category in ('Women', 'Music', 'Men') - and i_class in ('accessories', 'classical', 'fragrances', 'pants') - and i_brand in ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1'))) - group by - i_manufact_id, - d_qoy - ) tmp1 -where - case when avg_quarterly_sales > 0 then abs (sum_sales -avg_quarterly_sales) / avg_quarterly_sales else null end > 0.1 -order by - avg_quarterly_sales, - sum_sales, - i_manufact_id -limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q4.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q4.sql deleted file mode 100644 index 258c73813f4fb2f1f911c678c7f6996c05f9c15d..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q4.sql +++ /dev/null @@ -1,35 +0,0 @@ -select i_brand_id brand_id, i_brand brand,t_hour,t_minute, sum(ext_price) ext_price -from item, (select ws_ext_sales_price as ext_price, - ws_sold_date_sk as sold_date_sk, - ws_item_sk as sold_item_sk, - ws_sold_time_sk as time_sk - from web_sales,date_dim - where d_date_sk = ws_sold_date_sk - and d_moy=12 - and d_year=2001 - union all - select cs_ext_sales_price as ext_price, - cs_sold_date_sk as sold_date_sk, - cs_item_sk as sold_item_sk, - cs_sold_time_sk as time_sk - from catalog_sales,date_dim - where d_date_sk = cs_sold_date_sk - and d_moy=12 - and d_year=2001 - union all - select ss_ext_sales_price as ext_price, - ss_sold_date_sk as sold_date_sk, - ss_item_sk as sold_item_sk, - ss_sold_time_sk as time_sk - from store_sales,date_dim - where d_date_sk = ss_sold_date_sk - and d_moy=12 - and d_year=2001 - ) as tmp,time_dim -where - sold_item_sk = i_item_sk - and time_sk = t_time_sk - and i_manager_id=1 - and (t_meal_time = 'breakfast' or t_meal_time = 'dinner') -group by i_brand, i_brand_id,t_hour,t_minute -order by ext_price desc, brand_id; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q5.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q5.sql deleted file mode 100644 index 4a8c7bc9d70ba5c1a3c11ede3881206851066e56..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q5.sql +++ /dev/null @@ -1,20 +0,0 @@ -select - c_customer_id as customer_id - ,c_last_name || ', ' || c_first_name as customername - from - customer - ,customer_address - ,customer_demographics - ,household_demographics - ,income_band - ,store_returns - where ca_city = 'Hopewell' - and c_current_addr_sk = ca_address_sk - and ib_lower_bound >= 32287 - and ib_upper_bound <= 82287 - and ib_income_band_sk = hd_income_band_sk - and cd_demo_sk = c_current_cdemo_sk - and hd_demo_sk = c_current_hdemo_sk - and sr_cdemo_sk = cd_demo_sk - order by customer_id - limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q6.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q6.sql deleted file mode 100644 index 221c169e32482c68ebbe5b9011cc4b43934ee8c2..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q6.sql +++ /dev/null @@ -1,25 +0,0 @@ -select * -from (select i_manager_id - ,sum(ss_sales_price) sum_sales - ,avg(sum(ss_sales_price)) over (partition by i_manager_id) avg_monthly_sales - from item - ,store_sales - ,date_dim - ,store - where ss_item_sk = i_item_sk -and ss_sold_date_sk = d_date_sk -and ss_sold_date_sk between 2452123 and 2452487 -and ss_store_sk = s_store_sk -and d_month_seq in (1219,1219+1,1219+2,1219+3,1219+4,1219+5,1219+6,1219+7,1219+8,1219+9,1219+10,1219+11) -and (( i_category in ('Books','Children','Electronics') - and i_class in ('personal','portable','reference','self-help') - and i_brand in ('scholaramalgamalg #14','scholaramalgamalg #7', 'exportiunivamalg #9','scholaramalgamalg #9')) -or( i_category in ('Women','Music','Men') - and i_class in ('accessories','classical','fragrances','pants') - and i_brand in ('amalgimporto #1','edu packscholar #1','exportiimporto #1', 'importoamalg #1'))) -group by i_manager_id, d_moy) tmp1 -where case when avg_monthly_sales > 0 then abs (sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1 -order by i_manager_id - ,avg_monthly_sales - ,sum_sales -limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q7.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q7.sql deleted file mode 100644 index a42e5d9887c3e53bfd1570c496f2aab4b41ed5a3..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q7.sql +++ /dev/null @@ -1,33 +0,0 @@ -select - substr(w_warehouse_name,1,20) - ,sm_type - ,cc_name - ,sum(case when (cs_ship_date_sk - cs_sold_date_sk <= 30 ) then 1 else 0 end) as D30_days - ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 30) and - (cs_ship_date_sk - cs_sold_date_sk <= 60) then 1 else 0 end ) as D31_60_days - ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 60) and - (cs_ship_date_sk - cs_sold_date_sk <= 90) then 1 else 0 end) as D61_90_days - ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 90) and - (cs_ship_date_sk - cs_sold_date_sk <= 120) then 1 else 0 end) as D91_120_days - ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 120) then 1 else 0 end) as D120_days -from - catalog_sales - ,warehouse - ,ship_mode - ,call_center - ,date_dim -where - d_month_seq between 1202 and 1202 + 11 --- equivalent to 2451605 2451969 -and cs_ship_date_sk = d_date_sk -and cs_warehouse_sk = w_warehouse_sk -and cs_ship_mode_sk = sm_ship_mode_sk -and cs_call_center_sk = cc_call_center_sk -group by - substr(w_warehouse_name,1,20) - ,sm_type - ,cc_name -order by substr(w_warehouse_name,1,20) - ,sm_type - ,cc_name -limit 100 ; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q8.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q8.sql deleted file mode 100644 index 564b59b2460ae127197a808dbcff39e69c10e649..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q8.sql +++ /dev/null @@ -1,41 +0,0 @@ -select - * -from - (select - i_category, - i_class, - i_brand, - s_store_name, - s_company_name, - d_moy, - sum(ss_sales_price) sum_sales, - avg(sum(ss_sales_price)) over (partition by i_category, i_brand, s_store_name, s_company_name) avg_monthly_sales - from - item, - store_sales, - date_dim, - store - where - ss_item_sk = i_item_sk - and ss_sold_date_sk = d_date_sk - and ss_store_sk = s_store_sk - and d_year in (2000) - and ((i_category in ('Home', 'Books', 'Electronics') - and i_class in ('wallpaper', 'parenting', 'musical')) - or (i_category in ('Shoes', 'Jewelry', 'Men') - and i_class in ('womens', 'birdal', 'pants'))) - and ss_sold_date_sk between 2451545 and 2451910 -- partition key filter - group by - i_category, - i_class, - i_brand, - s_store_name, - s_company_name, - d_moy - ) tmp1 -where - case when (avg_monthly_sales <> 0) then (abs(sum_sales - avg_monthly_sales) / avg_monthly_sales) else null end > 0.1 -order by - sum_sales - avg_monthly_sales, - s_store_name -limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q9.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q9.sql deleted file mode 100644 index 26350730a79bebcd470f73806282efd707b4d7df..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q9.sql +++ /dev/null @@ -1,44 +0,0 @@ -select - c_last_name, - c_first_name, - substr(s_city,1,30), - ss_ticket_number, - amt, - profit -from - (select - ss_ticket_number, - ss_customer_sk, - store.s_city, - sum(ss_coupon_amt) amt, - sum(ss_net_profit) profit - from - store_sales, - date_dim, - store, - household_demographics - where - store_sales.ss_sold_date_sk = date_dim.d_date_sk - and store_sales.ss_store_sk = store.s_store_sk - and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk - and (household_demographics.hd_dep_count = 8 - or household_demographics.hd_vehicle_count >0) - and date_dim.d_dow = 1 - and date_dim.d_year in (1998,1998+1,1998+2) - and store.s_number_employees between 200 and 295 - and ss_sold_date_sk between 2450819 and 2451904 - group by - ss_ticket_number, - ss_customer_sk, - ss_addr_sk, - store.s_city - ) ms, - customer -where - ss_customer_sk = c_customer_sk -order by - c_last_name, - c_first_name, - substr(s_city,1,30), - profit -limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/Vectorized/OmniColumnVectorSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/Vectorized/OmniColumnVectorSuite.scala deleted file mode 100644 index a9fc4452338869610ed8800d0417eddf9a148989..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/Vectorized/OmniColumnVectorSuite.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark.vectorized - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.vectorized.OmniColumnVector -import org.apache.spark.sql.types._ - -class OmniColumnVectorSuite extends SparkFunSuite { - test("int") { - val schema = new StructType().add("int", IntegerType); - val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns(4, schema, true) - vectors(0).putInt(0, 1) - vectors(0).putInt(1, 2) - vectors(0).putInt(2, 3) - vectors(0).putInt(3, 4) - assert(1 == vectors(0).getInt(0)) - assert(2 == vectors(0).getInt(1)) - assert(3 == vectors(0).getInt(2)) - assert(4 == vectors(0).getInt(3)) - vectors(0).close() - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala deleted file mode 100644 index bf8e24dd53841fee21ba4cab3c8bd11172eba76a..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala +++ /dev/null @@ -1,395 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark.expression - -import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{getExprIdMap, procCaseWhenExpression, procLikeExpression, rewriteToOmniExpressionLiteral, rewriteToOmniJsonExpressionLiteral} -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Max, Min, Sum} -import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, StringType} - -/** - * 功能描述 - * - * @author w00630100 - * @since 2022-02-21 - */ -class OmniExpressionAdaptorSuite extends SparkFunSuite { - var allAttribute = Seq(AttributeReference("a", IntegerType)(), - AttributeReference("b", IntegerType)(), AttributeReference("c", BooleanType)(), - AttributeReference("d", BooleanType)(), AttributeReference("e", IntegerType)(), - AttributeReference("f", StringType)(), AttributeReference("g", StringType)()) - - test("expression rewrite") { - checkExpressionRewrite("$operator$ADD:1(#0,#1)", Add(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$ADD:1(#0,1:1)", Add(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$SUBTRACT:1(#0,#1)", - Subtract(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$SUBTRACT:1(#0,1:1)", Subtract(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$MULTIPLY:1(#0,#1)", - Multiply(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$MULTIPLY:1(#0,1:1)", Multiply(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$DIVIDE:1(#0,#1)", Divide(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$DIVIDE:1(#0,1:1)", Divide(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$MODULUS:1(#0,#1)", - Remainder(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$MODULUS:1(#0,1:1)", Remainder(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$GREATER_THAN:4(#0,#1)", - GreaterThan(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$GREATER_THAN:4(#0,1:1)", - GreaterThan(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$GREATER_THAN_OR_EQUAL:4(#0,#1)", - GreaterThanOrEqual(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$GREATER_THAN_OR_EQUAL:4(#0,1:1)", - GreaterThanOrEqual(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$LESS_THAN:4(#0,#1)", - LessThan(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$LESS_THAN:4(#0,1:1)", - LessThan(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$LESS_THAN_OR_EQUAL:4(#0,#1)", - LessThanOrEqual(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$LESS_THAN_OR_EQUAL:4(#0,1:1)", - LessThanOrEqual(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$EQUAL:4(#0,#1)", EqualTo(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$EQUAL:4(#0,1:1)", EqualTo(allAttribute(0), Literal(1))) - - checkExpressionRewrite("OR:4(#2,#3)", Or(allAttribute(2), allAttribute(3))) - checkExpressionRewrite("OR:4(#2,3:1)", Or(allAttribute(2), Literal(3))) - - checkExpressionRewrite("AND:4(#2,#3)", And(allAttribute(2), allAttribute(3))) - checkExpressionRewrite("AND:4(#2,3:1)", And(allAttribute(2), Literal(3))) - - checkExpressionRewrite("not:4(#3)", Not(allAttribute(3))) - - checkExpressionRewrite("IS_NOT_NULL:4(#4)", IsNotNull(allAttribute(4))) - - checkExpressionRewrite("substr:15(#5,#0,#1)", - Substring(allAttribute(5), allAttribute(0), allAttribute(1))) - - checkExpressionRewrite("CAST:2(#1)", Cast(allAttribute(1), LongType)) - - checkExpressionRewrite("abs:1(#0)", Abs(allAttribute(0))) - - checkExpressionRewrite("SUM:2(#0)", Sum(allAttribute(0))) - - checkExpressionRewrite("MAX:1(#0)", Max(allAttribute(0))) - - checkExpressionRewrite("AVG:3(#0)", Average(allAttribute(0))) - - checkExpressionRewrite("MIN:1(#0)", Min(allAttribute(0))) - - checkExpressionRewrite("IN:4(#0,#0,#1)", - In(allAttribute(0), Seq(allAttribute(0), allAttribute(1)))) - - // checkExpressionRewrite("IN:4(#0, #0, #1)", InSet(allAttribute(0), Set(allAttribute(0), allAttribute(1)))) - } - - test("json expression rewrite") { - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"ADD\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", - Add(allAttribute(0), allAttribute(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"ADD\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", - Add(allAttribute(0), Literal(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"SUBTRACT\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", - Subtract(allAttribute(0), allAttribute(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"SUBTRACT\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", - Subtract(allAttribute(0), Literal(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"MULTIPLY\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", - Multiply(allAttribute(0), allAttribute(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"MULTIPLY\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", - Multiply(allAttribute(0), Literal(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"DIVIDE\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", - Divide(allAttribute(0), allAttribute(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"DIVIDE\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", - Divide(allAttribute(0), Literal(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"MODULUS\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", - Remainder(allAttribute(0), allAttribute(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"MODULUS\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", - Remainder(allAttribute(0), Literal(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + - "\"operator\":\"GREATER_THAN\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", - GreaterThan(allAttribute(0), allAttribute(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + - "\"operator\":\"GREATER_THAN\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", - GreaterThan(allAttribute(0), Literal(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + - "\"operator\":\"GREATER_THAN_OR_EQUAL\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", - GreaterThanOrEqual(allAttribute(0), allAttribute(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + - "\"operator\":\"GREATER_THAN_OR_EQUAL\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", - GreaterThanOrEqual(allAttribute(0), Literal(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", - LessThan(allAttribute(0), allAttribute(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", - LessThan(allAttribute(0), Literal(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + - "\"operator\":\"LESS_THAN_OR_EQUAL\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", - LessThanOrEqual(allAttribute(0), allAttribute(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + - "\"operator\":\"LESS_THAN_OR_EQUAL\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", - LessThanOrEqual(allAttribute(0), Literal(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", - EqualTo(allAttribute(0), allAttribute(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", - EqualTo(allAttribute(0), Literal(1))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":2}," + - "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":3}}", - Or(allAttribute(2), allAttribute(3))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":2}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":3}}", - Or(allAttribute(2), Literal(3))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":2}," + - "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":3}}", - And(allAttribute(2), allAttribute(3))) - - checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\"," + - "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":2}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":3}}", - And(allAttribute(2), Literal(3))) - - checkJsonExprRewrite("{\"exprType\":\"UNARY\",\"returnType\":4, \"operator\":\"not\"," + - "\"expr\":{\"exprType\":\"IS_NULL\",\"returnType\":4," + - "\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4}]}}", - IsNotNull(allAttribute(4))) - - checkJsonExprRewrite("{\"exprType\":\"FUNCTION\",\"returnType\":2,\"function_name\":\"CAST\"," + - "\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}]}", - Cast(allAttribute(1), LongType)) - - checkJsonExprRewrite("{\"exprType\":\"FUNCTION\",\"returnType\":1,\"function_name\":\"abs\"," + - " \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}]}", - Abs(allAttribute(0))) - - checkJsonExprRewrite("{\"exprType\":\"FUNCTION\",\"returnType\":1,\"function_name\":\"round\"," + - " \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0},{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":2}]}", - Round(allAttribute(0), Literal(2))) - } - - protected def checkExpressionRewrite(expected: Any, expression: Expression): Unit = { - { - val runResult = rewriteToOmniExpressionLiteral(expression, getExprIdMap(allAttribute)) - if (!expected.equals(runResult)) { - fail(s"expression($expression) not match with expected value:$expected," + - s"running value:$runResult") - } - } - } - - protected def checkJsonExprRewrite(expected: Any, expression: Expression): Unit = { - val runResult = rewriteToOmniJsonExpressionLiteral(expression, getExprIdMap(allAttribute)) - if (!expected.equals(runResult)) { - fail(s"expression($expression) not match with expected value:$expected," + - s"running value:$runResult") - } - } - - test("json expression rewrite support Chinese") { - val cnAttribute = Seq(AttributeReference("char_1", StringType)(), AttributeReference("char_20", StringType)(), - AttributeReference("varchar_1", StringType)(), AttributeReference("varchar_20", StringType)()) - - val like = Like(cnAttribute(2), Literal("我_"), '\\'); - val likeResult = procLikeExpression(like, getExprIdMap(cnAttribute)) - val likeExp = "{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\", \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000}, {\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"^我.$\",\"width\":4}]}" - if (!likeExp.equals(likeResult)) { - fail(s"expression($like) not match with expected value:$likeExp," + - s"running value:$likeResult") - } - - val startsWith = StartsWith(cnAttribute(2), Literal("我")); - val startsWithResult = procLikeExpression(startsWith, getExprIdMap(cnAttribute)) - val startsWithExp = "{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\", \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000}, {\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"^我.*$\",\"width\":5}]}" - if (!startsWithExp.equals(startsWithResult)) { - fail(s"expression($startsWith) not match with expected value:$startsWithExp," + - s"running value:$startsWithResult") - } - - val endsWith = EndsWith(cnAttribute(2), Literal("我")); - val endsWithResult = procLikeExpression(endsWith, getExprIdMap(cnAttribute)) - val endsWithExp = "{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\", \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000}, {\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"^.*我$\",\"width\":5}]}" - if (!endsWithExp.equals(endsWithResult)) { - fail(s"expression($endsWith) not match with expected value:$endsWithExp," + - s"running value:$endsWithResult") - } - - val contains = Contains(cnAttribute(2), Literal("我")); - val containsResult = procLikeExpression(contains, getExprIdMap(cnAttribute)) - val containsExp = "{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\", \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000}, {\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"^.*我.*$\",\"width\":7}]}" - if (!containsExp.equals(containsResult)) { - fail(s"expression($contains) not match with expected value:$containsExp," + - s"running value:$containsResult") - } - - val t1 = new Tuple2(Not(EqualTo(cnAttribute(0), Literal("新"))), Not(EqualTo(cnAttribute(1), Literal("官方爸爸")))) - val t2 = new Tuple2(Not(EqualTo(cnAttribute(2), Literal("爱你三千遍"))), Not(EqualTo(cnAttribute(2), Literal("新")))) - val branch = Seq(t1, t2) - val elseValue = Some(Not(EqualTo(cnAttribute(3), Literal("啊水水水水")))) - val caseWhen = CaseWhen(branch, elseValue); - val caseWhenResult = rewriteToOmniJsonExpressionLiteral(caseWhen, getExprIdMap(cnAttribute)) - val caseWhenExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"官方爸爸\",\"width\":4}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"爱你三千遍\",\"width\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":3,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"啊水水水水\",\"width\":5}}}}" - if (!caseWhenExp.equals(caseWhenResult)) { - fail(s"expression($caseWhen) not match with expected value:$caseWhenExp," + - s"running value:$caseWhenResult") - } - - val isNull = IsNull(cnAttribute(0)); - val isNullResult = rewriteToOmniJsonExpressionLiteral(isNull, getExprIdMap(cnAttribute)) - val isNullExp = "{\"exprType\":\"IS_NULL\",\"returnType\":4,\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":2000}]}" - if (!isNullExp.equals(isNullResult)) { - fail(s"expression($isNull) not match with expected value:$isNullExp," + - s"running value:$isNullResult") - } - - val children = Seq(cnAttribute(0), cnAttribute(1)) - val coalesce = Coalesce(children); - val coalesceResult = rewriteToOmniJsonExpressionLiteral(coalesce, getExprIdMap(cnAttribute)) - val coalesceExp = "{\"exprType\":\"COALESCE\",\"returnType\":15,\"width\":2000, \"value1\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":2000},\"value2\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":2000}}" - if (!coalesceExp.equals(coalesceResult)) { - fail(s"expression($coalesce) not match with expected value:$coalesceExp," + - s"running value:$coalesceResult") - } - - val children2 = Seq(cnAttribute(0), cnAttribute(1), cnAttribute(2)) - val coalesce2 = Coalesce(children2); - try { - rewriteToOmniJsonExpressionLiteral(coalesce2, getExprIdMap(cnAttribute)) - } catch { - case ex: UnsupportedOperationException => { - println(ex) - } - } - - } - - test("procCaseWhenExpression") { - val caseWhenAttribute = Seq(AttributeReference("char_1", StringType)(), AttributeReference("char_20", StringType)(), - AttributeReference("varchar_1", StringType)(), AttributeReference("varchar_20", StringType)(), - AttributeReference("a", IntegerType)(), AttributeReference("b", IntegerType)()) - - val t1 = new Tuple2(Not(EqualTo(caseWhenAttribute(0), Literal("新"))), Not(EqualTo(caseWhenAttribute(1), Literal("官方爸爸")))) - val t2 = new Tuple2(Not(EqualTo(caseWhenAttribute(2), Literal("爱你三千遍"))), Not(EqualTo(caseWhenAttribute(2), Literal("新")))) - val branch = Seq(t1, t2) - val elseValue = Some(Not(EqualTo(caseWhenAttribute(3), Literal("啊水水水水")))) - val expression = CaseWhen(branch, elseValue); - val runResult = procCaseWhenExpression(expression, getExprIdMap(caseWhenAttribute)) - val filterExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"官方爸爸\",\"width\":4}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"爱你三千遍\",\"width\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":3,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"啊水水水水\",\"width\":5}}}}" - if (!filterExp.equals(runResult)) { - fail(s"expression($expression) not match with expected value:$filterExp," + - s"running value:$runResult") - } - - val t3 = new Tuple2(Not(EqualTo(caseWhenAttribute(4), Literal(5))), Not(EqualTo(caseWhenAttribute(5), Literal(10)))) - val t4 = new Tuple2(LessThan(caseWhenAttribute(4), Literal(15)), GreaterThan(caseWhenAttribute(5), Literal(20))) - val branch2 = Seq(t3, t4) - val elseValue2 = Some(Not(EqualTo(caseWhenAttribute(5), Literal(25)))) - val numExpression = CaseWhen(branch2, elseValue2); - val numResult = procCaseWhenExpression(numExpression, getExprIdMap(caseWhenAttribute)) - val numFilterExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":10}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":15}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"GREATER_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":20}},\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":25}}}}" - if (!numFilterExp.equals(numResult)) { - fail(s"expression($numExpression) not match with expected value:$numFilterExp," + - s"running value:$numResult") - } - - val t5 = new Tuple2(Not(EqualTo(caseWhenAttribute(4), Literal(5))), Not(EqualTo(caseWhenAttribute(5), Literal(10)))) - val t6 = new Tuple2(LessThan(caseWhenAttribute(4), Literal(15)), GreaterThan(caseWhenAttribute(5), Literal(20))) - val branch3 = Seq(t5, t6) - val elseValue3 = None - val noneExpression = CaseWhen(branch3, elseValue3); - val noneResult = procCaseWhenExpression(noneExpression, getExprIdMap(caseWhenAttribute)) - val noneFilterExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":10}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":15}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"GREATER_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":20}},\"if_false\":{\"exprType\":\"LITERAL\",\"dataType\":4,\"isNull\":true}}}" - if (!noneFilterExp.equals(noneResult)) { - fail(s"expression($noneExpression) not match with expected value:$noneFilterExp," + - s"running value:$noneResult") - } - } - - -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala deleted file mode 100644 index 0a08416ff04ba4772f8141e418e51757bc237f12..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright (C) 2021-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark.hive - -import java.util.Properties - -import com.huawei.boostkit.spark.hive.util.HiveResourceRunner -import org.apache.log4j.{Level, LogManager} -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.sql.SparkSession - -/** - * @since 2021/12/15 - */ -class HiveResourceSuite extends SparkFunSuite { - private val QUERY_SQLS = "query-sqls" - private var spark: SparkSession = _ - private var runner: HiveResourceRunner = _ - - override def beforeAll(): Unit = { - val properties = new Properties() - properties.load(this.getClass.getClassLoader.getResourceAsStream("HiveResource.properties")) - - spark = SparkSession.builder() - .appName("test-sql-context") - .master("local[2]") - .config(readConf(properties)) - .enableHiveSupport() - .getOrCreate() - LogManager.getRootLogger.setLevel(Level.WARN) - runner = new HiveResourceRunner(spark, QUERY_SQLS) - - val hiveDb = properties.getProperty("hive.db") - spark.sql(if (hiveDb == null) "use default" else s"use $hiveDb") - } - - override def afterAll(): Unit = { - super.afterAll() - } - - test("queryBySparkSql-HiveDataSource") { - runner.runQuery("q1", 1) - runner.runQuery("q2", 1) - runner.runQuery("q3", 1) - runner.runQuery("q4", 1) - runner.runQuery("q5", 1) - runner.runQuery("q6", 1) - runner.runQuery("q7", 1) - runner.runQuery("q8", 1) - runner.runQuery("q9", 1) - runner.runQuery("q10", 1) - } - - def readConf(properties: Properties): SparkConf = { - val conf = new SparkConf() - val wholeStage = properties.getProperty("spark.sql.codegen.wholeStage") - val offHeapSize = properties.getProperty("spark.memory.offHeap.size") - conf.set("hive.metastore.uris", properties.getProperty("hive.metastore.uris")) - .set("spark.sql.warehouse.dir", properties.getProperty("spark.sql.warehouse.dir")) - .set("spark.memory.offHeap.size", if (offHeapSize == null) "8G" else offHeapSize) - .set("spark.sql.codegen.wholeStage", if (wholeStage == null) "false" else wholeStage) - .set("spark.sql.extensions", properties.getProperty("spark.sql.extensions")) - .set("spark.shuffle.manager", properties.getProperty("spark.shuffle.manager")) - .set("spark.sql.orc.impl", properties.getProperty("spark.sql.orc.impl")) - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/hive/util/HiveResourceRunner.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/hive/util/HiveResourceRunner.scala deleted file mode 100644 index 84e12f6bd5b6f63b57ecba62a27c20cc5e6698fc..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/hive/util/HiveResourceRunner.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (C) 2021-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark.hive.util - -import java.io.{File, FilenameFilter} -import java.nio.charset.StandardCharsets - -import org.apache.commons.io.FileUtils -import org.apache.spark.sql.{Row, SparkSession} - -class HiveResourceRunner(val spark: SparkSession, val resource: String) { - val caseIds = HiveResourceRunner.parseCaseIds(HiveResourceRunner.locateResourcePath(resource), - ".sql") - - def runQuery(caseId: String, roundId: Int, explain: Boolean = false): Unit = { - val path = "%s/%s.sql".format(resource, caseId) - val absolute = HiveResourceRunner.locateResourcePath(path) - val sql = FileUtils.readFileToString(new File(absolute), StandardCharsets.UTF_8) - println("Running query %s (round %d)... ".format(caseId, roundId)) - val df = spark.sql(sql) - if (explain) { - df.explain(extended = true) - } - val result: Array[Row] = df.head(100) - result.foreach(row => println(row)) - } -} - -object HiveResourceRunner { - private def parseCaseIds(dir: String, suffix: String): List[String] = { - val folder = new File(dir) - if (!folder.exists()) { - throw new IllegalArgumentException("dir does not exist: " + dir) - } - folder - .listFiles(new FilenameFilter { - override def accept(dir: File, name: String): Boolean = name.endsWith(suffix) - }) - .map(f => f.getName) - .map(n => n.substring(0, n.lastIndexOf(suffix))) - .sortBy(s => { - //fill with leading zeros - "%s%s".format(new String((0 until 16 - s.length).map(_ => '0').toArray), s) - }) - .toList - } - - private def locateResourcePath(resource: String): String = { - classOf[HiveResourceRunner].getClassLoader.getResource("") - .getPath.concat(File.separator).concat(resource) - } -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala deleted file mode 100644 index 237321f5921726da1b80119936b8bcadaa6f1c95..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala +++ /dev/null @@ -1,246 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.io.{File, FileInputStream} - -import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer -import com.huawei.boostkit.spark.vectorized.PartitionInfo -import nova.hetu.omniruntime.`type`.{DataType, _} -import nova.hetu.omniruntime.vector._ -import org.apache.spark.{HashPartitioner, SparkConf, TaskContext} -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.sort.ColumnarShuffleHandle -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.OmniColumnVector -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.Utils -import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Mockito.{doAnswer, when} -import org.mockito.invocation.InvocationOnMock - -class ColumnShuffleSerializerDisableCompressSuite extends SharedSparkSession { - @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ - @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ - @Mock(answer = RETURNS_SMART_NULLS) private var dependency - : ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = _ - - override def sparkConf: SparkConf = - super.sparkConf - .setAppName("test shuffle serializer disable compress") - .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") - .set("spark.shuffle.compress", "false") - - private var taskMetrics: TaskMetrics = _ - private var tempDir: File = _ - private var outputFile: File = _ - - private var shuffleHandle: ColumnarShuffleHandle[Int, ColumnarBatch] = _ - private val numPartitions = 1 - - protected var avgBatchNumRows: SQLMetric = _ - protected var outputNumRows: SQLMetric = _ - - override def beforeEach(): Unit = { - super.beforeEach() - - avgBatchNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, - "test serializer avg read batch num rows") - outputNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, - "test serializer number of output rows") - - tempDir = Utils.createTempDir() - outputFile = File.createTempFile("shuffle", null, tempDir) - taskMetrics = new TaskMetrics - - MockitoAnnotations.initMocks(this) - - shuffleHandle = - new ColumnarShuffleHandle[Int, ColumnarBatch](shuffleId = 0, dependency = dependency) - - val types : Array[DataType] = Array[DataType]( - IntDataType.INTEGER, - ShortDataType.SHORT, - LongDataType.LONG, - DoubleDataType.DOUBLE, - new Decimal64DataType(18, 3), - new Decimal128DataType(28, 11), - VarcharDataType.VARCHAR, - BooleanDataType.BOOLEAN) - val inputTypes = DataTypeSerializer.serialize(types) - - when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) - when(dependency.serializer).thenReturn(new JavaSerializer(sparkConf)) - when(dependency.partitionInfo).thenReturn( - new PartitionInfo("hash", numPartitions, types.length, inputTypes)) - when(dependency.dataSize) - .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "data size")) - when(dependency.bytesSpilled) - .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "shuffle bytes spilled")) - when(dependency.numInputRows) - .thenReturn(SQLMetrics.createMetric(spark.sparkContext, "number of input rows")) - when(dependency.splitTime) - .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_split")) - when(dependency.spillTime) - .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_spill")) - when(taskContext.taskMetrics()).thenReturn(taskMetrics) - when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) - - doAnswer { (invocationOnMock: InvocationOnMock) => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] - if (tmp != null) { - outputFile.delete - tmp.renameTo(outputFile) - } - null - }.when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])) - } - - override def afterEach(): Unit = { - try { - Utils.deleteRecursively(tempDir) - } finally { - super.afterEach() - } - } - - override def afterAll(): Unit = { - super.afterAll() - } - - test("write shuffle compress for none with null value last") { - val pidArray: Array[java.lang.Integer] = Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) - val intArray: Array[java.lang.Integer] = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, null) - val shortArray: Array[java.lang.Integer] = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, null) - val longArray: Array[java.lang.Long] = Array(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L, - 17L, 18L, 19L, null) - val doubleArray: Array[java.lang.Double] = Array(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.10, 11.11, 12.12, - 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, null) - val decimal64Array: Array[java.lang.Long] = Array(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L, - 17L, 18L, 19L, null) - val decimal128Array: Array[Array[Long]] = Array( - Array(0L, 0L), Array(1L, 1L), Array(2L, 2L), Array(3L, 3L), Array(4L, 4L), Array(5L, 5L), Array(6L, 6L), - Array(7L, 7L), Array(8L, 8L), Array(9L, 9L), Array(10L, 10L), Array(11L, 11L), Array(12L, 12L), Array(13L, 13L), - Array(14L, 14L), Array(15L, 15L), Array(16L, 16L), Array(17L, 17L), Array(18L, 18L), Array(19L, 19L), null) - val stringArray: Array[java.lang.String] = Array("", "a", "bb", "ccc", "dddd", "eeeee", "ffffff", "ggggggg", - "hhhhhhhh", "iiiiiiiii", "jjjjjjjjjj", "kkkkkkkkkkk", "llllllllllll", "mmmmmmmmmmmmm", "nnnnnnnnnnnnnn", - "ooooooooooooooo", "pppppppppppppppp", "qqqqqqqqqqqqqqqqq", "rrrrrrrrrrrrrrrrrr", "sssssssssssssssssss", null) - val booleanArray: Array[java.lang.Boolean] = Array(true, true, true, true, true, true, true, true, true, true, - false, false, false, false, false, false, false, false, false, false, null) - - val pidVector0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(pidArray) - val intVector0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(intArray) - val shortVector0 = ColumnarShuffleWriterSuite.initOmniColumnShortVector(shortArray) - val longVector0 = ColumnarShuffleWriterSuite.initOmniColumnLongVector(longArray) - val doubleVector0 = ColumnarShuffleWriterSuite.initOmniColumnDoubleVector(doubleArray) - val decimal64Vector0 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(decimal64Array) - val decimal128Vector0 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(decimal128Array) - val varcharVector0 = ColumnarShuffleWriterSuite.initOmniColumnVarcharVector(stringArray) - val booleanVector0 = ColumnarShuffleWriterSuite.initOmniColumnBooleanVector(booleanArray) - - val cb0 = ColumnarShuffleWriterSuite.makeColumnarBatch( - pidVector0.getVec.getSize, - List(pidVector0, intVector0, shortVector0, longVector0, doubleVector0, - decimal64Vector0, decimal128Vector0, varcharVector0, booleanVector0) - ) - - val pidVector1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(pidArray) - val intVector1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(intArray) - val shortVector1 = ColumnarShuffleWriterSuite.initOmniColumnShortVector(shortArray) - val longVector1 = ColumnarShuffleWriterSuite.initOmniColumnLongVector(longArray) - val doubleVector1 = ColumnarShuffleWriterSuite.initOmniColumnDoubleVector(doubleArray) - val decimal64Vector1 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(decimal64Array) - val decimal128Vector1 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(decimal128Array) - val varcharVector1 = ColumnarShuffleWriterSuite.initOmniColumnVarcharVector(stringArray) - val booleanVector1 = ColumnarShuffleWriterSuite.initOmniColumnBooleanVector(booleanArray) - - val cb1 = ColumnarShuffleWriterSuite.makeColumnarBatch( - pidVector1.getVec.getSize, - List(pidVector1, intVector1, shortVector1, longVector1, doubleVector1, - decimal64Vector1, decimal128Vector1, varcharVector1, booleanVector1) - ) - - def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb0), (0, cb1)) - - val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( - blockResolver, - shuffleHandle, - 0L, // MapId - taskContext.taskMetrics().shuffleWriteMetrics) - - writer.write(records) - writer.stop(success = true) - - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === 0) - // should be (numPartitions - 2) zero length files - - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) - - val serializer = new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance() - val deserializedStream = serializer.deserializeStream(new FileInputStream(outputFile)) - - try { - val kv = deserializedStream.asKeyValueIterator - var length = 0 - kv.foreach { - case (_, batch: ColumnarBatch) => - length += 1 - assert(batch.numRows == 42) - assert(batch.numCols == 8) - assert(batch.column(0).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[IntVec].get(0) == 0) - assert(batch.column(0).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[IntVec].get(19) == 19) - assert(batch.column(1).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[ShortVec].get(0) == 0) - assert(batch.column(1).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[ShortVec].get(19) == 19) - assert(batch.column(2).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(0) == 0) - assert(batch.column(2).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(19) == 19) - assert(batch.column(3).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[DoubleVec].get(0) == 0.0) - assert(batch.column(3).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[DoubleVec].get(19) == 19.19) - assert(batch.column(4).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(0) == 0L) - assert(batch.column(4).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(19) == 19L) - assert(batch.column(5).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[Decimal128Vec].get(0) sameElements Array(0L, 0L)) - assert(batch.column(5).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[Decimal128Vec].get(19) sameElements Array(19L, 19L)) - assert(batch.column(6).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[VarcharVec].get(0) sameElements "") - assert(batch.column(6).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[VarcharVec].get(19) sameElements "sssssssssssssssssss") - assert(batch.column(7).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[BooleanVec].get(0) == true) - assert(batch.column(7).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[BooleanVec].get(19) == false) - (0 until batch.numCols).foreach { i => - val valueVector = batch.column(i).asInstanceOf[OmniColumnVector].getVec - assert(valueVector.getSize == batch.numRows) - assert(valueVector.isNull(20)) - } - batch.close() - } - assert(length == 1) - } finally { - deserializedStream.close() - } - - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala deleted file mode 100644 index 8f0329248c9cf8e75b277b9cae4a3bd3e5a2e361..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala +++ /dev/null @@ -1,247 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.io.{File, FileInputStream} - -import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer -import com.huawei.boostkit.spark.vectorized.PartitionInfo -import nova.hetu.omniruntime.`type`.{DataType, _} -import nova.hetu.omniruntime.vector._ -import org.apache.spark.{HashPartitioner, SparkConf, TaskContext} -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.sort.ColumnarShuffleHandle -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.OmniColumnVector -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.Utils -import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Mockito.{doAnswer, when} -import org.mockito.invocation.InvocationOnMock - -class ColumnShuffleSerializerLz4Suite extends SharedSparkSession { - @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ - @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ - @Mock(answer = RETURNS_SMART_NULLS) private var dependency - : ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = _ - - override def sparkConf: SparkConf = - super.sparkConf - .setAppName("test shuffle serializer for lz4") - .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") - .set("spark.shuffle.compress", "true") - .set("spark.io.compression.codec", "lz4") - - private var taskMetrics: TaskMetrics = _ - private var tempDir: File = _ - private var outputFile: File = _ - - private var shuffleHandle: ColumnarShuffleHandle[Int, ColumnarBatch] = _ - private val numPartitions = 1 - - protected var avgBatchNumRows: SQLMetric = _ - protected var outputNumRows: SQLMetric = _ - - override def beforeEach(): Unit = { - super.beforeEach() - - avgBatchNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, - "test serializer avg read batch num rows") - outputNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, - "test serializer number of output rows") - - tempDir = Utils.createTempDir() - outputFile = File.createTempFile("shuffle", null, tempDir) - taskMetrics = new TaskMetrics - - MockitoAnnotations.initMocks(this) - - shuffleHandle = - new ColumnarShuffleHandle[Int, ColumnarBatch](shuffleId = 0, dependency = dependency) - - val types : Array[DataType] = Array[DataType]( - IntDataType.INTEGER, - ShortDataType.SHORT, - LongDataType.LONG, - DoubleDataType.DOUBLE, - new Decimal64DataType(18, 3), - new Decimal128DataType(28, 11), - VarcharDataType.VARCHAR, - BooleanDataType.BOOLEAN) - val inputTypes = DataTypeSerializer.serialize(types) - - when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) - when(dependency.serializer).thenReturn(new JavaSerializer(sparkConf)) - when(dependency.partitionInfo).thenReturn( - new PartitionInfo("hash", numPartitions, types.length, inputTypes)) - when(dependency.dataSize) - .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "data size")) - when(dependency.bytesSpilled) - .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "shuffle bytes spilled")) - when(dependency.numInputRows) - .thenReturn(SQLMetrics.createMetric(spark.sparkContext, "number of input rows")) - when(dependency.splitTime) - .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_split")) - when(dependency.spillTime) - .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_spill")) - when(taskContext.taskMetrics()).thenReturn(taskMetrics) - when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) - - doAnswer { (invocationOnMock: InvocationOnMock) => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] - if (tmp != null) { - outputFile.delete - tmp.renameTo(outputFile) - } - null - }.when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])) - } - - override def afterEach(): Unit = { - try { - Utils.deleteRecursively(tempDir) - } finally { - super.afterEach() - } - } - - override def afterAll(): Unit = { - super.afterAll() - } - - test("write shuffle compress for lz4 with no null value") { - val pidArray: Array[java.lang.Integer] = Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) - val intArray: Array[java.lang.Integer] = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20) - val shortArray: Array[java.lang.Integer] = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20) - val longArray: Array[java.lang.Long] = Array(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L, - 17L, 18L, 19L, 20L) - val doubleArray: Array[java.lang.Double] = Array(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.10, 11.11, 12.12, - 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.20) - val decimal64Array: Array[java.lang.Long] = Array(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L, - 17L, 18L, 19L, 20L) - val decimal128Array: Array[Array[Long]] = Array( - Array(0L, 0L), Array(1L, 1L), Array(2L, 2L), Array(3L, 3L), Array(4L, 4L), Array(5L, 5L), Array(6L, 6L), - Array(7L, 7L), Array(8L, 8L), Array(9L, 9L), Array(10L, 10L), Array(11L, 11L), Array(12L, 12L), Array(13L, 13L), - Array(14L, 14L), Array(15L, 15L), Array(16L, 16L), Array(17L, 17L), Array(18L, 18L), Array(19L, 19L), Array(20L, 20L)) - val stringArray: Array[java.lang.String] = Array("", "a", "bb", "ccc", "dddd", "eeeee", "ffffff", "ggggggg", - "hhhhhhhh", "iiiiiiiii", "jjjjjjjjjj", "kkkkkkkkkkk", "llllllllllll", "mmmmmmmmmmmmm", "nnnnnnnnnnnnnn", - "ooooooooooooooo", "pppppppppppppppp", "qqqqqqqqqqqqqqqqq", "rrrrrrrrrrrrrrrrrr", "sssssssssssssssssss", - "tttttttttttttttttttt") - val booleanArray: Array[java.lang.Boolean] = Array(true, true, true, true, true, true, true, true, true, true, - false, false, false, false, false, false, false, false, false, false, false) - - val pidVector0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(pidArray) - val intVector0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(intArray) - val shortVector0 = ColumnarShuffleWriterSuite.initOmniColumnShortVector(shortArray) - val longVector0 = ColumnarShuffleWriterSuite.initOmniColumnLongVector(longArray) - val doubleVector0 = ColumnarShuffleWriterSuite.initOmniColumnDoubleVector(doubleArray) - val decimal64Vector0 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(decimal64Array) - val decimal128Vector0 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(decimal128Array) - val varcharVector0 = ColumnarShuffleWriterSuite.initOmniColumnVarcharVector(stringArray) - val booleanVector0 = ColumnarShuffleWriterSuite.initOmniColumnBooleanVector(booleanArray) - - val cb0 = ColumnarShuffleWriterSuite.makeColumnarBatch( - pidVector0.getVec.getSize, - List(pidVector0, intVector0, shortVector0, longVector0, doubleVector0, - decimal64Vector0, decimal128Vector0, varcharVector0, booleanVector0) - ) - - val pidVector1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(pidArray) - val intVector1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(intArray) - val shortVector1 = ColumnarShuffleWriterSuite.initOmniColumnShortVector(shortArray) - val longVector1 = ColumnarShuffleWriterSuite.initOmniColumnLongVector(longArray) - val doubleVector1 = ColumnarShuffleWriterSuite.initOmniColumnDoubleVector(doubleArray) - val decimal64Vector1 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(decimal64Array) - val decimal128Vector1 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(decimal128Array) - val varcharVector1 = ColumnarShuffleWriterSuite.initOmniColumnVarcharVector(stringArray) - val booleanVector1 = ColumnarShuffleWriterSuite.initOmniColumnBooleanVector(booleanArray) - - val cb1 = ColumnarShuffleWriterSuite.makeColumnarBatch( - pidVector1.getVec.getSize, - List(pidVector1, intVector1, shortVector1, longVector1, doubleVector1, - decimal64Vector1, decimal128Vector1, varcharVector1, booleanVector1) - ) - - def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb0), (0, cb1)) - - val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( - blockResolver, - shuffleHandle, - 0L, // MapId - taskContext.taskMetrics().shuffleWriteMetrics) - - writer.write(records) - writer.stop(success = true) - - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === 0) - // should be (numPartitions - 2) zero length files - - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) - - val serializer = new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance() - val deserializedStream = serializer.deserializeStream(new FileInputStream(outputFile)) - - try { - val kv = deserializedStream.asKeyValueIterator - var length = 0 - kv.foreach { - case (_, batch: ColumnarBatch) => - length += 1 - assert(batch.numRows == 42) - assert(batch.numCols == 8) - assert(batch.column(0).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[IntVec].get(0) == 0) - assert(batch.column(0).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[IntVec].get(19) == 19) - assert(batch.column(1).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[ShortVec].get(0) == 0) - assert(batch.column(1).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[ShortVec].get(19) == 19) - assert(batch.column(2).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(0) == 0) - assert(batch.column(2).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(19) == 19) - assert(batch.column(3).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[DoubleVec].get(0) == 0.0) - assert(batch.column(3).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[DoubleVec].get(19) == 19.19) - assert(batch.column(4).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(0) == 0L) - assert(batch.column(4).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(19) == 19L) - assert(batch.column(5).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[Decimal128Vec].get(0) sameElements Array(0L, 0L)) - assert(batch.column(5).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[Decimal128Vec].get(19) sameElements Array(19L, 19L)) - assert(batch.column(6).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[VarcharVec].get(0) sameElements "") - assert(batch.column(6).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[VarcharVec].get(19) sameElements "sssssssssssssssssss") - assert(batch.column(7).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[BooleanVec].get(0) == true) - assert(batch.column(7).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[BooleanVec].get(19) == false) - (0 until batch.numCols).foreach { i => - val valueVector = batch.column(i).asInstanceOf[OmniColumnVector].getVec - assert(valueVector.getSize == batch.numRows) - } - batch.close() - } - assert(length == 1) - } finally { - deserializedStream.close() - } - - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala deleted file mode 100644 index 5b6811b03362294e35ca39a65de42592a9385aa8..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala +++ /dev/null @@ -1,247 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.io.{File, FileInputStream} - -import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer -import com.huawei.boostkit.spark.vectorized.PartitionInfo -import nova.hetu.omniruntime.`type`.{DataType, _} -import nova.hetu.omniruntime.vector._ -import org.apache.spark.{HashPartitioner, SparkConf, TaskContext} -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.sort.ColumnarShuffleHandle -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.OmniColumnVector -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.Utils -import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Mockito.{doAnswer, when} -import org.mockito.invocation.InvocationOnMock - -class ColumnShuffleSerializerSnappySuite extends SharedSparkSession { - @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ - @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ - @Mock(answer = RETURNS_SMART_NULLS) private var dependency - : ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = _ - - override def sparkConf: SparkConf = - super.sparkConf - .setAppName("test shuffle serializer for snappy") - .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") - .set("spark.shuffle.compress", "true") - .set("spark.io.compression.codec", "snappy") - - private var taskMetrics: TaskMetrics = _ - private var tempDir: File = _ - private var outputFile: File = _ - - private var shuffleHandle: ColumnarShuffleHandle[Int, ColumnarBatch] = _ - private val numPartitions = 1 - - protected var avgBatchNumRows: SQLMetric = _ - protected var outputNumRows: SQLMetric = _ - - override def beforeEach(): Unit = { - super.beforeEach() - - avgBatchNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, - "test serializer avg read batch num rows") - outputNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, - "test serializer number of output rows") - - tempDir = Utils.createTempDir() - outputFile = File.createTempFile("shuffle", null, tempDir) - taskMetrics = new TaskMetrics - - MockitoAnnotations.initMocks(this) - - shuffleHandle = - new ColumnarShuffleHandle[Int, ColumnarBatch](shuffleId = 0, dependency = dependency) - - val types : Array[DataType] = Array[DataType]( - IntDataType.INTEGER, - ShortDataType.SHORT, - LongDataType.LONG, - DoubleDataType.DOUBLE, - new Decimal64DataType(18, 3), - new Decimal128DataType(28, 11), - VarcharDataType.VARCHAR, - BooleanDataType.BOOLEAN) - val inputTypes = DataTypeSerializer.serialize(types) - - when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) - when(dependency.serializer).thenReturn(new JavaSerializer(sparkConf)) - when(dependency.partitionInfo).thenReturn( - new PartitionInfo("hash", numPartitions, types.length, inputTypes)) - when(dependency.dataSize) - .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "data size")) - when(dependency.bytesSpilled) - .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "shuffle bytes spilled")) - when(dependency.numInputRows) - .thenReturn(SQLMetrics.createMetric(spark.sparkContext, "number of input rows")) - when(dependency.splitTime) - .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_split")) - when(dependency.spillTime) - .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_spill")) - when(taskContext.taskMetrics()).thenReturn(taskMetrics) - when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) - - doAnswer { (invocationOnMock: InvocationOnMock) => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] - if (tmp != null) { - outputFile.delete - tmp.renameTo(outputFile) - } - null - }.when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])) - } - - override def afterEach(): Unit = { - try { - Utils.deleteRecursively(tempDir) - } finally { - super.afterEach() - } - } - - override def afterAll(): Unit = { - super.afterAll() - } - - test("write shuffle compress for snappy") { - val pidArray: Array[java.lang.Integer] = Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) - val intArray: Array[java.lang.Integer] = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20) - val shortArray: Array[java.lang.Integer] = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20) - val longArray: Array[java.lang.Long] = Array(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L, - 17L, 18L, 19L, 20L) - val doubleArray: Array[java.lang.Double] = Array(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.10, 11.11, 12.12, - 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.20) - val decimal64Array: Array[java.lang.Long] = Array(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L, - 17L, 18L, 19L, 20L) - val decimal128Array: Array[Array[Long]] = Array( - Array(0L, 0L), Array(1L, 1L), Array(2L, 2L), Array(3L, 3L), Array(4L, 4L), Array(5L, 5L), Array(6L, 6L), - Array(7L, 7L), Array(8L, 8L), Array(9L, 9L), Array(10L, 10L), Array(11L, 11L), Array(12L, 12L), Array(13L, 13L), - Array(14L, 14L), Array(15L, 15L), Array(16L, 16L), Array(17L, 17L), Array(18L, 18L), Array(19L, 19L), Array(20L, 20L)) - val stringArray: Array[java.lang.String] = Array("", "a", "bb", "ccc", "dddd", "eeeee", "ffffff", "ggggggg", - "hhhhhhhh", "iiiiiiiii", "jjjjjjjjjj", "kkkkkkkkkkk", "llllllllllll", "mmmmmmmmmmmmm", "nnnnnnnnnnnnnn", - "ooooooooooooooo", "pppppppppppppppp", "qqqqqqqqqqqqqqqqq", "rrrrrrrrrrrrrrrrrr", "sssssssssssssssssss", - "tttttttttttttttttttt") - val booleanArray: Array[java.lang.Boolean] = Array(true, true, true, true, true, true, true, true, true, true, - false, false, false, false, false, false, false, false, false, false, false) - - val pidVector0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(pidArray) - val intVector0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(intArray) - val shortVector0 = ColumnarShuffleWriterSuite.initOmniColumnShortVector(shortArray) - val longVector0 = ColumnarShuffleWriterSuite.initOmniColumnLongVector(longArray) - val doubleVector0 = ColumnarShuffleWriterSuite.initOmniColumnDoubleVector(doubleArray) - val decimal64Vector0 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(decimal64Array) - val decimal128Vector0 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(decimal128Array) - val varcharVector0 = ColumnarShuffleWriterSuite.initOmniColumnVarcharVector(stringArray) - val booleanVector0 = ColumnarShuffleWriterSuite.initOmniColumnBooleanVector(booleanArray) - - val cb0 = ColumnarShuffleWriterSuite.makeColumnarBatch( - pidVector0.getVec.getSize, - List(pidVector0, intVector0, shortVector0, longVector0, doubleVector0, - decimal64Vector0, decimal128Vector0, varcharVector0, booleanVector0) - ) - - val pidVector1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(pidArray) - val intVector1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(intArray) - val shortVector1 = ColumnarShuffleWriterSuite.initOmniColumnShortVector(shortArray) - val longVector1 = ColumnarShuffleWriterSuite.initOmniColumnLongVector(longArray) - val doubleVector1 = ColumnarShuffleWriterSuite.initOmniColumnDoubleVector(doubleArray) - val decimal64Vector1 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(decimal64Array) - val decimal128Vector1 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(decimal128Array) - val varcharVector1 = ColumnarShuffleWriterSuite.initOmniColumnVarcharVector(stringArray) - val booleanVector1 = ColumnarShuffleWriterSuite.initOmniColumnBooleanVector(booleanArray) - - val cb1 = ColumnarShuffleWriterSuite.makeColumnarBatch( - pidVector1.getVec.getSize, - List(pidVector1, intVector1, shortVector1, longVector1, doubleVector1, - decimal64Vector1, decimal128Vector1, varcharVector1, booleanVector1) - ) - - def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb0), (0, cb1)) - - val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( - blockResolver, - shuffleHandle, - 0L, // MapId - taskContext.taskMetrics().shuffleWriteMetrics) - - writer.write(records) - writer.stop(success = true) - - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === 0) - // should be (numPartitions - 2) zero length files - - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) - - val serializer = new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance() - val deserializedStream = serializer.deserializeStream(new FileInputStream(outputFile)) - - try { - val kv = deserializedStream.asKeyValueIterator - var length = 0 - kv.foreach { - case (_, batch: ColumnarBatch) => - length += 1 - assert(batch.numRows == 42) - assert(batch.numCols == 8) - assert(batch.column(0).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[IntVec].get(0) == 0) - assert(batch.column(0).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[IntVec].get(19) == 19) - assert(batch.column(1).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[ShortVec].get(0) == 0) - assert(batch.column(1).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[ShortVec].get(19) == 19) - assert(batch.column(2).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(0) == 0) - assert(batch.column(2).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(19) == 19) - assert(batch.column(3).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[DoubleVec].get(0) == 0.0) - assert(batch.column(3).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[DoubleVec].get(19) == 19.19) - assert(batch.column(4).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(0) == 0L) - assert(batch.column(4).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(19) == 19L) - assert(batch.column(5).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[Decimal128Vec].get(0) sameElements Array(0L, 0L)) - assert(batch.column(5).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[Decimal128Vec].get(19) sameElements Array(19L, 19L)) - assert(batch.column(6).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[VarcharVec].get(0) sameElements "") - assert(batch.column(6).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[VarcharVec].get(19) sameElements "sssssssssssssssssss") - assert(batch.column(7).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[BooleanVec].get(0) == true) - assert(batch.column(7).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[BooleanVec].get(19) == false) - (0 until batch.numCols).foreach { i => - val valueVector = batch.column(i).asInstanceOf[OmniColumnVector].getVec - assert(valueVector.getSize == batch.numRows) - } - batch.close() - } - assert(length == 1) - } finally { - deserializedStream.close() - } - - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala deleted file mode 100644 index a9924a95d42310d1f784088f1b67e015f45d1ca3..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.io.{File, FileInputStream} - -import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer -import com.huawei.boostkit.spark.vectorized.PartitionInfo -import nova.hetu.omniruntime.`type`.{DataType, _} -import nova.hetu.omniruntime.vector._ -import org.apache.spark.{HashPartitioner, SparkConf, TaskContext} -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.sort.ColumnarShuffleHandle -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.OmniColumnVector -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.Utils -import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Mockito.{doAnswer, when} -import org.mockito.invocation.InvocationOnMock - -class ColumnShuffleSerializerZlibSuite extends SharedSparkSession { - @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ - @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ - @Mock(answer = RETURNS_SMART_NULLS) private var dependency - : ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = _ - - override def sparkConf: SparkConf = - super.sparkConf - .setAppName("test shuffle serializer for zlib") - .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") - .set("spark.shuffle.compress", "true") - .set("spark.io.compression.codec", "zlib") - - private var taskMetrics: TaskMetrics = _ - private var tempDir: File = _ - private var outputFile: File = _ - - private var shuffleHandle: ColumnarShuffleHandle[Int, ColumnarBatch] = _ - private val numPartitions = 1 - - protected var avgBatchNumRows: SQLMetric = _ - protected var outputNumRows: SQLMetric = _ - - override def beforeEach(): Unit = { - super.beforeEach() - - avgBatchNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, - "test serializer avg read batch num rows") - outputNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, - "test serializer number of output rows") - - tempDir = Utils.createTempDir() - outputFile = File.createTempFile("shuffle", null, tempDir) - taskMetrics = new TaskMetrics - - MockitoAnnotations.initMocks(this) - - shuffleHandle = - new ColumnarShuffleHandle[Int, ColumnarBatch](shuffleId = 0, dependency = dependency) - - val types : Array[DataType] = Array[DataType]( - IntDataType.INTEGER, - ShortDataType.SHORT, - LongDataType.LONG, - DoubleDataType.DOUBLE, - new Decimal64DataType(18, 3), - new Decimal128DataType(28, 11), - VarcharDataType.VARCHAR, - BooleanDataType.BOOLEAN) - val inputTypes = DataTypeSerializer.serialize(types) - - when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) - when(dependency.serializer).thenReturn(new JavaSerializer(sparkConf)) - when(dependency.partitionInfo).thenReturn( - new PartitionInfo("hash", numPartitions, types.length, inputTypes)) - when(dependency.dataSize) - .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "data size")) - when(dependency.bytesSpilled) - .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "shuffle bytes spilled")) - when(dependency.numInputRows) - .thenReturn(SQLMetrics.createMetric(spark.sparkContext, "number of input rows")) - when(dependency.splitTime) - .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_split")) - when(dependency.spillTime) - .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_spill")) - when(taskContext.taskMetrics()).thenReturn(taskMetrics) - when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) - - doAnswer { (invocationOnMock: InvocationOnMock) => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] - if (tmp != null) { - outputFile.delete - tmp.renameTo(outputFile) - } - null - }.when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])) - } - - override def afterEach(): Unit = { - try { - Utils.deleteRecursively(tempDir) - } finally { - super.afterEach() - } - } - - override def afterAll(): Unit = { - super.afterAll() - } - - test("write shuffle compress for zlib with null value middle") { - val pidArray: Array[java.lang.Integer] = Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) - val intArray: Array[java.lang.Integer] = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, null, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20) - val shortArray: Array[java.lang.Integer] = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, null, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20) - val longArray: Array[java.lang.Long] = Array(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, 15L, 16L, - 17L, 18L, 19L, 20L) - val doubleArray: Array[java.lang.Double] = Array(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, null, 11.11, 12.12, - 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.20) - val decimal64Array: Array[java.lang.Long] = Array(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, 15L, 16L, - 17L, 18L, 19L, 20L) - val decimal128Array: Array[Array[Long]] = Array( - Array(0L, 0L), Array(1L, 1L), Array(2L, 2L), Array(3L, 3L), Array(4L, 4L), Array(5L, 5L), Array(6L, 6L), - Array(7L, 7L), Array(8L, 8L), Array(9L, 9L), null, Array(11L, 11L), Array(12L, 12L), Array(13L, 13L), - Array(14L, 14L), Array(15L, 15L), Array(16L, 16L), Array(17L, 17L), Array(18L, 18L), Array(19L, 19L), Array(20L, 20L)) - val stringArray: Array[java.lang.String] = Array("", "a", "bb", "ccc", "dddd", "eeeee", "ffffff", "ggggggg", - "hhhhhhhh", "iiiiiiiii", null, "kkkkkkkkkkk", "llllllllllll", "mmmmmmmmmmmmm", "nnnnnnnnnnnnnn", - "ooooooooooooooo", "pppppppppppppppp", "qqqqqqqqqqqqqqqqq", "rrrrrrrrrrrrrrrrrr", "sssssssssssssssssss", - "tttttttttttttttttttt") - val booleanArray: Array[java.lang.Boolean] = Array(true, true, true, true, true, true, true, true, true, true, - null, false, false, false, false, false, false, false, false, false, false) - - val pidVector0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(pidArray) - val intVector0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(intArray) - val shortVector0 = ColumnarShuffleWriterSuite.initOmniColumnShortVector(shortArray) - val longVector0 = ColumnarShuffleWriterSuite.initOmniColumnLongVector(longArray) - val doubleVector0 = ColumnarShuffleWriterSuite.initOmniColumnDoubleVector(doubleArray) - val decimal64Vector0 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(decimal64Array) - val decimal128Vector0 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(decimal128Array) - val varcharVector0 = ColumnarShuffleWriterSuite.initOmniColumnVarcharVector(stringArray) - val booleanVector0 = ColumnarShuffleWriterSuite.initOmniColumnBooleanVector(booleanArray) - - val cb0 = ColumnarShuffleWriterSuite.makeColumnarBatch( - pidVector0.getVec.getSize, - List(pidVector0, intVector0, shortVector0, longVector0, doubleVector0, - decimal64Vector0, decimal128Vector0, varcharVector0, booleanVector0) - ) - - val pidVector1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(pidArray) - val intVector1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(intArray) - val shortVector1 = ColumnarShuffleWriterSuite.initOmniColumnShortVector(shortArray) - val longVector1 = ColumnarShuffleWriterSuite.initOmniColumnLongVector(longArray) - val doubleVector1 = ColumnarShuffleWriterSuite.initOmniColumnDoubleVector(doubleArray) - val decimal64Vector1 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(decimal64Array) - val decimal128Vector1 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(decimal128Array) - val varcharVector1 = ColumnarShuffleWriterSuite.initOmniColumnVarcharVector(stringArray) - val booleanVector1 = ColumnarShuffleWriterSuite.initOmniColumnBooleanVector(booleanArray) - - val cb1 = ColumnarShuffleWriterSuite.makeColumnarBatch( - pidVector1.getVec.getSize, - List(pidVector1, intVector1, shortVector1, longVector1, doubleVector1, - decimal64Vector1, decimal128Vector1, varcharVector1, booleanVector1) - ) - - def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb0), (0, cb1)) - - val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( - blockResolver, - shuffleHandle, - 0L, // MapId - taskContext.taskMetrics().shuffleWriteMetrics) - - writer.write(records) - writer.stop(success = true) - - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === 0) - // should be (numPartitions - 2) zero length files - - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) - - val serializer = new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance() - val deserializedStream = serializer.deserializeStream(new FileInputStream(outputFile)) - - try { - val kv = deserializedStream.asKeyValueIterator - var length = 0 - kv.foreach { - case (_, batch: ColumnarBatch) => - length += 1 - assert(batch.numRows == 42) - assert(batch.numCols == 8) - assert(batch.column(0).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[IntVec].get(0) == 0) - assert(batch.column(0).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[IntVec].get(19) == 19) - assert(batch.column(1).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[ShortVec].get(0) == 0) - assert(batch.column(1).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[ShortVec].get(19) == 19) - assert(batch.column(2).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(0) == 0) - assert(batch.column(2).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(19) == 19) - assert(batch.column(3).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[DoubleVec].get(0) == 0.0) - assert(batch.column(3).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[DoubleVec].get(19) == 19.19) - assert(batch.column(4).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(0) == 0L) - assert(batch.column(4).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(19) == 19L) - assert(batch.column(5).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[Decimal128Vec].get(0) sameElements Array(0L, 0L)) - assert(batch.column(5).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[Decimal128Vec].get(19) sameElements Array(19L, 19L)) - assert(batch.column(6).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[VarcharVec].get(0) sameElements "") - assert(batch.column(6).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[VarcharVec].get(19) sameElements "sssssssssssssssssss") - assert(batch.column(7).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[BooleanVec].get(0) == true) - assert(batch.column(7).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[BooleanVec].get(19) == false) - (0 until batch.numCols).foreach { i => - val valueVector = batch.column(i).asInstanceOf[OmniColumnVector].getVec - assert(valueVector.getSize == batch.numRows) - assert(valueVector.isNull(10)) - } - batch.close() - } - assert(length == 1) - } finally { - deserializedStream.close() - } - - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala deleted file mode 100644 index 00adf145979e33f7dd7b1c49873fd72cdff18756..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala +++ /dev/null @@ -1,377 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.io.{File, FileInputStream} - -import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer -import com.huawei.boostkit.spark.vectorized.PartitionInfo -import nova.hetu.omniruntime.`type`.{DataType, _} -import nova.hetu.omniruntime.vector._ -import org.apache.spark.{HashPartitioner, SparkConf, TaskContext} -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.sort.ColumnarShuffleHandle -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.OmniColumnVector -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} -import org.apache.spark.util.Utils -import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Mockito.{doAnswer, when} -import org.mockito.invocation.InvocationOnMock - -class ColumnarShuffleWriterSuite extends SharedSparkSession { - @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ - @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ - @Mock(answer = RETURNS_SMART_NULLS) private var dependency - : ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = _ - - override def sparkConf: SparkConf = - super.sparkConf - .setAppName("test ColumnarShuffleWriter") - .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") - .set("spark.shuffle.compress", "false") - - private var taskMetrics: TaskMetrics = _ - private var tempDir: File = _ - private var outputFile: File = _ - - private var shuffleHandle: ColumnarShuffleHandle[Int, ColumnarBatch] = _ - private val numPartitions = 11 - - protected var avgBatchNumRows: SQLMetric = _ - protected var outputNumRows: SQLMetric = _ - - override def beforeEach(): Unit = { - super.beforeEach() - - avgBatchNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, - "test serializer avg read batch num rows") - outputNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, - "test serializer number of output rows") - - tempDir = Utils.createTempDir() - outputFile = File.createTempFile("shuffle", null, tempDir) - taskMetrics = new TaskMetrics - - MockitoAnnotations.initMocks(this) - - shuffleHandle = - new ColumnarShuffleHandle[Int, ColumnarBatch](shuffleId = 0, dependency = dependency) - - val types : Array[DataType] = Array[DataType]( - IntDataType.INTEGER, - IntDataType.INTEGER, - new Decimal64DataType(18, 3), - new Decimal128DataType(28, 11)) - val inputTypes = DataTypeSerializer.serialize(types) - - when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) - when(dependency.serializer).thenReturn(new JavaSerializer(sparkConf)) - when(dependency.partitionInfo).thenReturn( - new PartitionInfo("hash", numPartitions, 4, inputTypes)) - // inputTypes e.g: - // [{"id":"OMNI_INT","width":0,"precision":0,"scale":0,"dateUnit":"DAY","timeUnit":"SEC"}, - // {"id":"OMNI_INT","width":0,"precision":0,"scale":0,"dateUnit":"DAY","timeUnit":"SEC"}] - when(dependency.dataSize) - .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "data size")) - when(dependency.bytesSpilled) - .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "shuffle bytes spilled")) - when(dependency.numInputRows) - .thenReturn(SQLMetrics.createMetric(spark.sparkContext, "number of input rows")) - when(dependency.splitTime) - .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_split")) - when(dependency.spillTime) - .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_spill")) - when(taskContext.taskMetrics()).thenReturn(taskMetrics) - when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) - - doAnswer { (invocationOnMock: InvocationOnMock) => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] - if (tmp != null) { - outputFile.delete - tmp.renameTo(outputFile) - } - null - }.when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])) - } - - override def afterEach(): Unit = { - try { - Utils.deleteRecursively(tempDir) - } finally { - super.afterEach() - } - } - - override def afterAll(): Unit = { - super.afterAll() - } - - test("write empty iterator") { - val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( - blockResolver, - shuffleHandle, - 0, // MapId - taskContext.taskMetrics().shuffleWriteMetrics) - writer.write(Iterator.empty) - writer.stop( /* success = */ true) - - assert(writer.getPartitionLengths.sum === 0) - assert(outputFile.exists()) - assert(outputFile.length() === 0) - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === 0) - assert(shuffleWriteMetrics.recordsWritten === 0) - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) - } - - test("write empty column batch") { - val vectorPid0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(Array()) - val vector0_1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(Array()) - val vector0_2 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(Array()) - val vector0_3 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(Array()) - val vector0_4 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(Array()) - - val vectorPid1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(Array()) - val vector1_1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(Array()) - val vector1_2 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(Array()) - val vector1_3 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(Array()) - val vector1_4 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(Array()) - - val cb0 = ColumnarShuffleWriterSuite.makeColumnarBatch( - vectorPid0.getVec.getSize, List(vectorPid0, vector0_1, vector0_2, vector0_3, vector0_4)) - val cb1 = ColumnarShuffleWriterSuite.makeColumnarBatch( - vectorPid1.getVec.getSize, List(vectorPid1, vector1_1, vector1_2, vector1_3, vector1_4)) - - def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb0), (0, cb1)) - - val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( - blockResolver, - shuffleHandle, - 0L, // MapId - taskContext.taskMetrics().shuffleWriteMetrics) - - writer.write(records) - writer.stop(success = true) - assert(writer.getPartitionLengths.sum === 0) - assert(outputFile.exists()) - assert(outputFile.length() === 0) - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === 0) - assert(shuffleWriteMetrics.recordsWritten === 0) - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) - } - - test("write with some empty partitions") { - val vectorPid0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(Array(0, 0, 1, 1)) - val vector0_1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(Array(null, null, null, null)) - val vector0_2 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(Array(100, 100, null, null)) - val vector0_3 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(Array(100L, 100L, 100L, 100L)) - val vector0_4 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(Array(Array(100L, 100L), Array(100L, 100L), null, null)) - val cb0 = ColumnarShuffleWriterSuite.makeColumnarBatch( - vectorPid0.getVec.getSize, List(vectorPid0, vector0_1, vector0_2, vector0_3, vector0_4)) - - val vectorPid1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(Array(0, 0, 1, 1)) - val vector1_1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(Array(null, null, null, null)) - val vector1_2 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(Array(100, 100, null, null)) - val vector1_3 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(Array(100L, 100L, 100L, 100L)) - val vector1_4 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(Array(Array(100L, 100L), Array(100L, 100L), null, null)) - val cb1 = ColumnarShuffleWriterSuite.makeColumnarBatch( - vectorPid1.getVec.getSize, List(vectorPid1, vector1_1, vector1_2, vector1_3, vector1_4)) - - def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb0), (0, cb1)) - - val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( - blockResolver, - shuffleHandle, - 0L, // MapId - taskContext.taskMetrics().shuffleWriteMetrics) - - writer.write(records) - writer.stop(success = true) - - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === (numPartitions - 2)) - // should be (numPartitions - 2) zero length files - - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) - - val serializer = new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance() - val deserializedStream = serializer.deserializeStream(new FileInputStream(outputFile)) - - try { - val kv = deserializedStream.asKeyValueIterator - var length = 0 - kv.foreach { - case (_, batch: ColumnarBatch) => - length += 1 - assert(batch.numRows == 4) - assert(batch.numCols == 4) - (0 until batch.numCols).foreach { i => - val valueVector = batch.column(i).asInstanceOf[OmniColumnVector].getVec - assert(valueVector.getSize == batch.numRows) - } - batch.close() - } - assert(length == 2) - } finally { - deserializedStream.close() - } - - } -} - -object ColumnarShuffleWriterSuite { - def initOmniColumnBooleanVector(values: Array[java.lang.Boolean]): OmniColumnVector = { - val length = values.length - val vecTmp = new BooleanVec(length) - (0 until length).foreach { i => - if (values(i) != null) { - vecTmp.set(i, values(i)) - } else { - vecTmp.setNull(i) - } - } - val colVecTmp = new OmniColumnVector(length, BooleanType, false) - colVecTmp.setVec(vecTmp) - colVecTmp - } - - def initOmniColumnIntVector(values: Array[java.lang.Integer]): OmniColumnVector = { - val length = values.length - val vecTmp = new IntVec(length) - (0 until length).foreach { i => - if (values(i) != null) { - vecTmp.set(i, values(i)) - } else { - vecTmp.setNull(i) - } - } - val colVecTmp = new OmniColumnVector(length, IntegerType, false) - colVecTmp.setVec(vecTmp) - colVecTmp - } - - def initOmniColumnShortVector(values: Array[java.lang.Integer]): OmniColumnVector = { - val length = values.length - val vecTmp = new ShortVec(length) - (0 until length).foreach { i => - if (values(i) != null) { - vecTmp.set(i, values(i).shortValue()) - } else { - vecTmp.setNull(i) - } - } - val colVecTmp = new OmniColumnVector(length, ShortType, false) - colVecTmp.setVec(vecTmp) - colVecTmp - } - - def initOmniColumnLongVector(values: Array[java.lang.Long]): OmniColumnVector = { - val length = values.length - val vecTmp = new LongVec(length) - (0 until length).foreach { i => - if (values(i) != null) { - vecTmp.set(i, values(i)) - } else { - vecTmp.setNull(i) - } - } - val colVecTmp = new OmniColumnVector(length, LongType, false) - colVecTmp.setVec(vecTmp) - colVecTmp - } - - def initOmniColumnDoubleVector(values: Array[java.lang.Double]): OmniColumnVector = { - val length = values.length - val vecTmp = new DoubleVec(length) - (0 until length).foreach { i => - if (values(i) != null) { - vecTmp.set(i, values(i)) - } else { - vecTmp.setNull(i) - } - } - val colVecTmp = new OmniColumnVector(length, DoubleType, false) - colVecTmp.setVec(vecTmp) - colVecTmp - } - - def initOmniColumnVarcharVector(values: Array[java.lang.String]): OmniColumnVector = { - val length = values.length - val vecTmp = new VarcharVec(1024, length) - (0 until length).foreach { i => - if (values(i) != null) { - vecTmp.set(i, values(i).getBytes()) - } else { - vecTmp.setNull(i) - } - } - val colVecTmp = new OmniColumnVector(length, StringType, false) - colVecTmp.setVec(vecTmp) - colVecTmp - } - - def initOmniColumnDecimal64Vector(values: Array[java.lang.Long]): OmniColumnVector = { - val length = values.length - val vecTmp = new LongVec(length) - (0 until length).foreach { i => - if (values(i) != null) { - vecTmp.set(i, values(i)) - } else { - vecTmp.setNull(i) - } - } - val colVecTmp = new OmniColumnVector(length, DecimalType(18, 3), false) - colVecTmp.setVec(vecTmp) - colVecTmp - } - - def initOmniColumnDecimal128Vector(values: Array[Array[Long]]): OmniColumnVector = { - val length = values.length - val vecTmp = new Decimal128Vec(length) - (0 until length).foreach { i => - if (values(i) != null) { - vecTmp.set(i, values(i)) - } else { - vecTmp.setNull(i) - } - } - val colVecTmp = new OmniColumnVector(length, DecimalType(28, 11), false) - colVecTmp.setVec(vecTmp) - colVecTmp - } - - def makeColumnarBatch(rowNum: Int, vectors: List[ColumnVector]): ColumnarBatch = { - new ColumnarBatch(vectors.toArray, rowNum) - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarAggregateBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarAggregateBenchmark.scala deleted file mode 100644 index a1f113b1dfa515c28d780892f835278985eea01b..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarAggregateBenchmark.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.benchmark - - -object ColumnarAggregateBenchmark extends ColumnarBasedBenchmark { - - override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { - - val N = if (mainArgs.isEmpty) { - 500L << 20 - } else { - mainArgs(0).toLong - } - - runBenchmark("stat functions") { - spark.range(N).groupBy().agg("id" -> "sum").explain() - columnarBenchmark(s"spark.range(${N}).groupBy().agg(id -> sum)", N) { - spark.range(N).groupBy().agg("id" -> "sum").noop() - } - } - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarBasedBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarBasedBenchmark.scala deleted file mode 100644 index 402932161e87e095e1a50ec19f77652a91629b78..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarBasedBenchmark.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.benchmark - -import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark -import org.apache.spark.sql.internal.SQLConf - -/** - * Common basic scenario to run benchmark - */ -abstract class ColumnarBasedBenchmark extends SqlBasedBenchmark { - /** Runs function `f` with 3 scenario(spark WSCG on, off and omni-columnar processing) */ - final def columnarBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { - val benchmark = new Benchmark(name, cardinality, output = output) - if (getSparkSession.conf.getOption("spark.sql.extensions").isDefined) - { - benchmark.addCase(s"$name omniruntime wholestage off", numIters = 5) { _ => - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { - f - } - } - } - else - { - benchmark.addCase(s"$name Spark wholestage off", numIters = 5) { _ => - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { - f - } - } - benchmark.addCase(s"$name Spark wholestage on", numIters = 5) { _ => - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { - f - } - } - } - - benchmark.run() - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarFilterBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarFilterBenchmark.scala deleted file mode 100644 index 98e8596fe2b6c44cf869bd071385998cfa27bb54..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarFilterBenchmark.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.benchmark - -import org.apache.spark.sql.benchmark.ColumnarAggregateBenchmark.spark - -object ColumnarFilterBenchmark extends ColumnarBasedBenchmark { - override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { - - val N = if (mainArgs.isEmpty) { - 500L << 20 - } else { - mainArgs(0).toLong - } - - runBenchmark("filter with API") { - spark.range(N).filter("id > 100").explain() - columnarBenchmark(s"spark.range(${N}).filter(id > 100)", N) { - spark.range(N).filter("id > 100").noop() - } - } - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarJoinBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarJoinBenchmark.scala deleted file mode 100644 index 55eda5db03c85a9a6bff206555ebb40518764132..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarJoinBenchmark.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.benchmark - -import org.apache.spark.sql.functions._ - -object ColumnarJoinBenchmark extends ColumnarBasedBenchmark { - def broadcastHashJoinLongKey(rowsA: Long): Unit = { - val rowsB = 1 << 16 - val dim = spark.range(rowsB).selectExpr("id as k", "id as v") - val df = spark.range(rowsA).join(dim.hint("broadcast"), (col("id") % rowsB) === col("k")) - df.explain() - columnarBenchmark(s"broadcastHashJoinLongKey spark.range(${rowsA}).join(spark.range(${rowsB}))", rowsA) { - df.noop() - } - } - - def sortMergeJoin(rowsA: Long, rowsB: Long): Unit = { - val df1 = spark.range(rowsA).selectExpr(s"id * 2 as k1") - val df2 = spark.range(rowsB).selectExpr(s"id * 3 as k2") - val df = df1.join(df2.hint("mergejoin"), col("k1") === col("k2")) - df.explain() - columnarBenchmark(s"sortMergeJoin spark.range(${rowsA}).join(spark.range(${rowsB}))", rowsA) { - df.noop() - } - } - - override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { - - val rowsA = if (mainArgs.isEmpty) { - 20 << 20 - } else { - mainArgs(0).toLong - } - - val rowsB = if (mainArgs.isEmpty || mainArgs.length < 2) { - 1 << 16 - } else { - mainArgs(1).toLong - } - - runBenchmark("Join Benchmark") { - broadcastHashJoinLongKey(rowsA) - sortMergeJoin(rowsA, rowsB) - } - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarProjectBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarProjectBenchmark.scala deleted file mode 100644 index 2540ccbc243fc415ae43ad6275c5c4e6c0b27304..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarProjectBenchmark.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.benchmark - -object ColumnarProjectBenchmark extends ColumnarBasedBenchmark { - override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { - val N = if (mainArgs.isEmpty) { - 500L << 18 - } else { - mainArgs(0).toLong - } - - runBenchmark("project with API") { - spark.range(N).selectExpr("id as p").explain() - columnarBenchmark(s"spark.range(${N}).selectExpr(id as p)", N) { - spark.range(N).selectExpr("id as p").noop() - } - } - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarRangeBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarRangeBenchmark.scala deleted file mode 100644 index 134ec158ee5cf80a235e05373324191ee8fc701a..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarRangeBenchmark.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.benchmark - -object ColumnarRangeBenchmark extends ColumnarBasedBenchmark { - override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { - - val N = if (mainArgs.isEmpty) { - 500L << 20 - } else { - mainArgs(0).toLong - } - - runBenchmark("range with API") { - spark.range(N).explain() - columnarBenchmark(s"spark.range(${N})", N) { - spark.range(N).noop() - } - } - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarSortBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarSortBenchmark.scala deleted file mode 100644 index 99781adf1f5e81663d2d3a0310bcb197ad84a980..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarSortBenchmark.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.benchmark - -object ColumnarSortBenchmark extends ColumnarBasedBenchmark { - override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { - - val N = if (mainArgs.isEmpty) { - 500L << 20 - } else { - mainArgs(0).toLong - } - - runBenchmark("sort with API") { - val value = spark.range(N) - value.sort(value("id").desc).explain() - columnarBenchmark(s"spark.range(${N}).sort(id.desc)", N) { - val value = spark.range(N) - value.sort(value("id").desc).noop() - } - } - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarTopNBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarTopNBenchmark.scala deleted file mode 100644 index 64d3d0ee4dcfe2234e8a4db49e87e5102dc85181..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarTopNBenchmark.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.benchmark - -object ColumnarTopNBenchmark extends ColumnarBasedBenchmark { - override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { - - val N = if (mainArgs.isEmpty) { - 500L << 20 - } else { - mainArgs(0).toLong - } - - runBenchmark("topN with API") { - val value = spark.range(N) - value.sort(value("id").desc).limit(20).explain() - - columnarBenchmark(s"spark.range(${N}).sort(id.desc).limit(20)", N) { - val value = spark.range(N) - value.sort(value("id").desc).limit(20).noop() - } - } - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarUnionBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarUnionBenchmark.scala deleted file mode 100644 index 24b98c9d572a727289b91e61c4c48a2a639dda58..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarUnionBenchmark.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.benchmark - -object ColumnarUnionBenchmark extends ColumnarBasedBenchmark { - override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { - val N = if (mainArgs.isEmpty) { - 5L << 15 - } else { - mainArgs(0).toLong - } - - val M = if (mainArgs.isEmpty || mainArgs.length < 2) { - 10L << 15 - } else { - mainArgs(1).toLong - } - - runBenchmark("union with API") { - val rangeM = spark.range(M) - spark.range(N).union(rangeM).explain() - columnarBenchmark(s"spark.range(${N}).union(spark.range(${M}))", N) { - spark.range(N).union(rangeM).noop() - } - } - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala deleted file mode 100644 index 9f4ae359e1cc8459841dc9757fb52a514a1cbfb4..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala +++ /dev/null @@ -1,413 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.internal.config.UI.UI_ENABLED -import org.apache.spark.sql._ -import org.apache.spark.sql.execution.adaptive._ -import org.apache.spark.sql.execution.exchange.ReusedExchangeExec -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} - -class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAll { - - private var originalActiveSparkSession: Option[SparkSession] = _ - private var originalInstantiatedSparkSession: Option[SparkSession] = _ - - override protected def beforeAll(): Unit = { - super.beforeAll() - originalActiveSparkSession = SparkSession.getActiveSession - originalInstantiatedSparkSession = SparkSession.getDefaultSession - - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() - } - - override protected def afterAll(): Unit = { - try { - // Set these states back. - originalActiveSparkSession.foreach(ctx => SparkSession.setActiveSession(ctx)) - originalInstantiatedSparkSession.foreach(ctx => SparkSession.setDefaultSession(ctx)) - } finally { - super.afterAll() - } - } - - val numInputPartitions: Int = 10 - - def withSparkSession( - f: SparkSession => Unit, - targetPostShuffleInputSize: Int, - minNumPostShufflePartitions: Option[Int]): Unit = { - val sparkConf = - new SparkConf(false) - .setMaster("local[*]") - .setAppName("test") - .set(UI_ENABLED, false) - .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") - .set(SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key, "5") - .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") - .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") - .set( - SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key, - targetPostShuffleInputSize.toString) - .set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "com.huawei.boostkit.spark.ColumnarPlugin") - .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") - minNumPostShufflePartitions match { - case Some(numPartitions) => - sparkConf.set(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key, numPartitions.toString) - case None => - sparkConf.set(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key, "1") - } - - val spark = SparkSession.builder() - .config(sparkConf) - .getOrCreate() - try f(spark) finally spark.stop() - } - - Seq(Some(5), None).foreach { minNumPostShufflePartitions => - val testNameNote = minNumPostShufflePartitions match { - case Some(numPartitions) => "(minNumPostShufflePartitions: " + numPartitions + ")" - case None => "" - } - - test(s"determining the number of reducers: aggregate operator$testNameNote") { - val test = { spark: SparkSession => - val df = - spark - .range(0, 1000, 1, numInputPartitions) - .selectExpr("id % 20 as key", "id as value") - val agg = df.groupBy("key").count() - - // Check the answer first. - QueryTest.checkAnswer( - agg, - spark.range(0, 20).selectExpr("id", "50 as cnt").collect()) - - // Then, let's look at the number of post-shuffle partitions estimated - // by the ExchangeCoordinator. - val finalPlan = agg.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - val shuffleReaders = finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r - } - assert(shuffleReaders.length === 1) - minNumPostShufflePartitions match { - case Some(numPartitions) => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === numPartitions) - } - case None => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === 3) - } - } - } - // The number of coulmn partitions byte is small. smaller threshold value should be used - withSparkSession(test, 1500, minNumPostShufflePartitions) - } - - test(s"determining the number of reducers: join operator$testNameNote") { - val test = { spark: SparkSession => - val df1 = - spark - .range(0, 1000, 1, numInputPartitions) - .selectExpr("id % 500 as key1", "id as value1") - val df2 = - spark - .range(0, 1000, 1, numInputPartitions) - .selectExpr("id % 500 as key2", "id as value2") - - val join = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("value2")) - - // Check the answer first. - val expectedAnswer = - spark - .range(0, 1000) - .selectExpr("id % 500 as key", "id as value") - .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value")) - QueryTest.checkAnswer( - join, - expectedAnswer.collect()) - - // Then, let's look at the number of post-shuffle partitions estimated - // by the ExchangeCoordinator. - val finalPlan = join.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - val shuffleReaders = finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r - } - assert(shuffleReaders.length === 2) - minNumPostShufflePartitions match { - case Some(numPartitions) => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === numPartitions) - } - case None => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === 2) - } - } - } - // The number of coulmn partitions byte is small. smaller threshold value should be used - withSparkSession(test, 11384, minNumPostShufflePartitions) - } - - test(s"determining the number of reducers: complex query 1$testNameNote") { - val test: (SparkSession) => Unit = { spark: SparkSession => - val df1 = - spark - .range(0, 1000, 1, numInputPartitions) - .selectExpr("id % 500 as key1", "id as value1") - .groupBy("key1") - .count() - .toDF("key1", "cnt1") - val df2 = - spark - .range(0, 1000, 1, numInputPartitions) - .selectExpr("id % 500 as key2", "id as value2") - .groupBy("key2") - .count() - .toDF("key2", "cnt2") - - val join = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("cnt2")) - - // Check the answer first. - val expectedAnswer = - spark - .range(0, 500) - .selectExpr("id", "2 as cnt") - QueryTest.checkAnswer( - join, - expectedAnswer.collect()) - - // Then, let's look at the number of post-shuffle partitions estimated - // by the ExchangeCoordinator. - val finalPlan = join.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - val shuffleReaders = finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r - } - assert(shuffleReaders.length === 2) - minNumPostShufflePartitions match { - case Some(numPartitions) => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === numPartitions) - } - case None => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === 3) - } - } - } - // The number of coulmn partitions byte is small. smaller threshold value should be used - withSparkSession(test, 7384, minNumPostShufflePartitions) - } - - test(s"determining the number of reducers: complex query 2$testNameNote") { - val test: (SparkSession) => Unit = { spark: SparkSession => - val df1 = - spark - .range(0, 1000, 1, numInputPartitions) - .selectExpr("id % 500 as key1", "id as value1") - .groupBy("key1") - .count() - .toDF("key1", "cnt1") - val df2 = - spark - .range(0, 1000, 1, numInputPartitions) - .selectExpr("id % 500 as key2", "id as value2") - - val join = - df1 - .join(df2, col("key1") === col("key2")) - .select(col("key1"), col("cnt1"), col("value2")) - - // Check the answer first. - val expectedAnswer = - spark - .range(0, 1000) - .selectExpr("id % 500 as key", "2 as cnt", "id as value") - QueryTest.checkAnswer( - join, - expectedAnswer.collect()) - - // Then, let's look at the number of post-shuffle partitions estimated - // by the ExchangeCoordinator. - val finalPlan = join.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - val shuffleReaders = finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r - } - assert(shuffleReaders.length === 2) - minNumPostShufflePartitions match { - case Some(numPartitions) => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === numPartitions) - } - case None => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === 2) - } - } - } - // The number of coulmn partitions byte is small. smaller threshold value should be used - withSparkSession(test, 10000, minNumPostShufflePartitions) - } - - test(s"determining the number of reducers: plan already partitioned$testNameNote") { - val test: SparkSession => Unit = { spark: SparkSession => - try { - spark.range(1000).write.bucketBy(30, "id").saveAsTable("t") - // `df1` is hash partitioned by `id`. - val df1 = spark.read.table("t") - val df2 = - spark - .range(0, 1000, 1, numInputPartitions) - .selectExpr("id % 500 as key2", "id as value2") - - val join = df1.join(df2, col("id") === col("key2")).select(col("id"), col("value2")) - - // Check the answer first. - val expectedAnswer = spark.range(0, 500).selectExpr("id % 500", "id as value") - .union(spark.range(500, 1000).selectExpr("id % 500", "id as value")) - QueryTest.checkAnswer( - join, - expectedAnswer.collect()) - - // Then, let's make sure we do not reduce number of post shuffle partitions. - val finalPlan = join.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - val shuffleReaders = finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r - } - assert(shuffleReaders.length === 0) - } finally { - spark.sql("drop table t") - } - } - withSparkSession(test, 12000, minNumPostShufflePartitions) - } - } - - ignore("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") { - val test: SparkSession => Unit = { spark: SparkSession => - spark.sql("SET spark.sql.exchange.reuse=true") - val df = spark.range(1).selectExpr("id AS key", "id AS value") - - // test case 1: a query stage has 3 child stages but they are the same stage. - // Final Stage 1 - // ShuffleQueryStage 0 - // ReusedQueryStage 0 - // ReusedQueryStage 0 - val resultDf = df.join(df, "key").join(df, "key") - QueryTest.checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil) - val finalPlan = resultDf.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - assert(finalPlan.collect { - case ShuffleQueryStageExec(_, r: ReusedExchangeExec) => r - }.length == 2) - assert( - finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r - }.length == 3) - - - // test case 2: a query stage has 2 parent stages. - // Final Stage 3 - // ShuffleQueryStage 1 - // ShuffleQueryStage 0 - // ShuffleQueryStage 2 - // ReusedQueryStage 0 - val grouped = df.groupBy("key").agg(max("value").as("value")) - val resultDf2 = grouped.groupBy(col("key") + 1).max("value") - .union(grouped.groupBy(col("key") + 2).max("value")) - QueryTest.checkAnswer(resultDf2, Row(1, 0) :: Row(2, 0) :: Nil) - - val finalPlan2 = resultDf2.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - - // The result stage has 2 children - val level1Stages = finalPlan2.collect { case q: QueryStageExec => q } - assert(level1Stages.length == 2) - - val leafStages = level1Stages.flatMap { stage => - // All of the child stages of result stage have only one child stage. - val children = stage.plan.collect { case q: QueryStageExec => q } - assert(children.length == 1) - children - } - assert(leafStages.length == 2) - - val reusedStages = level1Stages.flatMap { stage => - stage.plan.collect { - case ShuffleQueryStageExec(_, r: ReusedExchangeExec) => r - } - } - assert(reusedStages.length == 1) - } - withSparkSession(test, 4, None) - } - - test("Do not reduce the number of shuffle partition for repartition") { - val test: SparkSession => Unit = { spark: SparkSession => - val ds = spark.range(3) - val resultDf = ds.repartition(2, ds.col("id")).toDF() - - QueryTest.checkAnswer(resultDf, - Seq(0, 1, 2).map(i => Row(i))) - val finalPlan = resultDf.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - assert( - finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r - }.isEmpty) - } - withSparkSession(test, 200, None) - } - - test("Union two datasets with different pre-shuffle partition number") { - val test: SparkSession => Unit = { spark: SparkSession => - val df1 = spark.range(3).join(spark.range(3), "id").toDF() - val df2 = spark.range(3).groupBy().sum() - - val resultDf = df1.union(df2) - - QueryTest.checkAnswer(resultDf, Seq((0), (1), (2), (3)).map(i => Row(i))) - - val finalPlan = resultDf.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - // As the pre-shuffle partition number are different, we will skip reducing - // the shuffle partition numbers. - assert( - finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r - }.isEmpty) - } - withSparkSession(test, 100, None) - } -} - -object ColumnarCoalescedShuffleReader { - def unapply(reader: ColumnarCustomShuffleReaderExec): Boolean = { - !reader.isLocalReader && !reader.hasSkewedPartition && reader.hasCoalescedPartition - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala deleted file mode 100644 index 19c44656e5bc4f2091e631347d612e6ba0f168e7..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StructType} - -class ColumnarExecSuite extends ColumnarSparkPlanTest { - private var dealer: DataFrame = _ - - protected override def beforeAll(): Unit = { - super.beforeAll() - - dealer = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(1, 2.0, false), - Row(1, 2.0, false), - Row(2, 1.0, false), - Row(null, null, false), - Row(null, 5.0, false), - Row(6, null, false) - )), new StructType().add("a", IntegerType).add("b", DoubleType) - .add("c", BooleanType)) - dealer.createOrReplaceTempView("dealer") - } - - test("validate columnar transfer exec happened") { - val sql1 = "SELECT a + 1 FROM dealer" - assertColumnarToRowOmniAndSparkResultEqual(sql1) - } - - test("spark limit with columnarToRow as child") { - - // fetch parital - val sql1 = "select * from (select a, b+2 from dealer order by a, b+2) limit 2" - assertColumnarToRowOmniAndSparkResultEqual(sql1) - - // fetch all - val sql2 = "select a, b+2 from dealer limit 6" - assertColumnarToRowOmniAndSparkResultEqual(sql2) - - // fetch all - val sql3 = "select a, b+2 from dealer limit 10" - assertColumnarToRowOmniAndSparkResultEqual(sql3) - - // fetch parital - val sql4 = "select a, b+2 from dealer order by a limit 2" - assertColumnarToRowOmniAndSparkResultEqual(sql4) - - // fetch all - val sql5 = "select a, b+2 from dealer order by a limit 6" - assertColumnarToRowOmniAndSparkResultEqual(sql5) - - // fetch all - val sql6 = "select a, b+2 from dealer order by a limit 10" - assertColumnarToRowOmniAndSparkResultEqual(sql6) - } - - private def assertColumnarToRowOmniAndSparkResultEqual(sql: String): Unit = { - - spark.conf.set("spark.omni.sql.columnar.takeOrderedAndProject", true) - spark.conf.set("spark.omni.sql.columnar.project", true) - val omniResult = spark.sql(sql) - val omniPlan = omniResult.queryExecution.executedPlan - assert(omniPlan.find(_.isInstanceOf[OmniColumnarToRowExec]).isDefined, - s"SQL:${sql}\n@OmniEnv no OmniColumnarToRowExec,omniPlan:${omniPlan}") - - spark.conf.set("spark.omni.sql.columnar.takeOrderedAndProject", false) - spark.conf.set("spark.omni.sql.columnar.project", false) - val sparkResult = spark.sql(sql) - val sparkPlan = sparkResult.queryExecution.executedPlan - assert(sparkPlan.find(_.isInstanceOf[OmniColumnarToRowExec]).isEmpty, - s"SQL:${sql}\n@SparkEnv have OmniColumnarToRowExec,sparkPlan:${sparkPlan}") - - assert(omniResult.except(sparkResult).isEmpty, - s"SQL:${sql}\nomniResult:${omniResult.show()}\nsparkResult:${sparkResult.show()}\n") - spark.conf.set("spark.omni.sql.columnar.takeOrderedAndProject", true) - spark.conf.set("spark.omni.sql.columnar.project", true) - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExpandExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExpandExecSuite.scala deleted file mode 100644 index 3af1849f8b05e35bae5cde6f93f6b33b10f6a25d..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExpandExecSuite.scala +++ /dev/null @@ -1,358 +0,0 @@ -package org.apache.spark.sql.execution - -import org.apache.spark.sql.{DataFrame, Row} - -class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { - - import testImplicits.{localSeqToDatasetHolder, newProductEncoder} - - private var dealer: DataFrame = _ - private var floatDealer: DataFrame = _ - private var nullDealer: DataFrame = _ - - override def beforeAll(): Unit = { - super.beforeAll() - - dealer = Seq[(Int, String, String, Int)]( - (100, "Fremont", "Honda Civic", 10), - (100, "Fremont", "Honda Accord", 15), - (100, "Fremont", "Honda CRV", 7), - (200, "Dublin", "Honda Civic", 20), - (200, "Dublin", "Honda Accord", 10), - (200, "Dublin", "Honda CRV", 3), - (300, "San Jose", "Honda Civic", 5), - (300, "San Jose", "Honda Accord", 8), - ).toDF("id", "city", "car_model", "quantity") - dealer.createOrReplaceTempView("dealer") - - floatDealer = Seq[(Int, String, String, Float)]( - (100, "Fremont", "Honda Civic", 10.00F), - (100, "Fremont", "Honda Accord", 15.00F), - (100, "Fremont", "Honda CRV", 7.00F), - (200, "Dublin", "Honda Civic", 20.00F), - (200, "Dublin", "Honda Accord", 10.00F), - (200, "Dublin", "Honda CRV", 3.00F), - (300, "San Jose", "Honda Civic", 5.00F), - (300, "San Jose", "Honda Accord", 8.00F), - ).toDF("id", "city", "car_model", "quantity") - floatDealer.createOrReplaceTempView("float_dealer") - - nullDealer = Seq[(Int, String, String, Int)]( - (100, null, "Honda Civic", 10), - (100, "Fremont", "Honda Accord", 15), - (100, "Fremont", null, 7), - (200, "Dublin", "Honda Civic", 20), - (200, null, "Honda Accord", 10), - (200, "Dublin", "Honda CRV", 3), - (300, "San Jose", null, 5), - (300, "San Jose", "Honda Accord", 8), - (300, null, null, 8), - ).toDF("id", "city", "car_model", "quantity") - nullDealer.createOrReplaceTempView("null_dealer") - - } - - test("use ColumnarExpandExec in Grouping Sets clause when default") { - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + - "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) - assert(plan.find(_.isInstanceOf[ExpandExec]).isEmpty) - } - - test("use ExpandExec in Grouping Sets clause when SparkExtension rollback") { - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM float_dealer " + - "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty) - assert(plan.find(_.isInstanceOf[ExpandExec]).isDefined) - } - - test("use ExpandExec in Grouping Sets clause when spark.omni.sql.columnar.expand=false") { - spark.conf.set("spark.omni.sql.columnar.expand", false) - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM float_dealer " + - "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty) - assert(plan.find(_.isInstanceOf[ExpandExec]).isDefined) - spark.conf.set("spark.omni.sql.columnar.expand", true) - } - - test("use ColumnarExpandExec in Rollup clause when default") { - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + - "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) - assert(plan.find(_.isInstanceOf[ExpandExec]).isEmpty) - } - - test("use ExpandExec in Rollup clause when SparkExtension rollback") { - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM float_dealer " + - "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty) - assert(plan.find(_.isInstanceOf[ExpandExec]).isDefined) - } - - test("use ExpandExec in Rollup clause when spark.omni.sql.columnar.expand=false") { - spark.conf.set("spark.omni.sql.columnar.expand", false) - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM float_dealer " + - "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty) - assert(plan.find(_.isInstanceOf[ExpandExec]).isDefined) - spark.conf.set("spark.omni.sql.columnar.expand", true) - } - - test("use ColumnarExpandExec in Cube clause when default") { - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + - "GROUP BY CUBE(city, car_model) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) - assert(plan.find(_.isInstanceOf[ExpandExec]).isEmpty) - } - - test("use ExpandExec in Cube clause when SparkExtension rollback") { - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM float_dealer " + - "GROUP BY CUBE(city, car_model) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty) - assert(plan.find(_.isInstanceOf[ExpandExec]).isDefined) - } - - test("use ExpandExec in Cube clause when spark.omni.sql.columnar.expand=false") { - spark.conf.set("spark.omni.sql.columnar.expand", false) - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM float_dealer " + - "GROUP BY CUBE(city, car_model) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty) - assert(plan.find(_.isInstanceOf[ExpandExec]).isDefined) - spark.conf.set("spark.omni.sql.columnar.expand", true) - } - - test("ColumnarExpandExec exec correctly in Grouping Sets clause") { - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + - "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) - - val expect = Seq( - Row(null, null, 78), - Row(null, "Honda Accord", 33), - Row(null, "Honda CRV", 10), - Row(null, "Honda Civic", 35), - Row("Dublin", null, 33), - Row("Dublin", "Honda Accord", 10), - Row("Dublin", "Honda CRV", 3), - Row("Dublin", "Honda Civic", 20), - Row("Fremont", null, 32), - Row("Fremont", "Honda Accord", 15), - Row("Fremont", "Honda CRV", 7), - Row("Fremont", "Honda Civic", 10), - Row("San Jose", null, 13), - Row("San Jose", "Honda Accord", 8), - Row("San Jose", "Honda Civic", 5), - ) - checkAnswer(result, expect) - } - - test("ColumnarExpandExec exec correctly in Rollup clause") { - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + - "GROUP BY ROLLUP (city, car_model) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) - - val expect = Seq( - Row(null, null, 78), - Row("Dublin", null, 33), - Row("Dublin", "Honda Accord", 10), - Row("Dublin", "Honda CRV", 3), - Row("Dublin", "Honda Civic", 20), - Row("Fremont", null, 32), - Row("Fremont", "Honda Accord", 15), - Row("Fremont", "Honda CRV", 7), - Row("Fremont", "Honda Civic", 10), - Row("San Jose", null, 13), - Row("San Jose", "Honda Accord", 8), - Row("San Jose", "Honda Civic", 5), - ) - checkAnswer(result, expect) - } - - test("ColumnarExpandExec exec correctly in Cube clause") { - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + - "GROUP BY CUBE (city, car_model) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) - - val expect = Seq( - Row(null, null, 78), - Row(null, "Honda Accord", 33), - Row(null, "Honda CRV", 10), - Row(null, "Honda Civic", 35), - Row("Dublin", null, 33), - Row("Dublin", "Honda Accord", 10), - Row("Dublin", "Honda CRV", 3), - Row("Dublin", "Honda Civic", 20), - Row("Fremont", null, 32), - Row("Fremont", "Honda Accord", 15), - Row("Fremont", "Honda CRV", 7), - Row("Fremont", "Honda Civic", 10), - Row("San Jose", null, 13), - Row("San Jose", "Honda Accord", 8), - Row("San Jose", "Honda Civic", 5), - ) - checkAnswer(result, expect) - } - - test("ColumnarExpandExec exec correctly in Grouping Sets clause with GROUPING__ID column") { - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + - "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) - - val expect = Seq( - Row(null, null, 78, 3), - Row(null, "Honda Accord", 33, 2), - Row(null, "Honda CRV", 10, 2), - Row(null, "Honda Civic", 35, 2), - Row("Dublin", null, 33, 1), - Row("Dublin", "Honda Accord", 10, 0), - Row("Dublin", "Honda CRV", 3, 0), - Row("Dublin", "Honda Civic", 20, 0), - Row("Fremont", null, 32, 1), - Row("Fremont", "Honda Accord", 15, 0), - Row("Fremont", "Honda CRV", 7, 0), - Row("Fremont", "Honda Civic", 10, 0), - Row("San Jose", null, 13, 1), - Row("San Jose", "Honda Accord", 8, 0), - Row("San Jose", "Honda Civic", 5, 0), - ) - checkAnswer(result, expect) - } - - test("ColumnarExpandExec exec correctly in Rollup clause with GROUPING__ID column") { - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + - "GROUP BY ROLLUP (city, car_model) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) - - val expect = Seq( - Row(null, null, 78, 3), - Row("Dublin", null, 33, 1), - Row("Dublin", "Honda Accord", 10, 0), - Row("Dublin", "Honda CRV", 3, 0), - Row("Dublin", "Honda Civic", 20, 0), - Row("Fremont", null, 32, 1), - Row("Fremont", "Honda Accord", 15, 0), - Row("Fremont", "Honda CRV", 7, 0), - Row("Fremont", "Honda Civic", 10, 0), - Row("San Jose", null, 13, 1), - Row("San Jose", "Honda Accord", 8, 0), - Row("San Jose", "Honda Civic", 5, 0), - ) - checkAnswer(result, expect) - } - - test("ColumnarExpandExec exec correctly in Cube clause with GROUPING__ID column") { - val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + - "GROUP BY CUBE (city, car_model) ORDER BY city, car_model;") - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) - - val expect = Seq( - Row(null, null, 78, 3), - Row(null, "Honda Accord", 33, 2), - Row(null, "Honda CRV", 10, 2), - Row(null, "Honda Civic", 35, 2), - Row("Dublin", null, 33, 1), - Row("Dublin", "Honda Accord", 10, 0), - Row("Dublin", "Honda CRV", 3, 0), - Row("Dublin", "Honda Civic", 20, 0), - Row("Fremont", null, 32, 1), - Row("Fremont", "Honda Accord", 15, 0), - Row("Fremont", "Honda CRV", 7, 0), - Row("Fremont", "Honda Civic", 10, 0), - Row("San Jose", null, 13, 1), - Row("San Jose", "Honda Accord", 8, 0), - Row("San Jose", "Honda Civic", 5, 0), - ) - checkAnswer(result, expect) - } - - - test("ColumnarExpandExec and ExpandExec return the same result when use Grouping Sets clause") { - val sql = "SELECT city, car_model, sum(quantity) AS sum FROM dealer " + - "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;" - checkExpandExecAndColumnarExpandExecAgree(sql) - } - - test("ColumnarExpandExec and ExpandExec return the same result when use Rollup clause") { - val sql = "SELECT city, car_model, sum(quantity) AS sum FROM dealer " + - "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;" - checkExpandExecAndColumnarExpandExecAgree(sql) - } - - test("ColumnarExpandExec and ExpandExec return the same result when use Cube clause") { - val sql = "SELECT city, car_model, sum(quantity) AS sum FROM dealer " + - "GROUP BY CUBE (city, car_model) ORDER BY city, car_model;" - checkExpandExecAndColumnarExpandExecAgree(sql) - } - - test("ColumnarExpandExec and ExpandExec return the same result when use Grouping Sets clause with null value") { - val sql = "SELECT city, car_model, sum(quantity) AS sum FROM null_dealer " + - "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;" - checkExpandExecAndColumnarExpandExecAgree(sql) - } - - test("ColumnarExpandExec and ExpandExec return the same result when use Rollup clause with null value") { - val sql = "SELECT city, car_model, sum(quantity) AS sum FROM null_dealer " + - "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;" - checkExpandExecAndColumnarExpandExecAgree(sql) - } - - test("ColumnarExpandExec and ExpandExec return the same result when use Cube clause with null value") { - val sql = "SELECT city, car_model, sum(quantity) AS sum FROM null_dealer " + - "GROUP BY CUBE (city, car_model) ORDER BY city, car_model;" - checkExpandExecAndColumnarExpandExecAgree(sql) - } - - test("ColumnarExpandExec and ExpandExec return the same result when use Grouping Sets clause with GROUPING__ID column") { - val sql = "SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + - "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;" - checkExpandExecAndColumnarExpandExecAgree(sql) - } - - test("ColumnarExpandExec and ExpandExec return the same result when use Rollup clause with GROUPING__ID column") { - val sql = "SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + - "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;" - checkExpandExecAndColumnarExpandExecAgree(sql) - } - test("ColumnarExpandExec and ExpandExec return the same result when use Cube clause with GROUPING__ID column") { - val sql = "SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + - "GROUP BY CUBE (city, car_model) ORDER BY city, car_model;" - checkExpandExecAndColumnarExpandExecAgree(sql) - } - - // check ExpandExec and ColumnarExpandExec return the same result - def checkExpandExecAndColumnarExpandExecAgree(sql: String): Unit = { - spark.conf.set("spark.omni.sql.columnar.expand", true) - val omniResult = spark.sql(sql) - val omniPlan = omniResult.queryExecution.executedPlan - assert(omniPlan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) - assert(omniPlan.find(_.isInstanceOf[ExpandExec]).isEmpty) - - spark.conf.set("spark.omni.sql.columnar.expand", false) - val sparkResult = spark.sql(sql) - val sparkPlan = sparkResult.queryExecution.executedPlan - assert(sparkPlan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty) - assert(sparkPlan.find(_.isInstanceOf[ExpandExec]).isDefined) - - // DataFrame do not support comparing with equals method, use DataFrame.except instead - assert(omniResult.except(sparkResult).isEmpty) - - spark.conf.set("spark.omni.sql.columnar.expand", true) - } - -} - diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarFilterExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarFilterExecSuite.scala deleted file mode 100644 index 3c06bd0c865c7ab651a462bce2f96cdddc7ad65a..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarFilterExecSuite.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.expressions.Expression - -class ColumnarFilterExecSuite extends ColumnarSparkPlanTest { - import testImplicits.{localSeqToDatasetHolder, newProductEncoder} - - private var inputDf: DataFrame = _ - private var inputDfWithNull: DataFrame = _ - - protected override def beforeAll(): Unit = { - super.beforeAll() - inputDf = Seq[(String, String, java.lang.Integer, java.lang.Double)]( - ("abc", "", 4, 2.0), - ("", "Hello", 1, 1.0), - (" add", "World", 8, 3.0), - (" yeah ", "yeah", 10, 8.0) - ).toDF("a", "b", "c", "d") - - inputDfWithNull = Seq[(String, String, java.lang.Integer, java.lang.Double)]( - (null, "", 4, 2.0), - (null, null, 1, 1.0), - (" add", "World", 8, null), - (" yeah ", "yeah", 10, 8.0), - (" yeah ", "yeah", 10, 8.0) - ).toDF("a", "b", "c", "d") - } - - test("validate columnar filter exec happened") { - val res = inputDf.filter("c > 1") - print(res.queryExecution.executedPlan) - val isColumnarFilterHappen = res.queryExecution.executedPlan - .find(_.isInstanceOf[ColumnarFilterExec]).isDefined - val isColumnarConditionProjectHappen = res.queryExecution.executedPlan - .find(_.isInstanceOf[ColumnarConditionProjectExec]).isDefined - assert(isColumnarFilterHappen || isColumnarConditionProjectHappen, s"ColumnarFilterExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - test("columnar filter is equal to native") { - val expr: Expression = (inputDf.col("c") > 3).expr - checkThatPlansAgreeTemplate(expr = expr, df = inputDf) - } - - test("columnar filter is equal to native with null") { - val expr: Expression = (inputDfWithNull.col("c") > 3 && inputDfWithNull.col("d").isNotNull).expr - checkThatPlansAgreeTemplate(expr = expr, df = inputDfWithNull) - } - - test("ColumnarFilterExec is not rolled back with not_equal filter expr") { - val res = inputDf.filter("c != d") - val isColumnarFilterHappen = res.queryExecution.executedPlan - .find(_.isInstanceOf[ColumnarFilterExec]).isDefined - val isColumnarConditionProjectHappen = res.queryExecution.executedPlan - .find(_.isInstanceOf[ColumnarConditionProjectExec]).isDefined - assert(isColumnarFilterHappen || isColumnarConditionProjectHappen, s"ColumnarFilterExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - def checkThatPlansAgreeTemplate(expr: Expression, df: DataFrame): Unit = { - checkThatPlansAgree( - df, - (child: SparkPlan) => - ColumnarFilterExec(expr, child = child), - (child: SparkPlan) => - FilterExec(expr, child = child), - sortAnswers = false) - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateDistinctOperatorSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateDistinctOperatorSuite.scala deleted file mode 100644 index 1c996800f186dcdee15c239ea1228cfda45e9cea..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateDistinctOperatorSuite.scala +++ /dev/null @@ -1,354 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.types._ - -class ColumnarHashAggregateDistinctOperatorSuite extends ColumnarSparkPlanTest { - - private var dealer: DataFrame = _ - private var dealer_decimal: DataFrame = _ - - protected override def beforeAll(): Unit = { - super.beforeAll() - - dealer = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(100, "Fremont", "Honda Civic", 10), - Row(100, "Fremont", "Honda Accord", null), - Row(100, "Fremont", "Honda CRV", 7), - Row(200, "Dublin", "Honda Civic", 20), - Row(200, "Dublin", "Honda Civic", null), - Row(200, "Dublin", "Honda Accord", 3), - Row(300, "San Jose", "Honda Civic", 5), - Row(300, "San Jose", "Honda Accord", null) - )), new StructType() - .add("id", IntegerType) - .add("city", StringType) - .add("car_model", StringType) - .add("quantity", IntegerType)) - dealer.createOrReplaceTempView("dealer") - - dealer_decimal = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(100, "Fremont", "Honda Civic", BigDecimal("123456.78"), null, BigDecimal("1234567891234567.89")), - Row(100, "Fremont", "Honda Accord", BigDecimal("456.78"), BigDecimal("456789.12"), null), - Row(100, "Fremont", "Honda CRV", BigDecimal("6.78"), BigDecimal("6789.12"), BigDecimal("67891234567.89")), - Row(200, "Dublin", "Honda Civic", BigDecimal("123456.78"), null, BigDecimal("1234567891234567.89")), - Row(200, "Dublin", "Honda Accord", BigDecimal("6.78"), BigDecimal("9.12"), BigDecimal("567.89")), - Row(200, "Dublin", "Honda CRV", BigDecimal("123456.78"), BigDecimal("123456789.12"), null), - Row(300, "San Jose", "Honda Civic", BigDecimal("3456.78"), null, BigDecimal("34567891234567.89")), - Row(300, "San Jose", "Honda Accord", BigDecimal("56.78"), BigDecimal("56789.12"), null) - )), new StructType() - .add("id", IntegerType) - .add("city", StringType) - .add("car_model", StringType) - .add("quantity_dec8_2", DecimalType(8, 2)) - .add("quantity_dec11_2", DecimalType(11, 2)) - .add("quantity_dec18_2", DecimalType(18, 2))) - dealer_decimal.createOrReplaceTempView("dealer_decimal") - } - - test("Test HashAgg with 1 distinct:") { - val sql1 = "SELECT car_model, count(DISTINCT quantity) AS count FROM dealer" + - " GROUP BY car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql1) - - val sql2 = "SELECT car_model, avg(DISTINCT quantity) AS count FROM dealer" + - " GROUP BY car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql2) - - val sql3 = "SELECT car_model, sum(DISTINCT quantity) AS count FROM dealer" + - " GROUP BY car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql3) - - val sql4 = "SELECT car_model, count(DISTINCT quantity) AS count FROM dealer" + - " GROUP BY car_model order by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql4) - - val sql5 = "SELECT car_model, avg(DISTINCT quantity) AS count FROM dealer" + - " GROUP BY car_model order by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql5) - - val sql6 = "SELECT car_model, sum(DISTINCT quantity) AS count FROM dealer" + - " GROUP BY car_model order by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql6) - } - - test("Test HashAgg with 1 distinct + 1 without distinct:") { - val sql1 = "SELECT car_model, max(id), count(DISTINCT quantity) AS count FROM dealer" + - " GROUP BY car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql1) - - val sql2 = "SELECT car_model, count(id), avg(DISTINCT quantity) AS count FROM dealer" + - " GROUP BY car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql2) - - val sql3 = "SELECT car_model, min(id), sum(DISTINCT quantity) AS count FROM dealer" + - " GROUP BY car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql3) - - val sql4 = "SELECT car_model, max(id), count(DISTINCT quantity) AS count FROM dealer" + - " GROUP BY car_model order by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql4) - - val sql5 = "SELECT car_model, count(id), avg(DISTINCT quantity) AS count FROM dealer" + - " GROUP BY car_model order by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql5) - - val sql6 = "SELECT car_model, min(id), sum(DISTINCT quantity) AS count FROM dealer" + - " GROUP BY car_model order by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql6) - } - - test("Test HashAgg with multi distinct + multi without distinct:") { - val sql1 = "select car_model, min(id), max(quantity), count(distinct city) from dealer" + - " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql1) - - val sql2 = "select car_model, avg(DISTINCT quantity), count(DISTINCT city) from dealer" + - " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql2) - - val sql3 = "select car_model, sum(DISTINCT quantity), count(DISTINCT city) from dealer" + - " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql3) - - val sql4 = "select car_model, avg(DISTINCT quantity), sum(DISTINCT city) from dealer" + - " group by car_model;" - // sum(DISTINCT city) have knownfloatingpointnormalized(normalizenanandzero(cast(city as double))) - // not support, HashAggExec will partial replace - assertHashAggregateExecOmniAndSparkResultEqual(sql4, false) - - val sql5 = "select car_model, count(DISTINCT city), avg(DISTINCT quantity), sum(DISTINCT city) from dealer" + - " group by car_model;" - // sum(DISTINCT city) have knownfloatingpointnormalized(normalizenanandzero(cast(city as double))) - // not support, HashAggExec will partial replace - assertHashAggregateExecOmniAndSparkResultEqual(sql5, false) - - val sql6 = "select car_model, min(id), sum(DISTINCT quantity), count(DISTINCT city) from dealer" + - " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql6) - - val sql7 = "select car_model, sum(DISTINCT quantity), count(DISTINCT city), avg(DISTINCT city), min(id), max(id) from dealer" + - " group by car_model;" - // avg(DISTINCT city) have knownfloatingpointnormalized(normalizenanandzero(cast(city as double))) - // not support, HashAggExec will partial replace - assertHashAggregateExecOmniAndSparkResultEqual(sql7, false) - - val sql8 = "select car_model, min(id), sum(DISTINCT quantity), count(DISTINCT city), avg(DISTINCT city) from dealer" + - " group by car_model;" - // avg(DISTINCT city) have knownfloatingpointnormalized(normalizenanandzero(cast(city as double))) - // not support, HashAggExec will partial replace - assertHashAggregateExecOmniAndSparkResultEqual(sql8, false) - } - - test("Test HashAgg with decimal distinct:") { - val sql1 = "select car_model, avg(DISTINCT quantity_dec8_2), count(DISTINCT city) from dealer_decimal" + - " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql1, hashAggExecFullReplace = false) - - val sql2 = "select car_model, min(id), sum(DISTINCT quantity_dec8_2), count(DISTINCT city) from dealer_decimal" + - " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql2) - - val sql3 = "select car_model, count(DISTINCT quantity_dec8_2), count(DISTINCT city), avg(DISTINCT city), min(id), max(id) from dealer_decimal" + - " group by car_model;" - // avg(DISTINCT city) have knownfloatingpointnormalized(normalizenanandzero(cast(city as double))) - // not support, HashAggExec will partial replace - assertHashAggregateExecOmniAndSparkResultEqual(sql3, false) - - val sql4 = "select car_model, avg(DISTINCT quantity_dec11_2), count(DISTINCT city) from dealer_decimal" + - " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql4, hashAggExecFullReplace = false) - - val sql5 = "select car_model, min(id), sum(DISTINCT quantity_dec11_2), count(DISTINCT city) from dealer_decimal" + - " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql5) - - val sql6 = "select car_model, count(DISTINCT quantity_dec11_2), count (DISTINCT city), avg(DISTINCT city), min(id), max(id) from dealer_decimal" + - " group by car_model;" - // avg(DISTINCT city) have knownfloatingpointnormalized(normalizenanandzero(cast(city as double))) - // not support, HashAggExec will partial replace - assertHashAggregateExecOmniAndSparkResultEqual(sql6, false) - - val sql7 = "select car_model, count(DISTINCT quantity_dec8_2), avg(DISTINCT quantity_dec8_2), sum(DISTINCT quantity_dec8_2) from dealer_decimal" + - " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql7, hashAggExecFullReplace = false) - - val sql8 = "select car_model, count(DISTINCT quantity_dec11_2), avg(DISTINCT quantity_dec11_2), sum(DISTINCT quantity_dec11_2) from dealer_decimal" + - " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql8, hashAggExecFullReplace = false) - } - - test("Test HashAgg with multi distinct + multi without distinct + order by:") { - val sql1 = "select car_model, min(id), max(quantity), count(distinct city) from dealer" + - " group by car_model order by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql1) - - val sql2 = "select car_model, avg(DISTINCT quantity), count(DISTINCT city) from dealer" + - " group by car_model order by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql2) - - val sql3 = "select car_model, sum(DISTINCT quantity), count(DISTINCT city) from dealer" + - " group by car_model order by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql3) - - val sql4 = "select car_model, avg(DISTINCT quantity), sum(DISTINCT city) from dealer" + - " group by car_model order by car_model;" - // sum(DISTINCT city) have knownfloatingpointnormalized(normalizenanandzero(cast(city as double))) - // not support, HashAggExec will partial replace - assertHashAggregateExecOmniAndSparkResultEqual(sql4, false) - - val sql5 = "select car_model, count(DISTINCT city), avg(DISTINCT quantity), sum(DISTINCT city) from dealer" + - " group by car_model order by car_model;" - // sum(DISTINCT city) have knownfloatingpointnormalized(normalizenanandzero(cast(city as double))) - // not support, HashAggExec will partial replace - assertHashAggregateExecOmniAndSparkResultEqual(sql5, false) - - val sql6 = "select car_model, min(id), sum(DISTINCT quantity), count(DISTINCT city) from dealer" + - " group by car_model order by car_model;" - // count(DISTINCT city) have knownfloatingpointnormalized(normalizenanandzero(cast(city as double))) - // not support, HashAggExec will partial replace - assertHashAggregateExecOmniAndSparkResultEqual(sql6, false) - - val sql7 = "select car_model, sum(DISTINCT quantity), count(DISTINCT city), avg(DISTINCT city), min(id), max(id) from dealer" + - " group by car_model order by car_model;" - // avg(DISTINCT city) have knownfloatingpointnormalized(normalizenanandzero(cast(city as double))) - // not support, HashAggExec will partial replace - assertHashAggregateExecOmniAndSparkResultEqual(sql7, false) - - val sql8 = "select car_model, min(id), sum(DISTINCT quantity), count(DISTINCT city), avg(DISTINCT city) from dealer" + - " group by car_model order by car_model;" - // avg(DISTINCT city) have knownfloatingpointnormalized(normalizenanandzero(cast(city as double))) - // not support, HashAggExec will partial replace - assertHashAggregateExecOmniAndSparkResultEqual(sql8, false) - } - - test("Test HashAgg with 1 distinct + order by:") { - val sql1 = "SELECT car_model, count(DISTINCT city) AS count FROM dealer" + - " GROUP BY car_model order by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql1) - } - - test("Test HashAgg with 2 distinct: group by columnar as distinct columnar") { - val sql1 = "SELECT city, car_model, count(DISTINCT city), max(DISTINCT quantity) FROM dealer" + - " GROUP BY city, car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql1) - - val sql2 = "SELECT city, id, count(DISTINCT id), max(DISTINCT quantity) FROM dealer" + - " GROUP BY city, id;" - assertHashAggregateExecOmniAndSparkResultEqual(sql2) - - val sql3 = "SELECT city, quantity_dec8_2, count(DISTINCT quantity_dec8_2), max(DISTINCT id) FROM dealer_decimal" + - " GROUP BY city, quantity_dec8_2;" - assertHashAggregateExecOmniAndSparkResultEqual(sql3) - - val sql4 = "SELECT city, quantity_dec11_2, count(DISTINCT quantity_dec11_2), max(DISTINCT id) FROM dealer_decimal" + - " GROUP BY city, quantity_dec11_2;" - assertHashAggregateExecOmniAndSparkResultEqual(sql4) - - val sql5 = "SELECT city, quantity_dec18_2, count(DISTINCT quantity_dec18_2), max(DISTINCT id) FROM dealer_decimal" + - " GROUP BY city, quantity_dec18_2;" - assertHashAggregateExecOmniAndSparkResultEqual(sql5) - } - - test("Test HashAgg with aggkey expresion + 2 distinct: ") { - val sql1 = "SELECT car_model, city, count(DISTINCT concat(city,car_model)), max(DISTINCT quantity)" + - " FROM dealer GROUP BY city, car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql1) - - val sql2 = "SELECT city, concat(city,car_model), count(DISTINCT concat(city,car_model)), max(DISTINCT quantity)" + - " FROM dealer GROUP BY city, car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql2) - - val sql3 = "SELECT city, id, count (DISTINCT id + 3), max(DISTINCT quantity + 1)" + - " FROM dealer GROUP BY city, id;" - assertHashAggregateExecOmniAndSparkResultEqual(sql3) - - val sql4 = "SELECT city, quantity_dec8_2, count(DISTINCT quantity_dec8_2 + 1.00), max(DISTINCT id)" + - " FROM dealer_decimal GROUP BY city, quantity_dec8_2;" - assertHashAggregateExecOmniAndSparkResultEqual(sql4) - - val sql5 = "SELECT city, quantity_dec11_2, count(DISTINCT quantity_dec11_2 + 1.00), max(DISTINCT id)" + - " FROM dealer_decimal GROUP BY city, quantity_dec11_2;" - assertHashAggregateExecOmniAndSparkResultEqual(sql5) - - val sql6 = "SELECT city, quantity_dec18_2, count(DISTINCT quantity_dec18_2 + 1.00), max(DISTINCT id)" + - " FROM dealer_decimal GROUP BY city, quantity_dec18_2;" - assertHashAggregateExecOmniAndSparkResultEqual(sql6) - } - - test("Test HashAgg with 2 distinct + order by: group by columnar as distinct columnar") { - val sql1 = "SELECT city, car_model, count(DISTINCT city), max(DISTINCT quantity) FROM dealer" + - " GROUP BY city, car_model order by city, car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql1) - - val sql2 = "SELECT city, id, count(DISTINCT id), max(DISTINCT quantity) FROM dealer" + - " GROUP BY city, id order by city, id;" - assertHashAggregateExecOmniAndSparkResultEqual(sql2) - - val sql3 = "SELECT city, quantity_dec8_2, count(DISTINCT quantity_dec8_2), max(DISTINCT id)" + - " FROM dealer_decimal GROUP BY city, quantity_dec8_2 order BY city, quantity_dec8_2;" - assertHashAggregateExecOmniAndSparkResultEqual(sql3) - - val sql4 = "SELECT city, quantity_dec11_2, count(DISTINCT quantity_dec11_2), max(DISTINCT id)" + - " FROM dealer_decimal GROUP BY city, quantity_dec11_2 order BY city, quantity_dec11_2;" - assertHashAggregateExecOmniAndSparkResultEqual(sql4) - - val sql5 = "SELECT city, quantity_dec18_2, count(DISTINCT quantity_dec18_2), max(DISTINCT id)" + - " FROM dealer_decimal GROUP BY city, quantity_dec18_2 order BY city, quantity_dec18_2;" - assertHashAggregateExecOmniAndSparkResultEqual(sql5) - } - - test("Test HashAgg with aggkey expresion + 2 distinct + order by:") { - val sql1 = "SELECT city, quantity_dec18_2, count(DISTINCT quantity_dec18_2 + 1.00), max(DISTINCT id)" + - " FROM dealer_decimal GROUP BY city, quantity_dec18_2 order BY city, quantity_dec18_2;" - assertHashAggregateExecOmniAndSparkResultEqual(sql1) - } - - private def assertHashAggregateExecOmniAndSparkResultEqual(sql: String, hashAggExecFullReplace: Boolean = true): Unit = { - // run ColumnarHashAggregateExec config - spark.conf.set("spark.omni.sql.columnar.hashagg", true) - val omniResult = spark.sql(sql) - val omniPlan = omniResult.queryExecution.executedPlan - assert(omniPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, - s"SQL:${sql}\n@OmniEnv no ColumnarHashAggregateExec,omniPlan:${omniPlan}") - if (hashAggExecFullReplace) { - assert(omniPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, - s"SQL:${sql}\n@OmniEnv have HashAggregateExec,omniPlan:${omniPlan}") - } - - // run HashAggregateExec config - spark.conf.set("spark.omni.sql.columnar.hashagg", false) - val sparkResult = spark.sql(sql) - val sparkPlan = sparkResult.queryExecution.executedPlan - assert(sparkPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isEmpty, - s"SQL:${sql}\n@SparkEnv have ColumnarHashAggregateExec,sparkPlan:${sparkPlan}") - assert(sparkPlan.find(_.isInstanceOf[HashAggregateExec]).isDefined, - s"SQL:${sql}\n@SparkEnv no HashAggregateExec,sparkPlan:${sparkPlan}") - // DataFrame do not support comparing with equals method, use DataFrame.except instead - // DataFrame.except can do equal for rows misorder(with and without order by are same) - assert(omniResult.except(sparkResult).isEmpty, - s"SQL:${sql}\nomniResult:${omniResult.show()}\nsparkResult:${sparkResult.show()}\n") - spark.conf.set("spark.omni.sql.columnar.hashagg", true) - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala deleted file mode 100644 index 11dfac2cbd2175ebb9218addcc7831bf11b60526..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.functions.{avg, count, first, max, min, sum} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row} - -class ColumnarHashAggregateExecSuite extends ColumnarSparkPlanTest { - private var df: DataFrame = _ - - protected override def beforeAll(): Unit = { - super.beforeAll() - df = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(1, 2.0, 1L, "a"), - Row(1, 2.0, 2L, null), - Row(2, 1.0, 3L, "c"), - Row(null, null, 6L, "e"), - Row(null, 5.0, 7L, "f") - )), new StructType().add("a", IntegerType).add("b", DoubleType) - .add("c", LongType).add("d", StringType)) - } - - test("validate columnar hashAgg exec happened") { - val res = df.groupBy("a").agg(sum("b")) - assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - test("check columnar hashAgg result") { - val res = testData2.groupBy("a").agg(sum("b")) - checkAnswer( - res, - Seq(Row(1, 3), Row(2, 3), Row(3, 3)) - ) - } - - test("check columnar hashAgg result with null") { - val res = df.filter(df("a").isNotNull && df("d").isNotNull).groupBy("a").agg(sum("b")) - checkAnswer( - res, - Seq(Row(1, 2.0), Row(2, 1.0)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result is correct when execute count(*) api") { - val res = df.agg(count("*")) - assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") - checkAnswer( - res, - Seq(Row(5)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result " + - "is correct when execute count(*) api with group by") { - val res = df.groupBy("a").agg(count("*")) - assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") - checkAnswer( - res, - Seq(Row(1, 2), Row(2, 1), Row(null, 2)) - ) - } - - test("test hashAgg null") { - var res = df.filter(df("a").equalTo(3)).groupBy("a").agg(sum("a")) - checkAnswer( - res, - Seq.empty - ) - res = df.filter(df("a").equalTo(3)).groupBy("a").agg(max("a")) - checkAnswer( - res, - Seq.empty - ) - res = df.filter(df("a").equalTo(3)).groupBy("a").agg(min("a")) - checkAnswer( - res, - Seq.empty - ) - res = df.filter(df("a").equalTo(3)).groupBy("a").agg(avg("a")) - checkAnswer( - res, - Seq.empty - ) - res = df.filter(df("a").equalTo(3)).groupBy("a").agg(first("a")) - checkAnswer( - res, - Seq.empty - ) - res = df.filter(df("a").equalTo(3)).groupBy("a").agg(count("a")) - checkAnswer( - res, - Seq.empty - ) - } - test("test agg null") { - var res = df.filter(df("a").equalTo(3)).agg(sum("a")) - checkAnswer( - res, - Seq(Row(null)) - ) - res = df.filter(df("a").equalTo(3)).agg(max("a")) - checkAnswer( - res, - Seq(Row(null)) - ) - res = df.filter(df("a").equalTo(3)).agg(min("a")) - checkAnswer( - res, - Seq(Row(null)) - ) - res = df.filter(df("a").equalTo(3)).agg(avg("a")) - checkAnswer( - res, - Seq(Row(null)) - ) - res = df.filter(df("a").equalTo(3)).agg(first("a")) - checkAnswer( - res, - Seq(Row(null)) - ) - res = df.filter(df("a").equalTo(3)).agg(count("a")) - checkAnswer( - res, - Seq(Row(0)) - ) - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala deleted file mode 100644 index 370aa8b58edbbb14726a6061fd9ae795038c3f4e..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala +++ /dev/null @@ -1,470 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.optimizer.BuildRight -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ColumnarBroadcastHashJoinExec, ColumnarShuffledHashJoinExec, ColumnarSortMergeJoinExec, SortMergeJoinExec} -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} - -// refer to joins package -class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { - - import testImplicits.{localSeqToDatasetHolder, newProductEncoder} - - private var left: DataFrame = _ - private var right: DataFrame = _ - private var leftWithNull: DataFrame = _ - private var rightWithNull: DataFrame = _ - private var person_test: DataFrame = _ - private var order_test: DataFrame = _ - - protected override def beforeAll(): Unit = { - super.beforeAll() - left = Seq[(String, String, java.lang.Integer, java.lang.Double)]( - ("abc", "", 4, 2.0), - ("", "Hello", 1, 1.0), - (" add", "World", 8, 3.0), - (" yeah ", "yeah", 10, 8.0) - ).toDF("a", "b", "q", "d") - - right = Seq[(String, String, java.lang.Integer, java.lang.Double)]( - ("abc", "", 4, 1.0), - ("", "Hello", 2, 2.0), - (" add", "World", 1, 3.0), - (" yeah ", "yeah", 0, 4.0) - ).toDF("a", "b", "c", "d") - - leftWithNull = Seq[(String, String, java.lang.Integer, java.lang.Double)]( - ("abc", null, 4, 2.0), - ("", "Hello", null, 1.0), - (" add", "World", 8, 3.0), - (" yeah ", "yeah", 10, 8.0) - ).toDF("a", "b", "q", "d") - - rightWithNull = Seq[(String, String, java.lang.Integer, java.lang.Double)]( - ("abc", "", 4, 1.0), - ("", "Hello", 2, 2.0), - (" add", null, 1, null), - (" yeah ", null, null, 4.0) - ).toDF("a", "b", "c", "d") - - person_test = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(3, "Carter"), - Row(1, "Adams"), - Row(2, "Bush") - )), new StructType() - .add("id_p", IntegerType) - .add("name", StringType)) - person_test.createOrReplaceTempView("person_test") - - order_test = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(5, 34764, 65), - Row(1, 77895, 3), - Row(2, 44678, 3), - Row(4, 24562, 1), - Row(3, 22456, 1) - )), new StructType() - .add("id_o", IntegerType) - .add("order_no", IntegerType) - .add("id_p", IntegerType)) - order_test.createOrReplaceTempView("order_test") - } - - test("validate columnar broadcastHashJoin exec happened") { - val res = left.join(right.hint("broadcast"), col("q") === col("c")) - assert( - res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarBroadcastHashJoinExec]).isDefined, - s"ColumnarBroadcastHashJoinExec not happened, " + - s"executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - test("validate columnar sortMergeJoin exec happened") { - val res = left.join(right.hint("mergejoin"), col("q") === col("c")) - assert( - res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortMergeJoinExec]).isDefined, - s"ColumnarSortMergeJoinExec not happened, " + - s"executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - test("columnar broadcastHashJoin is equal to native") { - val df = left.join(right.hint("broadcast"), col("q") === col("c")) - val leftKeys = Seq(left.col("q").expr) - val rightKeys = Seq(right.col("c").expr) - checkThatPlansAgreeTemplateForBHJ(df, leftKeys, rightKeys) - } - - test("columnar sortMergeJoin Inner Join is equal to native") { - val df = left.join(right.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(left.col("q").expr) - val rightKeys = Seq(right.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, Inner) - } - - test("columnar sortMergeJoin Inner Join is equal to native With NULL") { - val df = leftWithNull.join(rightWithNull.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(leftWithNull.col("q").expr) - val rightKeys = Seq(rightWithNull.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, Inner) - } - - test("columnar sortMergeJoin LeftOuter Join is equal to native") { - val df = left.join(right.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(left.col("q").expr) - val rightKeys = Seq(right.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftOuter) - } - - test("columnar sortMergeJoin LeftOuter Join is equal to native With NULL") { - val df = leftWithNull.join(rightWithNull.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(leftWithNull.col("q").expr) - val rightKeys = Seq(rightWithNull.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftOuter) - } - - test("columnar sortMergeJoin FullOuter Join is equal to native") { - val df = left.join(right.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(left.col("q").expr) - val rightKeys = Seq(right.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, FullOuter) - } - - test("columnar sortMergeJoin FullOuter Join is equal to native With NULL") { - val df = leftWithNull.join(rightWithNull.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(leftWithNull.col("q").expr) - val rightKeys = Seq(rightWithNull.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, FullOuter) - } - - test("columnar sortMergeJoin LeftSemi Join is equal to native") { - val df = left.join(right.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(left.col("q").expr) - val rightKeys = Seq(right.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftSemi) - } - - test("columnar sortMergeJoin LeftSemi Join is equal to native With NULL") { - val df = leftWithNull.join(rightWithNull.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(leftWithNull.col("q").expr) - val rightKeys = Seq(rightWithNull.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftSemi) - } - - test("columnar sortMergeJoin LeftAnti Join is equal to native") { - val df = left.join(right.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(left.col("q").expr) - val rightKeys = Seq(right.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftAnti) - } - - test("columnar sortMergeJoin LeftAnti Join is equal to native With NULL") { - val df = leftWithNull.join(rightWithNull.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(leftWithNull.col("q").expr) - val rightKeys = Seq(rightWithNull.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftAnti) - } - - test("columnar broadcastHashJoin is equal to native with null") { - val df = leftWithNull.join(rightWithNull.hint("broadcast"), - col("q").isNotNull === col("c").isNotNull) - val leftKeys = Seq(leftWithNull.col("q").isNotNull.expr) - val rightKeys = Seq(rightWithNull.col("c").isNotNull.expr) - checkThatPlansAgreeTemplateForBHJ(df, leftKeys, rightKeys) - } - - test("columnar broadcastHashJoin LeftSemi Join is equal to native") { - val df = left.join(right.hint("broadcast"), col("q") === col("c")) - val leftKeys = Seq(left.col("q").expr) - val rightKeys = Seq(right.col("c").expr) - checkThatPlansAgreeTemplateForBHJ(df, leftKeys, rightKeys, LeftSemi) - } - - test("columnar broadcastHashJoin LeftSemi Join is equal to native with null") { - val df = leftWithNull.join(rightWithNull.hint("broadcast"), - col("q").isNotNull === col("c").isNotNull) - val leftKeys = Seq(leftWithNull.col("q").isNotNull.expr) - val rightKeys = Seq(rightWithNull.col("c").isNotNull.expr) - checkThatPlansAgreeTemplateForBHJ(df, leftKeys, rightKeys, LeftSemi) - } - - def checkThatPlansAgreeTemplateForBHJ(df: DataFrame, leftKeys: Seq[Expression], - rightKeys: Seq[Expression], joinType: JoinType = Inner): Unit = { - checkThatPlansAgree( - df, - (child: SparkPlan) => - ColumnarBroadcastHashJoinExec(leftKeys, rightKeys, joinType, - BuildRight, None, child, child), - (child: SparkPlan) => - BroadcastHashJoinExec(leftKeys, rightKeys, joinType, - BuildRight, None, child, child), - sortAnswers = false) - } - - test("validate columnar broadcastHashJoin left outer join happened") { - val res = left.join(right.hint("broadcast"), col("q") === col("c"), "leftouter") - assert( - res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarBroadcastHashJoinExec]).isDefined, - s"ColumnarBroadcastHashJoinExec not happened," + - s" executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - test("columnar broadcastHashJoin left outer join is equal to native") { - val df = left.join(right.hint("broadcast"), col("q") === col("c"), "leftouter") - checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( - Row("abc", "", 4, 2.0, "abc", "", 4, 1.0), - Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), - Row(" add", "World", 8, 3.0, null, null, null, null), - Row(" yeah ", "yeah", 10, 8.0, null, null, null, null) - ), false) - } - - test("columnar broadcastHashJoin left outer join is equal to native with null") { - val df = leftWithNull.join(rightWithNull.hint("broadcast"), - col("q").isNotNull === col("c").isNotNull, "leftouter") - checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( - Row("abc", null, 4, 2.0, " add", null, 1, null), - Row("abc", null, 4, 2.0, "", "Hello", 2, 2.0), - Row("abc", null, 4, 2.0, "abc", "", 4, 1.0), - Row("", "Hello", null, 1.0, " yeah ", null, null, 4.0), - Row(" add", "World", 8, 3.0, " add", null, 1, null), - Row(" add", "World", 8, 3.0, "", "Hello", 2, 2.0), - Row(" add", "World", 8, 3.0, "abc", "", 4, 1.0), - Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null), - Row(" yeah ", "yeah", 10, 8.0, "", "Hello", 2, 2.0), - Row(" yeah ", "yeah", 10, 8.0, "abc", "", 4, 1.0) - ), false) - } - - test("validate columnar shuffledHashJoin full outer join happened") { - val res = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "fullouter") - assert( - res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffledHashJoinExec]).isDefined, - s"ColumnarShuffledHashJoinExec not happened," + - s" executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - test("columnar shuffledHashJoin full outer join is equal to native") { - val df = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "fullouter") - checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( - Row(null, null, null, null, " yeah ", "yeah", 0, 4.0), - Row("abc", "", 4, 2.0, "abc", "", 4, 1.0), - Row(" yeah ", "yeah", 10, 8.0, null, null, null, null), - Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), - Row(" add", "World", 8, 3.0, null, null, null, null), - Row(null, null, null, null, "", "Hello", 2, 2.0) - ), false) - } - - test("columnar shuffledHashJoin full outer join is equal to native with null") { - val df = leftWithNull.join(rightWithNull.hint("SHUFFLE_HASH"), - col("q").isNotNull === col("c").isNotNull, "fullouter") - checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( - Row("", "Hello", null, 1.0, " yeah ", null, null, 4.0), - Row("abc", null, 4, 2.0, " add", null, 1, null), - Row("abc", null, 4, 2.0, "", "Hello", 2, 2.0), - Row("abc", null, 4, 2.0, "abc", "", 4, 1.0), - Row(" add", "World", 8, 3.0, " add", null, 1, null), - Row(" add", "World", 8, 3.0, "", "Hello", 2, 2.0), - Row(" add", "World", 8, 3.0, "abc", "", 4, 1.0), - Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null), - Row(" yeah ", "yeah", 10, 8.0, "", "Hello", 2, 2.0), - Row(" yeah ", "yeah", 10, 8.0, "abc", "", 4, 1.0) - ), false) - } - - test("validate columnar shuffledHashJoin left semi join happened") { - val res = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftsemi") - assert( - res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffledHashJoinExec]).isDefined, - s"ColumnarShuffledHashJoinExec not happened," + - s" executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - test("columnar shuffledHashJoin left semi join is equal to native") { - val df = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftsemi") - checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( - Row("abc", "", 4, 2.0), - Row("", "Hello", 1, 1.0) - ), false) - } - - test("columnar shuffledHashJoin left semi join is equal to native with null") { - val df = leftWithNull.join(rightWithNull.hint("SHUFFLE_HASH"), - col("q") === col("c"), "leftsemi") - checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( - Row("abc", null, 4, 2.0) - ), false) - } - - test("validate columnar shuffledHashJoin left outer join happened") { - val res = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftouter") - assert( - res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffledHashJoinExec]).isDefined, - s"ColumnarShuffledHashJoinExec not happened," + - s" executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - test("columnar shuffledHashJoin left outer join is equal to native") { - val df = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftouter") - checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( - Row("abc", "", 4, 2.0, "abc", "", 4, 1.0), - Row(" yeah ", "yeah", 10, 8.0, null, null, null, null), - Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), - Row(" add", "World", 8, 3.0, null, null, null, null) - ), false) - } - - test("columnar shuffledHashJoin left outer join is equal to native with null") { - val df = leftWithNull.join(rightWithNull.hint("SHUFFLE_HASH"), - col("q") === col("c"), "leftouter") - checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( - Row("", "Hello", null, 1.0, null, null, null, null), - Row("abc", null, 4, 2.0, "abc", "", 4, 1.0), - Row(" yeah ", "yeah", 10, 8.0, null, null, null, null), - Row(" add", "World", 8, 3.0, null, null, null, null) - ), false) - } - - test("ColumnarBroadcastHashJoin is not rolled back with not_equal filter expr") { - val res = left.join(right.hint("broadcast"), left("a") <=> right("a")) - assert( - res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarBroadcastHashJoinExec]).isDefined, - s"ColumnarBroadcastHashJoinExec not happened, " + - s"executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - def checkThatPlansAgreeTemplateForSMJ(df: DataFrame, leftKeys: Seq[Expression], - rightKeys: Seq[Expression], joinType: JoinType): Unit = { - checkThatPlansAgree( - df, - (child: SparkPlan) => - new ColumnarSortMergeJoinExec(leftKeys, rightKeys, joinType, - None, child, child), - (child: SparkPlan) => - SortMergeJoinExec(leftKeys, rightKeys, joinType, - None, child, child), - sortAnswers = true) - } - - test("BroadcastHashJoin and project funsion test") { - val omniResult = person_test.join(order_test.hint("broadcast"), person_test("id_p") === order_test("id_p"), "leftouter") - .select(person_test("name"), order_test("order_no")) - val omniPlan = omniResult.queryExecution.executedPlan - assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, - s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") - checkAnswer(omniResult, _ => omniPlan, Seq( - Row("Carter", 44678), - Row("Carter", 77895), - Row("Adams", 22456), - Row("Adams", 24562), - Row("Bush", null) - ), false) - } - - test("BroadcastHashJoin and project funsion test for duplicate column") { - val omniResult = person_test.join(order_test.hint("broadcast"), person_test("id_p") === order_test("id_p"), "leftouter") - .select(person_test("name"), order_test("order_no"), order_test("id_p")) - val omniPlan = omniResult.queryExecution.executedPlan - assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, - s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") - checkAnswer(omniResult, _ => omniPlan, Seq( - Row("Carter", 44678, 3), - Row("Carter", 77895, 3), - Row("Adams", 22456, 1), - Row("Adams", 24562, 1), - Row("Bush", null, null) - ), false) - } - - test("BroadcastHashJoin and project funsion test for reorder columns") { - val omniResult = person_test.join(order_test.hint("broadcast"), person_test("id_p") === order_test("id_p"), "leftouter") - .select(order_test("order_no"), person_test("name"), order_test("id_p")) - val omniPlan = omniResult.queryExecution.executedPlan - assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, - s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") - checkAnswer(omniResult, _ => omniPlan, Seq( - Row(44678, "Carter", 3), - Row(77895, "Carter", 3), - Row(22456, "Adams", 1), - Row(24562, "Adams", 1), - Row(null, "Bush", null) - ), false) - } - - test("BroadcastHashJoin and project are not fused test") { - val omniResult = person_test.join(order_test.hint("broadcast"), person_test("id_p") === order_test("id_p"), "leftouter") - .select(order_test("order_no").plus(1), person_test("name")) - val omniPlan = omniResult.queryExecution.executedPlan - assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, - s"SQL:\n@OmniEnv have ColumnarProjectExec,omniPlan:${omniPlan}") - checkAnswer(omniResult, _ => omniPlan, Seq( - Row(44679, "Carter"), - Row(77896, "Carter"), - Row(22457, "Adams"), - Row(24563, "Adams"), - Row(null, "Bush") - ), false) - } - - test("BroadcastHashJoin and project funsion test for alias") { - val omniResult = person_test.join(order_test.hint("broadcast"), person_test("id_p") === order_test("id_p"), "leftouter") - .select(person_test("name").as("name1"), order_test("order_no").as("order_no1")) - val omniPlan = omniResult.queryExecution.executedPlan - assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, - s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") - checkAnswer(omniResult, _ => omniPlan, Seq( - Row("Carter", 44678), - Row("Carter", 77895), - Row("Adams", 22456), - Row("Adams", 24562), - Row("Bush", null) - ), false) - } - - test("validate columnar shuffledHashJoin left anti join happened") { - val res = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftanti") - assert( - res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffledHashJoinExec]).isDefined, - s"ColumnarShuffledHashJoinExec not happened," + - s" executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - test("columnar shuffledHashJoin left anti join is equal to native") { - val df = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftanti") - checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( - Row(" yeah ", "yeah", 10, 8.0), - Row(" add", "World", 8, 3.0) - ), false) - } - - test("columnar shuffledHashJoin left anti join is equal to native with null") { - val df = leftWithNull.join(rightWithNull.hint("SHUFFLE_HASH"), - col("q") === col("c"), "leftanti") - checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( - Row("", "Hello", null, 1.0), - Row(" yeah ", "yeah", 10, 8.0), - Row(" add", "World", 8, 3.0) - ), false) - } -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarProjectExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarProjectExecSuite.scala deleted file mode 100644 index ce39461cd5bba1920509eda6be231467f2488e1c..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarProjectExecSuite.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.expressions.NamedExpression - -class ColumnarProjectExecSuite extends ColumnarSparkPlanTest { - import testImplicits.{localSeqToDatasetHolder, newProductEncoder} - - private var inputDf: DataFrame = _ - private var inputDfWithNull: DataFrame = _ - - protected override def beforeAll(): Unit = { - super.beforeAll() - inputDf = Seq[(String, String, java.lang.Integer, java.lang.Double)]( - ("abc", "", 4, 2.0), - ("", "Hello", 1, 1.0), - (" add", "World", 8, 3.0), - (" yeah ", "yeah", 10, 8.0) - ).toDF("a", "b", "c", "d") - - inputDfWithNull = Seq[(String, String, java.lang.Integer, java.lang.Double)]( - (null, "", 4, 2.0), - (null, null, 1, 1.0), - (" add", "World", 8, 3.0), - (" yeah ", "yeah", 10, 8.0) - ).toDF("a", "b", "c", "d") - } - - test("validate columnar project exec happened") { - val res = inputDf.selectExpr("a as t") - assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, - s"ColumnarProjectExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - test("columnar project is equal to native") { - val projectList: Seq[NamedExpression] = Seq(inputDf.col("a").as("abc").expr.asInstanceOf[NamedExpression]) - checkThatPlansAgreeTemplate(projectList, inputDf) - } - - test("columnar project is equal to native with null") { - val projectList: Seq[NamedExpression] = Seq(inputDfWithNull.col("a").as("abc").expr.asInstanceOf[NamedExpression]) - checkThatPlansAgreeTemplate(projectList, inputDfWithNull) - } - - def checkThatPlansAgreeTemplate(projectList: Seq[NamedExpression], df: DataFrame): Unit = { - checkThatPlansAgree( - df, - (child: SparkPlan) => - ColumnarProjectExec(projectList, child = child), - (child: SparkPlan) => - ProjectExec(projectList, child = child), - sortAnswers = false) - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarRangeExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarRangeExecSuite.scala deleted file mode 100644 index 79699046aee4600572ceca73be7f89113045b26e..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarRangeExecSuite.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -// refer to DataFrameRangeSuite -class ColumnarRangeSuite extends ColumnarSparkPlanTest { - test("validate columnar range exec happened") { - val res = spark.range(0, 10, 1) - assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarRangeExec]).isDefined, s"ColumnarRangeExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecSuite.scala deleted file mode 100644 index 91fe50455e4f184069adee0be316d63658f4af16..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecSuite.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -class ColumnarShuffleExchangeExecSuite extends ColumnarSparkPlanTest { - import testImplicits.{localSeqToDatasetHolder, newProductEncoder} - - protected override def beforeAll(): Unit = { - super.beforeAll() - } - - test("validate columnar shuffleExchange exec worked") { - val inputDf = Seq[(String, java.lang.Integer, java.lang.Double)] ( - ("Sam", 12, 9.1), - ("Bob", 13, 9.3), - ("Ted", 10, 8.9) - ).toDF("name", "age", "point") - val res = inputDf.sort(inputDf("age").asc) - assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined, - s"ColumnarSortExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") - - assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined, - s"ColumnarShuffleExchangeExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSortExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSortExecSuite.scala deleted file mode 100644 index cecf846af71efdfe1f25e064f600025c56feaeae..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSortExecSuite.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.lang - -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.SortOrder - -class ColumnarSortExecSuite extends ColumnarSparkPlanTest { - import testImplicits.{localSeqToDatasetHolder, newProductEncoder} - - test("validate columnar sort exec happened") { - val inputDf = Seq[(String, java.lang.Integer, java.lang.Double)]( - ("Hello", 4, 2.0), - ("Hello", 1, 1.0), - ("World", 8, 3.0) - ).toDF("a", "b", "c") - val res = inputDf.sort(inputDf("b").asc) - assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined, s"ColumnarSortExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - test("columnar sort is equal to native sort") { - val df = Seq[(String, java.lang.Integer, java.lang.Double)]( - ("Hello", 4, 2.0), - ("Hello", 1, 1.0), - ("World", 8, 3.0) - ).toDF("a", "b", "c") - val sortOrder = Stream('a.asc, 'b.asc, 'c.asc) - checkThatPlansAgreeTemplate(input = df, sortOrder = sortOrder) - } - - test("columnar sort is equal to native sort with null") { - val dfWithNull = Seq[(String, Integer, lang.Double)]( - ("Hello", 4, 2.0), - (null, 1, 1.0), - ("World", null, 3.0), - ("World", 8, 3.0) - ).toDF("a", "b", "c") - val sortOrder = Stream('a.asc, 'b.asc, 'c.asc) - checkThatPlansAgreeTemplate(input = dfWithNull, sortOrder = sortOrder) - } - - def checkThatPlansAgreeTemplate(input: DataFrame, sortOrder: Seq[SortOrder]): Unit = { - checkThatPlansAgree( - input, - (child: SparkPlan) => - ColumnarSortExec(sortOrder, global = true, child = child), - (child: SparkPlan) => - SortExec(sortOrder, global = true, child = child), - sortAnswers = false) - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala deleted file mode 100644 index 16ab589578aacc68a1964566ef05f3898d0e406a..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.SparkConf -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row} -import org.apache.spark.sql.catalyst.util.stackTraceToString -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} -import org.apache.spark.sql.test.SharedSparkSession - -private[sql] abstract class ColumnarSparkPlanTest extends SparkPlanTest with SharedSparkSession { - // setup basic columnar configuration for columnar exec - override def sparkConf: SparkConf = super.sparkConf - .set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "com.huawei.boostkit.spark.ColumnarPlugin") - .set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") - .set("spark.executorEnv.OMNI_CONNECTED_ENGINE", "Spark") - .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") - - protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { - val analyzedDF = try df catch { - case ae: AnalysisException => - if (ae.plan.isDefined) { - fail( - s""" - |Failed to analyze query: $ae - |${ae.plan.get} - | - |${stackTraceToString(ae)} - |""".stripMargin) - } else { - throw ae - } - } - assertEmptyMissingInput(analyzedDF) - QueryTest.checkAnswer(analyzedDF, expectedAnswer) - } - - private def assertEmptyMissingInput(query: Dataset[_]): Unit = { - assert(query.queryExecution.analyzed.missingInput.isEmpty, - s"The analyzed logical plan has missing inputs:\n${query.queryExecution.analyzed}") - assert(query.queryExecution.optimizedPlan.missingInput.isEmpty, - s"The optimized logical plan has missing inputs:\n${query.queryExecution.optimizedPlan}") - assert(query.queryExecution.executedPlan.missingInput.isEmpty, - s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}") - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNExecSuite.scala deleted file mode 100644 index 3fb8a1bf9c6306362687d4d975dabfd44f43b1b7..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNExecSuite.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.dsl.expressions.DslSymbol -import org.apache.spark.sql.catalyst.expressions.{NamedExpression, SortOrder} - -// refer to TakeOrderedAndProjectSuite -class ColumnarTopNExecSuite extends ColumnarSparkPlanTest { - import testImplicits.{localSeqToDatasetHolder, newProductEncoder} - - private var inputDf: DataFrame = _ - private var inputDfWithNull: DataFrame = _ - - protected override def beforeAll(): Unit = { - super.beforeAll() - inputDf = Seq[(java.lang.Integer, java.lang.Double, String)]( - (4, 2.0, "abc"), - (1, 1.0, "aaa"), - (8, 3.0, "ddd"), - (10, 8.0, "") - ).toDF("a", "b", "c") - - inputDfWithNull = Seq[(String, String, java.lang.Integer, java.lang.Double)]( - ("abc", "", 4, 2.0), - ("", null, 1, 1.0), - (" add", "World", 8, null), - (" yeah ", "yeah", 10, 8.0) - ).toDF("a", "b", "c", "d") - } - - test("validate columnar topN exec happened") { - val res = inputDf.sort("a").limit(2) - assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarTakeOrderedAndProjectExec]).isDefined, s"ColumnarTakeOrderedAndProjectExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - test("columnar topN is equal to native") { - val limit = 3 - val sortOrder = Stream('a.asc, 'b.desc) - val projectList = Seq(inputDf.col("a").as("abc").expr.asInstanceOf[NamedExpression]) - checkThatPlansAgreeTemplate(inputDf, limit, sortOrder, projectList) - } - - test("columnar topN is equal to native with null") { - val res = inputDfWithNull.orderBy("a", "b").selectExpr("c + 1", "d + 2").limit(2) - checkAnswer(res, Seq(Row(2, 3.0), Row(9, null))) - } - - def checkThatPlansAgreeTemplate(df: DataFrame, limit: Int, sortOrder: Seq[SortOrder], - projectList: Seq[NamedExpression]): Unit = { - checkThatPlansAgree( - df, - input => ColumnarTakeOrderedAndProjectExec(limit, sortOrder, projectList, input), - input => TakeOrderedAndProjectExec(limit, sortOrder, projectList, input), - sortAnswers = false) - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarUnionExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarUnionExecSuite.scala deleted file mode 100644 index 9539d9448aad5b33a3277225a6b11b69c5b97913..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarUnionExecSuite.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.{DataFrame, Row} - -class ColumnarUnionExecSuite extends ColumnarSparkPlanTest { - import testImplicits.{localSeqToDatasetHolder, newProductEncoder} - - private var left: DataFrame = _ - private var right: DataFrame = _ - - protected override def beforeAll(): Unit = { - super.beforeAll() - left = Seq[(String, String, java.lang.Integer, java.lang.Double)]( - ("abc", "", 4, 2.0), - ("", "Hello", 1, 1.0), - (" add", "World", 8, 3.0), - (" yeah ", "yeah", 10, 8.0) - ).toDF("a", "b", "c", "d") - - right = Seq[(String, String, java.lang.Integer, java.lang.Double)]( - (null, "", 4, 2.0), - (null, null, 1, 1.0), - (" add", "World", 8, 3.0), - (" yeah ", "yeah", 10, 8.0) - ).toDF("a", "b", "c", "d") - } - - test("validate columnar union exec happened") { - val res = left.union(right) - assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarUnionExec]).isDefined, s"ColumnarUnionExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") - } - - test("columnar union is equal to expected") { - val expected = Array(Row("abc", "", 4, 2.0), - Row("", "Hello", 1, 1.0), - Row(" add", "World", 8, 3.0), - Row(" yeah ", "yeah", 10, 8.0), - Row(null, "", 4, 2.0), - Row(null, null, 1, 1.0), - Row(" add", "World", 8, 3.0), - Row(" yeah ", "yeah", 10, 8.0)) - val res = left.union(right) - val result: Array[Row] = res.head(8) - assertResult(expected)(result) - } - - test("columnar union is equal to native with null") { - val df = left.union(right) - val children = Seq(left.queryExecution.executedPlan, right.queryExecution.executedPlan) - checkThatPlansAgreeTemplate(df, children) - } - - def checkThatPlansAgreeTemplate(df: DataFrame, child: Seq[SparkPlan]): Unit = { - checkThatPlansAgree( - df, - (_: SparkPlan) => - ColumnarUnionExec(child), - (_: SparkPlan) => - UnionExec(child), - sortAnswers = false) - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarWindowExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarWindowExecSuite.scala deleted file mode 100644 index 4f11256f47f32ba8088a6bd3f201e53bddcfeb46..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarWindowExecSuite.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSparkSession - -// refer to DataFrameWindowFramesSuite -class ColumnarWindowExecSuite extends ColumnarSparkPlanTest with SharedSparkSession { - import testImplicits._ - - private var inputDf: DataFrame = _ - - protected override def beforeAll(): Unit = { - super.beforeAll() - inputDf = Seq( - ("abc", "", 4, 2.0), - ("", "Hello", 1, 1.0), - (" add", "World", 8, 3.0), - (" yeah ", "yeah", 10, 8.0), - ("abc", "", 10, 8.0) - ).toDF("a", "b", "c", "d") - } - - test("validate columnar window exec happened") { - val res1 = Window.partitionBy("a").orderBy('c.desc) - val res2 = inputDf.withColumn("max", max("c").over(res1)) - res2.head(10).foreach(row => println(row)) - assert(res2.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarWindowExec]).isDefined, s"ColumnarWindowExec not happened, executedPlan as follows: \n${res2.queryExecution.executedPlan}") - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/DecimalOperationSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/DecimalOperationSuite.scala deleted file mode 100644 index 2f72a3651cfe5fb286e42b85ad2d07c76bb01bf3..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/DecimalOperationSuite.scala +++ /dev/null @@ -1,1128 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.types.Decimal -import org.apache.spark.sql.{Column, DataFrame} - -import java.math.MathContext - -class DecimalOperationSuite extends ColumnarSparkPlanTest { - - import testImplicits.{localSeqToDatasetHolder, newProductEncoder} - - private var deci_overflow: DataFrame = _ - - private def newDecimal(deci: String, precision: Int, scale: Int): Decimal = { - if (deci == null) - null - else - new Decimal().set(BigDecimal(deci, MathContext.UNLIMITED), precision, scale) - } - - private def newRow(id: Int, c_deci5_0: String, c_deci7_2: String, c_deci17_2: String, c_deci18_6: String, - c_deci21_6: String, c_deci22_6: String, c_deci38_0: String, c_deci38_16: String): - (Int, Decimal, Decimal, Decimal, Decimal, Decimal, Decimal, Decimal, Decimal) = { - (id, - newDecimal(c_deci5_0, 5, 0), - newDecimal(c_deci7_2, 7, 2), - newDecimal(c_deci17_2, 17, 2), - newDecimal(c_deci18_6, 18, 6), - newDecimal(c_deci21_6, 21, 6), - newDecimal(c_deci22_6, 22, 6), - newDecimal(c_deci38_0, 38, 0), - newDecimal(c_deci38_16, 38, 16)) - } - - private def checkResult(sql: String, expect: String): Unit = { - val result = spark.sql(sql) - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarConditionProjectExec]).isDefined) - val output = result.collect().toSeq.head.getDecimal(0) - assertResult(expect, s"sql: ${sql}")(output.toString) - } - - private def checkResultNull(sql: String): Unit = { - val result = spark.sql(sql) - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarConditionProjectExec]).isDefined) - val output = result.collect().toSeq.head.getDecimal(0) - assertResult(null, s"sql: ${sql}")(output) - } - - private def checkAnsiResult(sql: String, expect: String): Unit = { - spark.conf.set("spark.sql.ansi.enabled", true) - checkResult(sql, expect) - spark.conf.set("spark.sql.ansi.enabled", false) - } - - private def checkAnsiResultNull(sql: String): Unit = { - spark.conf.set("spark.sql.ansi.enabled", true) - checkResultNull(sql) - spark.conf.set("spark.sql.ansi.enabled", false) - } - - private def checkAnsiResultException(sql: String, msg: String): Unit = { - spark.conf.set("spark.sql.ansi.enabled", true) - val result = spark.sql(sql) - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ColumnarConditionProjectExec]).isDefined) - val exception = intercept[Exception]( - result.collect().toSeq.head.getDecimal(0) - ) - assert(exception.getMessage.contains(msg), s"sql: ${sql}") - spark.conf.set("spark.sql.ansi.enabled", false) - } - - private def checkAnsiResultOverflowException(sql: String): Unit = - checkAnsiResultException(sql, "Reason: Decimal overflow") - - private def checkAnsiResultDivideBy0Exception(sql: String): Unit = - checkAnsiResultException(sql, "Reason: Division by zero") - - - override def beforeAll(): Unit = { - super.beforeAll() - - deci_overflow = Seq[(Int, Decimal, Decimal, Decimal, Decimal, - Decimal, Decimal, Decimal, Decimal)]( - newRow(1, "12345", "12345.12", "123456789123456.23", "123456789123.34", - "123456789123456.456789", "1234567891234567.567891", - "123456789123456789123456789", "1234567891234567891234.6789123456"), - newRow(2, "99999", "99999.99", "999999999999999.99", "999999999999.999999", - "-999999999999999.999999", "9999999999999999.999999", - "99999999999999999999999999999999999999", "9999999999999999999999.9999999999999999"), - newRow(3, "99999", "0.99", "0.99", "0.999999", - "0.999999", "9999999999999999.999999", - "99999999999999999999999999999999999999", "-9999999999999999999999.9999999999999999"), - newRow(4, "99999", "0", "0.99", "0.999999", - "0", "9999999999999999.999999", - "99999999999999999999999999999999999999", "0"), - newRow(5, "99999", null, "0.99", "0.999999", - null, "0.999999", - "99999999999999999999999999999999999999", null), - newRow(6, "-12345", "12345.12", "-123456789123456.23", "123456789123.34", - "-123456789123456.456789", "1234567891234567.567891", - "123456789123456789123456789", "-1234567891234567891234.6789123456"), - ).toDF("id", "c_deci5_0", "c_deci7_2", "c_deci17_2", "c_deci18_6", "c_deci21_6", - "c_deci22_6", "c_deci38_0", "c_deci38_16") - - // Decimal in DataFrame is decimal(38,16), so need to cast to the target decimal type - deci_overflow = deci_overflow.withColumn("c_deci5_0", Column("c_deci5_0").cast("decimal(5,0)")) - .withColumn("c_deci7_2", Column("c_deci7_2").cast("decimal(7,2)")) - .withColumn("c_deci17_2", Column("c_deci17_2").cast("decimal(17,2)")) - .withColumn("c_deci18_6", Column("c_deci18_6").cast("decimal(18,6)")) - .withColumn("c_deci21_6", Column("c_deci21_6").cast("decimal(21,6)")) - .withColumn("c_deci22_6", Column("c_deci22_6").cast("decimal(22,6)")) - .withColumn("c_deci38_0", Column("c_deci38_0").cast("decimal(38,0)")) - .withColumn("c_deci38_16", Column("c_deci38_16").cast("decimal(38,16)")) - - deci_overflow.createOrReplaceTempView("deci_overflow") - deci_overflow.printSchema() - } - - /* normal positive decimal operation test cases */ - // spark.sql.ansi.enabled=false - test("decimal64+decimal64=decimal64 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0+c_deci7_2 from deci_overflow where id = 1;", "24690.12") - } - - test("decimal64+decimal64=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci17_2+c_deci18_6 from deci_overflow where id = 1;", "123580245912579.570000") - } - - test("decimal64+decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci17_2+c_deci22_6 from deci_overflow where id = 1;", "1358024680358023.797891") - } - - test("decimal128+decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci21_6+c_deci22_6 from deci_overflow where id = 1;", "1358024680358024.024680") - } - - test("decimal64-decimal64=decimal64 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci7_2-c_deci5_0 from deci_overflow where id = 1;", "0.12") - } - - test("decimal64-decimal64=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci18_6-c_deci17_2 from deci_overflow where id = 1;", "-123333332334332.890000") - } - - test("decimal64-decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci22_6-c_deci17_2 from deci_overflow where id = 1;", "1111111102111111.337891") - } - - test("decimal128-decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci22_6-c_deci21_6 from deci_overflow where id = 1;", "1111111102111111.111102") - } - - test("decimal64*decimal64=decimal64 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0*c_deci7_2 from deci_overflow where id = 1;", "152400506.40") - } - - test("decimal64*decimal64=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci17_2*c_deci18_6 from deci_overflow where id = 1;", - "15241578780659191108332561.40820000") - } - - test("decimal64*decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci17_2*c_deci22_6 from deci_overflow where id = 1;", - "152415787806736055266232119611.091911") - } - - test("decimal128*decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci21_6*c_deci22_6 from deci_overflow where id = 1;", - "152415787806736335252649604807.436065") - } - - test("decimal64/decimal64=decimal64 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0/c_deci7_2 from deci_overflow where id = 1;", "0.99999028") - } - - test("decimal64/decimal64=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci17_2/c_deci18_6 from deci_overflow where id = 1;", - "1000.00000000094146301") - } - - test("decimal64/decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci17_2/c_deci22_6 from deci_overflow where id = 1;", - "0.09999999999999957") - } - - test("decimal128/decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci22_6/c_deci21_6 from deci_overflow where id = 1;", - "10.0000000000000243") - } - - test("decimal64%decimal64=decimal64 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0%c_deci7_2 from deci_overflow where id = 1;", "12345.00") - } - - test("decimal64%decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci22_6%c_deci17_2 from deci_overflow where id = 1;", "5.267891") - } - - test("decimal128%decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci22_6%c_deci21_6 from deci_overflow where id = 1;", "3.000001") - } - - - // spark.sql.ansi.enabled=true - test("decimal64+decimal64=decimal64 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0+c_deci7_2 from deci_overflow where id = 1;", "24690.12") - } - - test("decimal64+decimal64=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci17_2+c_deci18_6 from deci_overflow where id = 1;", "123580245912579.570000") - } - - test("decimal64+decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci17_2+c_deci22_6 from deci_overflow where id = 1;", "1358024680358023.797891") - } - - test("decimal128+decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci21_6+c_deci22_6 from deci_overflow where id = 1;", "1358024680358024.024680") - } - - test("decimal64-decimal64=decimal64 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci7_2-c_deci5_0 from deci_overflow where id = 1;", "0.12") - } - - test("decimal64-decimal64=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci18_6-c_deci17_2 from deci_overflow where id = 1;", "-123333332334332.890000") - } - - test("decimal64-decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci22_6-c_deci17_2 from deci_overflow where id = 1;", "1111111102111111.337891") - } - - test("decimal128-decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci22_6-c_deci21_6 from deci_overflow where id = 1;", "1111111102111111.111102") - } - - test("decimal64*decimal64=decimal64 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0*c_deci7_2 from deci_overflow where id = 1;", "152400506.40") - } - - test("decimal64*decimal64=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci17_2*c_deci18_6 from deci_overflow where id = 1;", - "15241578780659191108332561.40820000") - } - - test("decimal64*decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci17_2*c_deci22_6 from deci_overflow where id = 1;", - "152415787806736055266232119611.091911") - } - - test("decimal128*decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci21_6*c_deci22_6 from deci_overflow where id = 1;", - "152415787806736335252649604807.436065") - } - - test("decimal64/decimal64=decimal64 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0/c_deci7_2 from deci_overflow where id = 1;", "0.99999028") - } - - test("decimal64/decimal64=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci17_2/c_deci18_6 from deci_overflow where id = 1;", - "1000.00000000094146301") - } - - test("decimal64/decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci17_2/c_deci22_6 from deci_overflow where id = 1;", - "0.09999999999999957") - } - - test("decimal128/decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci22_6/c_deci21_6 from deci_overflow where id = 1;", - "10.0000000000000243") - } - - test("decimal64%decimal64=decimal64 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0%c_deci7_2 from deci_overflow where id = 1;", "12345.00") - } - - test("decimal64%decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci22_6%c_deci17_2 from deci_overflow where id = 1;", "5.267891") - } - - test("decimal128%decimal128=decimal128 positive decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci22_6%c_deci21_6 from deci_overflow where id = 1;", "3.000001") - } - - - /* normal negative decimal operation test cases */ - // spark.sql.ansi.enabled=false - test("decimal64+decimal64=decimal64 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0+c_deci7_2 from deci_overflow where id = 6;", "0.12") - } - - test("decimal64+decimal64=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci17_2+c_deci18_6 from deci_overflow where id = 6;", "-123333332334332.890000") - } - - test("decimal64+decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci17_2+c_deci22_6 from deci_overflow where id = 6;", "1111111102111111.337891") - } - - test("decimal128+decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci21_6+c_deci22_6 from deci_overflow where id = 6;", "1111111102111111.111102") - } - - test("decimal64-decimal64=decimal64 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci7_2-c_deci5_0 from deci_overflow where id = 6;", "24690.12") - } - - test("decimal64-decimal64=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci18_6-c_deci17_2 from deci_overflow where id = 6;", "123580245912579.570000") - } - - test("decimal64-decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci22_6-c_deci17_2 from deci_overflow where id = 6;", "1358024680358023.797891") - } - - test("decimal128-decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci22_6-c_deci21_6 from deci_overflow where id = 6;", "1358024680358024.024680") - } - - test("decimal64*decimal64=decimal64 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0*c_deci7_2 from deci_overflow where id = 6;", "-152400506.40") - } - - test("decimal64*decimal64=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci17_2*c_deci18_6 from deci_overflow where id = 6;", - "-15241578780659191108332561.40820000") - } - - test("decimal64*decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci17_2*c_deci22_6 from deci_overflow where id = 6;", - "-152415787806736055266232119611.091911") - } - - test("decimal128*decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci21_6*c_deci22_6 from deci_overflow where id = 6;", - "-152415787806736335252649604807.436065") - } - - test("decimal64/decimal64=decimal64 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0/c_deci7_2 from deci_overflow where id = 6;", "-0.99999028") - } - - test("decimal64/decimal64=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci17_2/c_deci18_6 from deci_overflow where id = 6;", - "-1000.00000000094146301") - } - - test("decimal64/decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci17_2/c_deci22_6 from deci_overflow where id = 6;", - "-0.09999999999999957") - } - - test("decimal128/decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci22_6/c_deci21_6 from deci_overflow where id = 6;", - "-10.0000000000000243") - } - - test("decimal64%decimal64=decimal64 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0%c_deci7_2 from deci_overflow where id = 6;", "-12345.00") - } - - test("decimal64%decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci22_6%c_deci17_2 from deci_overflow where id = 6;", "5.267891") - } - - test("decimal128%decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=false") { - checkResult("select c_deci21_6%c_deci22_6 from deci_overflow where id = 6;", "-123456789123456.456789") - } - - // spark.sql.ansi.enabled=false - test("decimal64+decimal64=decimal64 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0+c_deci7_2 from deci_overflow where id = 6;", "0.12") - } - - test("decimal64+decimal64=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci17_2+c_deci18_6 from deci_overflow where id = 6;", "-123333332334332.890000") - } - - test("decimal64+decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci17_2+c_deci22_6 from deci_overflow where id = 6;", "1111111102111111.337891") - } - - test("decimal128+decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci21_6+c_deci22_6 from deci_overflow where id = 6;", "1111111102111111.111102") - } - - test("decimal64-decimal64=decimal64 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci7_2-c_deci5_0 from deci_overflow where id = 6;", "24690.12") - } - - test("decimal64-decimal64=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci18_6-c_deci17_2 from deci_overflow where id = 6;", "123580245912579.570000") - } - - test("decimal64-decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci22_6-c_deci17_2 from deci_overflow where id = 6;", "1358024680358023.797891") - } - - test("decimal128-decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci22_6-c_deci21_6 from deci_overflow where id = 6;", "1358024680358024.024680") - } - - test("decimal64*decimal64=decimal64 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0*c_deci7_2 from deci_overflow where id = 6;", "-152400506.40") - } - - test("decimal64*decimal64=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci17_2*c_deci18_6 from deci_overflow where id = 6;", - "-15241578780659191108332561.40820000") - } - - test("decimal64*decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci17_2*c_deci22_6 from deci_overflow where id = 6;", - "-152415787806736055266232119611.091911") - } - - test("decimal128*decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci21_6*c_deci22_6 from deci_overflow where id = 6;", - "-152415787806736335252649604807.436065") - } - - test("decimal64/decimal64=decimal64 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0/c_deci7_2 from deci_overflow where id = 6;", "-0.99999028") - } - - test("decimal64/decimal64=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci17_2/c_deci18_6 from deci_overflow where id = 6;", - "-1000.00000000094146301") - } - - test("decimal64/decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci17_2/c_deci22_6 from deci_overflow where id = 6;", - "-0.09999999999999957") - } - - test("decimal128/decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci22_6/c_deci21_6 from deci_overflow where id = 6;", - "-10.0000000000000243") - } - - test("decimal64%decimal64=decimal64 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0%c_deci7_2 from deci_overflow where id = 6;", "-12345.00") - } - - test("decimal64%decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci22_6%c_deci17_2 from deci_overflow where id = 6;", "5.267891") - } - - test("decimal128%decimal128=decimal128 negative decimal operation normal when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci21_6%c_deci22_6 from deci_overflow where id = 6;", "-123456789123456.456789") - } - - /* overflow decimal operation test cases */ - // spark.sql.ansi.enabled=false - test("decimal add operation positive overflow when spark.sql.ansi.enabled=false") { - checkResultNull("select (c_deci22_6*c_deci22_6)+c_deci18_6 from deci_overflow where id = 2;") - } - - test("decimal add operation negative overflow when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci21_6+(0-c_deci22_6*c_deci22_6) from deci_overflow where id = 2;") - } - - test("decimal subtract operation positive overflow when spark.sql.ansi.enabled=false") { - checkResultNull("select (c_deci22_6*c_deci22_6)-c_deci21_6 from deci_overflow where id = 2;") - } - - test("decimal subtract operation negative overflow when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci21_6-(c_deci22_6*c_deci22_6) from deci_overflow where id = 2;") - } - - test("decimal multiple operation positive overflow when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci22_6*c_deci22_6*c_deci5_0 from deci_overflow where id = 2;") - } - - test("decimal multiple operation negative overflow when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci21_6*c_deci22_6*c_deci22_6 from deci_overflow where id = 2;") - } - - test("decimal divide operation positive overflow when spark.sql.ansi.enabled=false") { - checkResultNull("select (c_deci22_6*c_deci22_6)/c_deci7_2 from deci_overflow where id = 3;") - } - - test("decimal divide operation negative overflow when spark.sql.ansi.enabled=false") { - checkResultNull("select (0-c_deci22_6*c_deci22_6)/c_deci7_2 from deci_overflow where id = 3;") - } - - // spark.sql.ansi.enabled=true - test("decimal add operation positive overflow when spark.sql.ansi.enabled=true") { - checkAnsiResultOverflowException("select (c_deci22_6*c_deci22_6)+c_deci18_6 from deci_overflow where id = 2;") - } - - test("decimal add operation negative overflow when spark.sql.ansi.enabled=true") { - checkAnsiResultOverflowException("select c_deci21_6+(0-c_deci22_6*c_deci22_6) from deci_overflow where id = 2;") - } - - test("decimal subtract operation positive overflow when spark.sql.ansi.enabled=true") { - checkAnsiResultOverflowException("select (c_deci22_6*c_deci22_6)-c_deci21_6 from deci_overflow where id = 2;") - } - - test("decimal subtract operation negative overflow when spark.sql.ansi.enabled=true") { - checkAnsiResultOverflowException("select c_deci21_6-(c_deci22_6*c_deci22_6) from deci_overflow where id = 2;") - } - - test("decimal multiple operation positive overflow when spark.sql.ansi.enabled=true") { - checkAnsiResultOverflowException("select c_deci22_6*c_deci22_6*c_deci5_0 from deci_overflow where id = 2;") - } - - test("decimal multiple operation negative overflow when spark.sql.ansi.enabled=true") { - checkAnsiResultOverflowException("select c_deci21_6*c_deci22_6*c_deci22_6 from deci_overflow where id = 2;") - } - - test("decimal divide operation positive overflow when spark.sql.ansi.enabled=true") { - checkAnsiResultOverflowException("select (c_deci22_6*c_deci22_6)/c_deci7_2 from deci_overflow where id = 3;") - } - - test("decimal divide operation negative overflow when spark.sql.ansi.enabled=true") { - checkAnsiResultOverflowException("select (0-c_deci22_6*c_deci22_6)/c_deci7_2 from deci_overflow where id = 3;") - } - - /* divide by zero decimal operation test cases */ - // spark.sql.ansi.enabled=false - test("decimal64/decimal64(0) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0/c_deci7_2 from deci_overflow where id = 4") - } - - test("decimal64/decimal128(0) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0/c_deci21_6 from deci_overflow where id = 4") - } - - test("decimal128/decimal64(0) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci22_6/c_deci7_2 from deci_overflow where id = 4") - } - - test("decimal128/decimal128(0) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci22_6/c_deci21_6 from deci_overflow where id = 4") - } - - test("decimal64/literal(0) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci17_2/0 from deci_overflow where id = 4") - } - - test("decimal128/literal(0) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci22_6/0 from deci_overflow where id = 4") - } - - test("decimal64%decimal64(0) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0%c_deci7_2 from deci_overflow where id = 4") - } - - test("decimal128%decimal64(0) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci22_6%c_deci7_2 from deci_overflow where id = 4") - } - - test("decimal64%decimal128(0) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci18_6%c_deci21_6 from deci_overflow where id = 4") - } - - test("decimal128%decimal128(0) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci22_6%c_deci21_6 from deci_overflow where id = 4") - } - - test("decimal64%literal(0) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci18_6%0 from deci_overflow where id = 4") - } - - test("decimal128%literal(0) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci22_6%0 from deci_overflow where id = 4") - } - - // spark.sql.ansi.enabled=true - test("decimal64/decimal64(0) when spark.sql.ansi.enabled=true") { - checkAnsiResultDivideBy0Exception("select c_deci5_0/c_deci7_2 from deci_overflow where id = 4") - } - - test("decimal64/decimal128(0) when spark.sql.ansi.enabled=true") { - checkAnsiResultDivideBy0Exception("select c_deci5_0/c_deci21_6 from deci_overflow where id = 4") - } - - test("decimal128/decimal64(0) when spark.sql.ansi.enabled=true") { - checkAnsiResultDivideBy0Exception("select c_deci22_6/c_deci7_2 from deci_overflow where id = 4") - } - - test("decimal128/decimal128(0) when spark.sql.ansi.enabled=true") { - checkAnsiResultDivideBy0Exception("select c_deci22_6/c_deci21_6 from deci_overflow where id = 4") - } - - test("decimal64/literal(0) when spark.sql.ansi.enabled=true") { - checkAnsiResultDivideBy0Exception("select c_deci17_2/0 from deci_overflow where id = 4") - } - - test("decimal128/literal(0) when spark.sql.ansi.enabled=true") { - checkAnsiResultDivideBy0Exception("select c_deci22_6/0 from deci_overflow where id = 4") - } - - test("decimal64%decimal64(0) when spark.sql.ansi.enabled=true") { - checkAnsiResultDivideBy0Exception("select c_deci5_0%c_deci7_2 from deci_overflow where id = 4") - } - - test("decimal128%decimal64(0) when spark.sql.ansi.enabled=true") { - checkAnsiResultDivideBy0Exception("select c_deci22_6%c_deci7_2 from deci_overflow where id = 4") - } - - test("decimal64%decimal128(0) when spark.sql.ansi.enabled=true") { - checkAnsiResultDivideBy0Exception("select c_deci18_6%c_deci21_6 from deci_overflow where id = 4") - } - - test("decimal128%decimal128(0) when spark.sql.ansi.enabled=true") { - checkAnsiResultDivideBy0Exception("select c_deci22_6%c_deci21_6 from deci_overflow where id = 4") - } - - test("decimal64%literal(0) when spark.sql.ansi.enabled=true") { - checkAnsiResultDivideBy0Exception("select c_deci18_6%0 from deci_overflow where id = 4") - } - - test("decimal128%literal(0) when spark.sql.ansi.enabled=true") { - checkAnsiResultDivideBy0Exception("select c_deci22_6%0 from deci_overflow where id = 4") - } - - /* zero decimal operation test cases */ - // spark.sql.ansi.enabled=false - test("decimal64+decimal64(0) when when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0+c_deci7_2 from deci_overflow where id = 4;", "99999.00") - } - - test("decimal64+decimal128(0) when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0+c_deci21_6 from deci_overflow where id = 4;", "99999.000000") - } - - test("decimal64(0)+decimal128 when spark.sql.ansi.enabled=false") { - checkResult("select c_deci7_2+c_deci22_6 from deci_overflow where id = 4;", "9999999999999999.999999") - } - - test("decimal128(0)+decimal128 when spark.sql.ansi.enabled=false") { - checkResult("select c_deci21_6+c_deci22_6 from deci_overflow where id = 4;", "9999999999999999.999999") - } - - test("decimal64+literal(0) when when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0+0 from deci_overflow where id = 4;", "99999") - } - - test("decimal128+literal(0) when when spark.sql.ansi.enabled=false") { - checkResult("select c_deci22_6+0 from deci_overflow where id = 4;", "9999999999999999.999999") - } - - test("decimal64-decimal64(0) when when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0-c_deci7_2 from deci_overflow where id = 4;", "99999.00") - } - - test("decimal64-decimal128(0) when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0-c_deci21_6 from deci_overflow where id = 4;", "99999.000000") - } - - test("decimal64(0)-decimal128 when spark.sql.ansi.enabled=false") { - checkResult("select c_deci7_2-c_deci22_6 from deci_overflow where id = 4;", "-9999999999999999.999999") - } - - test("decimal128(0)-decimal128 when spark.sql.ansi.enabled=false") { - checkResult("select c_deci21_6-c_deci22_6 from deci_overflow where id = 4;", "-9999999999999999.999999") - } - - test("decimal64-literal(0) when when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0-0 from deci_overflow where id = 4;", "99999") - } - - test("literal(0)-decimal128 when when spark.sql.ansi.enabled=false") { - checkResult("select 0-c_deci22_6 from deci_overflow where id = 4;", "-9999999999999999.999999") - } - - test("decimal64*decimal64(0) when when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0*c_deci7_2 from deci_overflow where id = 4;", "0.00") - } - - test("decimal64*decimal128(0) when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0*c_deci21_6 from deci_overflow where id = 4;", "0.000000") - } - - test("decimal64(0)*decimal128 when spark.sql.ansi.enabled=false") { - checkResult("select c_deci7_2*c_deci22_6 from deci_overflow where id = 4;", "0E-8") - } - - test("decimal128(0)*decimal128 when spark.sql.ansi.enabled=false") { - checkResult("select c_deci21_6*c_deci22_6 from deci_overflow where id = 4;", "0.000000") - } - - test("decimal64*literal(0) when when spark.sql.ansi.enabled=false") { - checkResult("select c_deci5_0*0 from deci_overflow where id = 4;", "0") - } - - test("literal(0)*decimal128 when when spark.sql.ansi.enabled=false") { - checkResult("select 0*c_deci22_6 from deci_overflow where id = 4;", "0.000000") - } - - test("decimal64(0)/decimal64 when when spark.sql.ansi.enabled=false") { - checkResult("select c_deci7_2/c_deci5_0 from deci_overflow where id = 4;", "0E-8") - } - - test("decimal128(0)/decimal64 when spark.sql.ansi.enabled=false") { - checkResult("select c_deci21_6/c_deci5_0 from deci_overflow where id = 4;", "0E-12") - } - - test("decimal64(0)/decimal128 when spark.sql.ansi.enabled=false") { - checkResult("select c_deci7_2/c_deci22_6 from deci_overflow where id = 4;", "0E-25") - } - - test("decimal128(0)/decimal128 when spark.sql.ansi.enabled=false") { - checkResult("select c_deci21_6/c_deci22_6 from deci_overflow where id = 4;", "0E-17") - } - - test("literal(0)/decimal64 when when spark.sql.ansi.enabled=false") { - checkResult("select 0/c_deci5_0 from deci_overflow where id = 4;", "0.000000") - } - - test("literal(0)/decimal128 when when spark.sql.ansi.enabled=false") { - checkResult("select 0/c_deci22_6 from deci_overflow where id = 4;", "0E-23") - } - - test("decimal64(0)%decimal64 when when spark.sql.ansi.enabled=false") { - checkResult("select c_deci7_2%c_deci5_0 from deci_overflow where id = 4;", "0.00") - } - - test("decimal128(0)%decimal64 when spark.sql.ansi.enabled=false") { - checkResult("select c_deci21_6%c_deci5_0 from deci_overflow where id = 4;", "0.000000") - } - - test("decimal64(0)%decimal128 when spark.sql.ansi.enabled=false") { - checkResult("select c_deci7_2%c_deci22_6 from deci_overflow where id = 4;", "0.000000") - } - - test("decimal128(0)%decimal128 when spark.sql.ansi.enabled=false") { - checkResult("select c_deci21_6%c_deci22_6 from deci_overflow where id = 4;", "0.000000") - } - - test("literal(0)%decimal64 when when spark.sql.ansi.enabled=false") { - checkResult("select 0%c_deci5_0 from deci_overflow where id = 4;", "0") - } - - test("literal(0)%decimal128 when when spark.sql.ansi.enabled=false") { - checkResult("select 0%c_deci22_6 from deci_overflow where id = 4;", "0.000000") - } - - // spark.sql.ansi.enabled=true - test("decimal64+decimal64(0) when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0+c_deci7_2 from deci_overflow where id = 4;", "99999.00") - } - - test("decimal64+decimal128(0) when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0+c_deci21_6 from deci_overflow where id = 4;", "99999.000000") - } - - test("decimal64(0)+decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci7_2+c_deci22_6 from deci_overflow where id = 4;", "9999999999999999.999999") - } - - test("decimal128(0)+decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci21_6+c_deci22_6 from deci_overflow where id = 4;", "9999999999999999.999999") - } - - test("decimal64+literal(0) when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0+0 from deci_overflow where id = 4;", "99999") - } - - test("decimal128+literal(0) when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci22_6+0 from deci_overflow where id = 4;", "9999999999999999.999999") - } - - test("decimal64-decimal64(0) when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0-c_deci7_2 from deci_overflow where id = 4;", "99999.00") - } - - test("decimal64-decimal128(0) when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0-c_deci21_6 from deci_overflow where id = 4;", "99999.000000") - } - - test("decimal64(0)-decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci7_2-c_deci22_6 from deci_overflow where id = 4;", "-9999999999999999.999999") - } - - test("decimal128(0)-decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci21_6-c_deci22_6 from deci_overflow where id = 4;", "-9999999999999999.999999") - } - - test("decimal64-literal(0) when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0-0 from deci_overflow where id = 4;", "99999") - } - - test("literal(0)-decimal128 when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select 0-c_deci22_6 from deci_overflow where id = 4;", "-9999999999999999.999999") - } - - test("decimal64*decimal64(0) when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0*c_deci7_2 from deci_overflow where id = 4;", "0.00") - } - - test("decimal64*decimal128(0) when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0*c_deci21_6 from deci_overflow where id = 4;", "0.000000") - } - - test("decimal64(0)*decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci7_2*c_deci22_6 from deci_overflow where id = 4;", "0E-8") - } - - test("decimal128(0)*decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci21_6*c_deci22_6 from deci_overflow where id = 4;", "0.000000") - } - - test("decimal64*literal(0) when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci5_0*0 from deci_overflow where id = 4;", "0") - } - - test("literal(0)*decimal128 when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select 0*c_deci22_6 from deci_overflow where id = 4;", "0.000000") - } - - test("decimal64(0)/decimal64 when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci7_2/c_deci5_0 from deci_overflow where id = 4;", "0E-8") - } - - test("decimal128(0)/decimal64 when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci21_6/c_deci5_0 from deci_overflow where id = 4;", "0E-12") - } - - test("decimal64(0)/decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci7_2/c_deci22_6 from deci_overflow where id = 4;", "0E-25") - } - - test("decimal128(0)/decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci21_6/c_deci22_6 from deci_overflow where id = 4;", "0E-17") - } - - test("literal(0)/decimal64 when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select 0/c_deci5_0 from deci_overflow where id = 4;", "0.000000") - } - - test("literal(0)/decimal128 when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select 0/c_deci22_6 from deci_overflow where id = 4;", "0E-23") - } - - test("decimal64(0)%decimal64 when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci7_2%c_deci5_0 from deci_overflow where id = 4;", "0.00") - } - - test("decimal128(0)%decimal64 when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci21_6%c_deci5_0 from deci_overflow where id = 4;", "0.000000") - } - - test("decimal64(0)%decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci7_2%c_deci22_6 from deci_overflow where id = 4;", "0.000000") - } - - test("decimal128(0)%decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResult("select c_deci21_6%c_deci22_6 from deci_overflow where id = 4;", "0.000000") - } - - test("literal(0)%decimal64 when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select 0%c_deci5_0 from deci_overflow where id = 4;", "0") - } - - test("literal(0)%decimal128 when when spark.sql.ansi.enabled=true") { - checkAnsiResult("select 0%c_deci22_6 from deci_overflow where id = 4;", "0.000000") - } - - /* NULL decimal operation test cases */ - // spark.sql.ansi.enabled=false - test("decimal64+decimal64(NULL) when when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0+c_deci7_2 from deci_overflow where id = 5;") - } - - test("decimal64+decimal128(NULL) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0+c_deci21_6 from deci_overflow where id = 5;") - } - - test("decimal64(NULL)+decimal128 when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci7_2+c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal128(NULL)+decimal128 when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci21_6+c_deci22_6 from deci_overflow where id = 5;") - } - - test("literal(NULL)+decimal64 when when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0+NULL from deci_overflow where id = 5;") - } - - test("decimal128+literal(NULL) when when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci22_6+NULL from deci_overflow where id = 5;") - } - - test("decimal64-decimal64(NULL) when when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0-c_deci7_2 from deci_overflow where id = 5;") - } - - test("decimal64-decimal128(NULL) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0-c_deci21_6 from deci_overflow where id = 5;") - } - - test("decimal64(NULL)-decimal128 when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci7_2-c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal128(NULL)-decimal128 when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci21_6-c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal64-literal(NULL) when when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0-NULL from deci_overflow where id = 5;") - } - - test("literal(NULL)-decimal128 when when spark.sql.ansi.enabled=false") { - checkResultNull("select NULL-c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal64*decimal64(NULL) when when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0*c_deci7_2 from deci_overflow where id = 5;") - } - - test("decimal64*decimal128(NULL) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0*c_deci21_6 from deci_overflow where id = 5;") - } - - test("decimal64(NULL)*decimal128 when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci7_2*c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal128*decimal128(NULL) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci22_6*c_deci21_6 from deci_overflow where id = 5;") - } - - test("decimal64*literal(NULL) when when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0*NULL from deci_overflow where id = 5;") - } - - test("literal(NULL)*decimal128 when when spark.sql.ansi.enabled=false") { - checkResultNull("select NULL*c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal64/decimal64(NULL) when when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0/c_deci7_2 from deci_overflow where id = 5;") - } - - test("decimal128(NULL)/decimal64 when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci21_6/c_deci5_0 from deci_overflow where id = 5;") - } - - test("decimal128/decimal64(NULL) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci22_6/c_deci7_2 from deci_overflow where id = 5;") - } - - test("decimal128(NULL)/decimal128 when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci21_6/c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal64/literal(NULL) when when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0/NULL from deci_overflow where id = 5;") - } - - test("literal(NULL)/decimal128 when when spark.sql.ansi.enabled=false") { - checkResultNull("select NULL/c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal64%decimal64(NULL) when when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci5_0%c_deci7_2 from deci_overflow where id = 5;") - } - - test("decimal128(NULL)%decimal64 when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci21_6%c_deci5_0 from deci_overflow where id = 5;") - } - - test("decimal64(NULL)%decimal128 when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci7_2%c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal128%decimal128(NULL) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci22_6%c_deci21_6 from deci_overflow where id = 5;") - } - - test("literal(NULL)%decimal64 when when spark.sql.ansi.enabled=false") { - checkResultNull("select NULL%c_deci5_0 from deci_overflow where id = 5;") - } - - test("decimal128%literal(NULL) when when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci22_6%NULL from deci_overflow where id = 5;") - } - - // spark.sql.ansi.enabled=true - test("decimal64+decimal64(NULL) when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci5_0+c_deci7_2 from deci_overflow where id = 5;") - } - - test("decimal64+decimal128(NULL) when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci5_0+c_deci21_6 from deci_overflow where id = 5;") - } - - test("decimal64(NULL)+decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci7_2+c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal128(NULL)+decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci21_6+c_deci22_6 from deci_overflow where id = 5;") - } - - test("literal(NULL)+decimal64 when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci5_0+NULL from deci_overflow where id = 5;") - } - - test("decimal128+literal(NULL) when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci22_6+NULL from deci_overflow where id = 5;") - } - - test("decimal64-decimal64(NULL) when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci5_0-c_deci7_2 from deci_overflow where id = 5;") - } - - test("decimal64-decimal128(NULL) when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci5_0-c_deci21_6 from deci_overflow where id = 5;") - } - - test("decimal64(NULL)-decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci7_2-c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal128(NULL)-decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci21_6-c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal64-literal(NULL) when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci5_0-NULL from deci_overflow where id = 5;") - } - - test("literal(NULL)-decimal128 when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select NULL-c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal64*decimal64(NULL) when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci5_0*c_deci7_2 from deci_overflow where id = 5;") - } - - test("decimal64*decimal128(NULL) when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci5_0*c_deci21_6 from deci_overflow where id = 5;") - } - - test("decimal64(NULL)*decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci7_2*c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal128*decimal128(NULL) when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci22_6*c_deci21_6 from deci_overflow where id = 5;") - } - - test("decimal64*literal(NULL) when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci5_0*NULL from deci_overflow where id = 5;") - } - - test("literal(NULL)*decimal128 when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select NULL*c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal64/decimal64(NULL) when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci5_0/c_deci7_2 from deci_overflow where id = 5;") - } - - test("decimal128(NULL)/decimal64 when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci21_6/c_deci5_0 from deci_overflow where id = 5;") - } - - test("decimal128/decimal64(NULL) when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci22_6/c_deci7_2 from deci_overflow where id = 5;") - } - - test("decimal128(NULL)/decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci21_6/c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal64/literal(NULL) when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci5_0/NULL from deci_overflow where id = 5;") - } - - test("literal(NULL)/decimal128 when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select NULL/c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal64%decimal64(NULL) when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci5_0%c_deci7_2 from deci_overflow where id = 5;") - } - - test("decimal128(NULL)%decimal64 when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci21_6%c_deci5_0 from deci_overflow where id = 5;") - } - - test("decimal64(NULL)%decimal128 when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci7_2%c_deci22_6 from deci_overflow where id = 5;") - } - - test("decimal128%decimal128(NULL) when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci22_6%c_deci21_6 from deci_overflow where id = 5;") - } - - test("literal(NULL)%decimal64 when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select NULL%c_deci5_0 from deci_overflow where id = 5;") - } - - test("decimal128%literal(NULL) when when spark.sql.ansi.enabled=true") { - checkAnsiResultNull("select c_deci22_6%NULL from deci_overflow where id = 5;") - } - -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala deleted file mode 100644 index cf2537484aefdcd68214db0046877652847bb34b..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala +++ /dev/null @@ -1,1519 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.adaptive - -import org.apache.log4j.Level -import org.apache.spark.Partition -import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} -import org.apache.spark.sql.{Dataset, Row, SparkSession, Strategy} -import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.execution.command.DataWritingCommandExec -import org.apache.spark.sql.execution.datasources.noop.NoopDataSource -import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec -import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarSparkPlanTest, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledColumnarRDD, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.{Exchange, REPARTITION, REPARTITION_WITH_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} -import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, ColumnarBroadcastHashJoinExec, ColumnarSortMergeJoinExec, SortMergeJoinExec} -import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter -import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate -import org.apache.spark.sql.functions.{sum, when} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode -import org.apache.spark.sql.types.{IntegerType, StructType} -import org.apache.spark.sql.util.QueryExecutionListener -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.Utils - -import java.io.File -import java.net.URI - -class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest - with AdaptiveSparkPlanHelper { - - import testImplicits._ - - setupTestData() - - private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = { - var finalPlanCnt = 0 - val listener = new SparkListener { - override def onOtherEvent(event: SparkListenerEvent): Unit = { - event match { - case SparkListenerSQLAdaptiveExecutionUpdate(_, _, sparkPlanInfo) => - if (sparkPlanInfo.simpleString.startsWith( - "AdaptiveSparkPlan isFinalPlan=true")) { - finalPlanCnt += 1 - } - case _ => // ignore other events - } - } - } - spark.sparkContext.addSparkListener(listener) - - val dfAdaptive = sql(query) - val planBefore = dfAdaptive.queryExecution.executedPlan - assert(planBefore.toString.startsWith("AdaptiveSparkPlan isFinalPlan=false")) - val result = dfAdaptive.collect() - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { - val df = sql(query) - checkAnswer(df, result) - } - val planAfter = dfAdaptive.queryExecution.executedPlan - assert(planAfter.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true")) - val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan - - spark.sparkContext.listenerBus.waitUntilEmpty() - // AQE will post `SparkListenerSQLAdaptiveExecutionUpdate` twice in case of subqueries that - // exist out of query stages. - val expectedFinalPlanCnt = adaptivePlan.find(_.subqueries.nonEmpty).map(_ => 2).getOrElse(1) - assert(finalPlanCnt == expectedFinalPlanCnt) - spark.sparkContext.removeSparkListener(listener) - - val exchanges = adaptivePlan.collect { - case e: Exchange => e - } - assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.") - (dfAdaptive.queryExecution.sparkPlan, adaptivePlan) - } - - private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = { - collect(plan) { - case j: BroadcastHashJoinExec => j - } - } - - private def findTopLevelColumnarBroadcastHashJoin(plan: SparkPlan) - : Seq[ColumnarBroadcastHashJoinExec] = { - collect(plan) { - case j: ColumnarBroadcastHashJoinExec => j - } - } - - private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = { - collect(plan) { - case j: SortMergeJoinExec => j - } - } - - private def findTopLevelColumnarSortMergeJoin(plan: SparkPlan): Seq[ColumnarSortMergeJoinExec] = { - collect(plan) { - case j: ColumnarSortMergeJoinExec => j - } - } - - private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = { - collect(plan) { - case j: BaseJoinExec => j - } - } - - private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = { - collectWithSubqueries(plan) { - case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e - case BroadcastQueryStageExec(_, e: ReusedExchangeExec) => e - } - } - - private def findReusedSubquery(plan: SparkPlan): Seq[ReusedSubqueryExec] = { - collectWithSubqueries(plan) { - case e: ReusedSubqueryExec => e - } - } - - private def checkNumLocalShuffleReaders( - plan: SparkPlan, numShufflesWithoutLocalReader: Int = 0): Unit = { - val numShuffles = collect(plan) { - case s: ShuffleQueryStageExec => s - }.length - - val numLocalReaders = collect(plan) { - case rowReader: CustomShuffleReaderExec if rowReader.isLocalReader => rowReader - case colReader: ColumnarCustomShuffleReaderExec if colReader.isLocalReader => colReader - } - numLocalReaders.foreach { - case rowCus: CustomShuffleReaderExec => - val rdd = rowCus.execute() - val parts = rdd.partitions - assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) - case r => - val columnarCus = r.asInstanceOf[ColumnarCustomShuffleReaderExec] - val rdd: RDD[ColumnarBatch] = columnarCus.executeColumnar() - val parts: Array[Partition] = rdd.partitions - assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) - } - assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader)) - } - - private def checkInitialPartitionNum(df: Dataset[_], numPartition: Int): Unit = { - // repartition obeys initialPartitionNum when adaptiveExecutionEnabled - val plan = df.queryExecution.executedPlan - assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) - val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { - case s: ShuffleExchangeExec => s - } - assert(shuffle.size == 1) - assert(shuffle(0).outputPartitioning.numPartitions == numPartition) - } - - test("Change merge join to broadcast join") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM testData join testData2 ON key = a where value = '1'") - val smj: Seq[SortMergeJoinExec] = findTopLevelSortMergeJoin(plan) - assert(smj.size == 1) - val bhj: Seq[ColumnarBroadcastHashJoinExec] = - findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) - } - } - - test("Reuse the parallelism of CoalescedShuffleReaderExec in LocalShuffleReaderExec") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", - SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM testData join testData2 ON key = a where value = '1'") - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 1) - val localReaders = collect(adaptivePlan) { - case reader: ColumnarCustomShuffleReaderExec if reader.isLocalReader => reader - } - assert(localReaders.length == 2) - val localShuffleRDD0 = localReaders(0).executeColumnar().asInstanceOf[ShuffledColumnarRDD] - val localShuffleRDD1 = localReaders(1).executeColumnar().asInstanceOf[ShuffledColumnarRDD] - // The pre-shuffle partition size is [0, 0, 0, 72, 0] - // We exclude the 0-size partitions, so only one partition, advisoryParallelism = 1 - // the final parallelism is - // math.max(1, advisoryParallelism / numMappers): math.max(1, 1/2) = 1 - // and the partitions length is 1 * numMappers = 2 - assert(localShuffleRDD0.getPartitions.length == 2) - // The pre-shuffle partition size is [0, 72, 0, 72, 126] - // We exclude the 0-size partitions, so only 3 partition, advisoryParallelism = 3 - // the final parallelism is - // math.max(1, advisoryParallelism / numMappers): math.max(1, 3/2) = 1 - // and the partitions length is 1 * numMappers = 2 - assert(localShuffleRDD1.getPartitions.length == 2) - } - } - - test("Reuse the default parallelism in LocalShuffleReaderExec") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", - SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM testData join testData2 ON key = a where value = '1'") - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 1) - val localReaders = collect(adaptivePlan) { - case reader: ColumnarCustomShuffleReaderExec if reader.isLocalReader => reader - } - assert(localReaders.length == 2) - val localShuffleRDD0 = localReaders(0).executeColumnar().asInstanceOf[ShuffledColumnarRDD] - val localShuffleRDD1 = localReaders(1).executeColumnar().asInstanceOf[ShuffledColumnarRDD] - // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 - // and the partitions length is 2 * numMappers = 4 - assert(localShuffleRDD0.getPartitions.length == 4) - // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 - // and the partitions length is 2 * numMappers = 4 - assert(localShuffleRDD1.getPartitions.length == 4) - } - } - - test("Empty stage coalesced to 1-partition RDD") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") { - val df1 = spark.range(10).withColumn("a", 'id) - val df2 = spark.range(10).withColumn("b", 'id) - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") - .groupBy('a).count() - checkAnswer(testDf, Seq()) - val plan = testDf.queryExecution.executedPlan - assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) - val coalescedReaders = collect(plan) { - case r: ColumnarCustomShuffleReaderExec => r - } - assert(coalescedReaders.length == 3) - coalescedReaders.foreach(r => assert(r.partitionSpecs.length == 1)) - } - - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { - val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") - .groupBy('a).count() - checkAnswer(testDf, Seq()) - val plan = testDf.queryExecution.executedPlan - print(plan) - assert(find(plan)(_.isInstanceOf[ColumnarBroadcastHashJoinExec]).isDefined) - val coalescedReaders = collect(plan) { - case r: ColumnarCustomShuffleReaderExec => r - } - assert(coalescedReaders.length == 3, s"$plan") - coalescedReaders.foreach(r => assert(r.isLocalReader || r.partitionSpecs.length == 1)) - } - } - } - - test("Scalar subquery") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM testData join testData2 ON key = a " + - "where value = (SELECT max(a) from testData3)") - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) - } - } - - // Currently, OmniFilterExec will fall back to Filter, if AQE is enabled, it will cause error - ignore("Scalar subquery in later stages") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM testData join testData2 ON key = a " + - "where (value + a) = (SELECT max(a) from testData3)") - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) - } - } - - test("multiple joins") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - """ - |WITH t4 AS ( - | SELECT * FROM lowercaseData t2 JOIN testData3 t3 ON t2.n = t3.a where t2.n = '1' - |) - |SELECT * FROM testData - |JOIN testData2 t2 ON key = t2.a - |JOIN t4 ON t2.b = t4.a - |WHERE value = 1 - """.stripMargin) - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 3) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 3) - - // A possible resulting query plan: - // BroadcastHashJoin - // +- BroadcastExchange - // +- LocalShuffleReader* - // +- ShuffleExchange - // +- BroadcastHashJoin - // +- BroadcastExchange - // +- LocalShuffleReader* - // +- ShuffleExchange - // +- LocalShuffleReader* - // +- ShuffleExchange - // +- BroadcastHashJoin - // +- LocalShuffleReader* - // +- ShuffleExchange - // +- BroadcastExchange - // +-LocalShuffleReader* - // +- ShuffleExchange - - // After applied the 'OptimizeLocalShuffleReader' rule, we can convert all the four - // shuffle reader to local shuffle reader in the bottom two 'BroadcastHashJoin'. - // For the top level 'BroadcastHashJoin', the probe side is not shuffle query stage - // and the build side shuffle query stage is also converted to local shuffle reader. - checkNumLocalShuffleReaders(adaptivePlan) - } - } - - test("multiple joins with aggregate") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - """ - |WITH t4 AS ( - | SELECT * FROM lowercaseData t2 JOIN ( - | select a, sum(b) from testData3 group by a - | ) t3 ON t2.n = t3.a where t2.n = '1' - |) - |SELECT * FROM testData - |JOIN testData2 t2 ON key = t2.a - |JOIN t4 ON t2.b = t4.a - |WHERE value = 1 - """.stripMargin) - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 3) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 3) - - // A possible resulting query plan: - // BroadcastHashJoin - // +- BroadcastExchange - // +- LocalShuffleReader* - // +- ShuffleExchange - // +- BroadcastHashJoin - // +- BroadcastExchange - // +- LocalShuffleReader* - // +- ShuffleExchange - // +- LocalShuffleReader* - // +- ShuffleExchange - // +- BroadcastHashJoin - // +- LocalShuffleReader* - // +- ShuffleExchange - // +- BroadcastExchange - // +-HashAggregate - // +- CoalescedShuffleReader - // +- ShuffleExchange - - // The shuffle added by Aggregate can't apply local reader. - checkNumLocalShuffleReaders(adaptivePlan, 1) - } - } - - test("multiple joins with aggregate 2") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - """ - |WITH t4 AS ( - | SELECT * FROM lowercaseData t2 JOIN ( - | select a, max(b) b from testData2 group by a - | ) t3 ON t2.n = t3.b - |) - |SELECT * FROM testData - |JOIN testData2 t2 ON key = t2.a - |JOIN t4 ON value = t4.a - |WHERE value = 1 - """.stripMargin) - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 3) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 2) - - // A possible resulting query plan: - // BroadcastHashJoin - // +- BroadcastExchange - // +- LocalShuffleReader* - // +- ShuffleExchange - // +- BroadcastHashJoin - // +- BroadcastExchange - // +- LocalShuffleReader* - // +- ShuffleExchange - // +- LocalShuffleReader* - // +- ShuffleExchange - // +- BroadcastHashJoin - // +- Filter - // +- HashAggregate - // +- CoalescedShuffleReader - // +- ShuffleExchange - // +- BroadcastExchange - // +-LocalShuffleReader* - // +- ShuffleExchange - - // The shuffle added by Aggregate can't apply local reader. - checkNumLocalShuffleReaders(adaptivePlan, 1) - } - } - - test("Exchange reuse") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT value FROM testData join testData2 ON key = a " + - "join (SELECT value v from testData join testData3 ON key = a) on value = v") - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 3) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 3) - // There is no SMJ - checkNumLocalShuffleReaders(adaptivePlan, 0) - // Even with local shuffle reader, the query stage reuse can also work. - val ex = findReusedExchange(adaptivePlan) - assert(ex.size == 1) - } - } - - test("Exchange reuse with subqueries") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT a FROM testData join testData2 ON key = a " + - "where value = (SELECT max(a) from testData join testData2 ON key = a)") - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) - // Even with local shuffle reader, the query stage reuse can also work. - val ex = findReusedExchange(adaptivePlan) - assert(ex.size == 1) - } - } - - test("Exchange reuse across subqueries") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", - SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT a FROM testData join testData2 ON key = a " + - "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + - "and a <= (SELECT max(a) from testData join testData2 ON key = a)") - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) - // Even with local shuffle reader, the query stage reuse can also work. - val ex = findReusedExchange(adaptivePlan) - assert(ex.nonEmpty) - val sub = findReusedSubquery(adaptivePlan) - assert(sub.isEmpty) - } - } - - test("Subquery reuse") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT a FROM testData join testData2 ON key = a " + - "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + - "and a <= (SELECT max(a) from testData join testData2 ON key = a)") - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) - // Even with local shuffle reader, the query stage reuse can also work. - val ex = findReusedExchange(adaptivePlan) - assert(ex.isEmpty) - val sub = findReusedSubquery(adaptivePlan) - assert(sub.nonEmpty) - } - } - - test("Broadcast exchange reuse across subqueries") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000", - SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT a FROM testData join testData2 ON key = a " + - "where value >= (" + - "SELECT /*+ broadcast(testData2) */ max(key) from testData join testData2 ON key = a) " + - "and a <= (" + - "SELECT /*+ broadcast(testData2) */ max(value) from testData join testData2 ON key = a)") - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) - // Even with local shuffle reader, the query stage reuse can also work. - val ex = findReusedExchange(adaptivePlan) - assert(ex.nonEmpty) - assert(ex.head.child.isInstanceOf[ColumnarBroadcastExchangeExec]) - val sub = findReusedSubquery(adaptivePlan) - assert(sub.isEmpty) - } - } - - test("Union/Except/Intersect queries") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - runAdaptiveAndVerifyResult( - """ - |SELECT * FROM testData - |EXCEPT - |SELECT * FROM testData2 - |UNION ALL - |SELECT * FROM testData - |INTERSECT ALL - |SELECT * FROM testData2 - """.stripMargin) - } - } - - test("Subquery de-correlation in Union queries") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - withTempView("a", "b") { - Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a") - Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b") - - runAdaptiveAndVerifyResult( - """ - |SELECT id,num,source FROM ( - | SELECT id, num, 'a' as source FROM a - | UNION ALL - | SELECT id, num, 'b' as source FROM b - |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) - """.stripMargin) - } - } - } - - test("Avoid plan change if cost is greater") { - val origPlan = sql("SELECT * FROM testData " + - "join testData2 t2 ON key = t2.a " + - "join testData2 t3 on t2.a = t3.a where t2.b = 1").queryExecution.executedPlan - - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "25", - SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT.key -> "0") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM testData " + - "join testData2 t2 ON key = t2.a " + - "join testData2 t3 on t2.a = t3.a where t2.b = 1") - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 2) - val smj2 = findTopLevelSortMergeJoin(adaptivePlan) - assert(smj2.size == 2, origPlan.toString) - } - } - - test("Change merge join to broadcast join without local shuffle reader") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.LOCAL_SHUFFLE_READER_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "25") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - """ - |SELECT * FROM testData t1 join testData2 t2 - |ON t1.key = t2.a join testData3 t3 on t2.a = t3.a - |where t1.value = 1 - """.stripMargin - ) - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 2) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan, 2) - } - } - - test("Avoid changing merge join to broadcast join if too many empty partitions on build plan") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") { - // `testData` is small enough to be broadcast but has empty partition ratio over the config. - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM testData join testData2 ON key = a where value = '1'") - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 1) - val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) - assert(bhj.isEmpty) - } - // It is still possible to broadcast `testData2`. - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM testData join testData2 ON key = a where value = '1'") - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 1) - assert(bhj.head.buildSide == BuildRight) - } - } - } - - test("SPARK-29906: AQE should not introduce extra shuffle for outermost limit") { - var numStages = 0 - val listener = new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - numStages = jobStart.stageInfos.length - } - } - try { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - spark.sparkContext.addSparkListener(listener) - spark.range(0, 100, 1, numPartitions = 10).take(1) - spark.sparkContext.listenerBus.waitUntilEmpty() - // Should be only one stage since there is no shuffle. - assert(numStages == 1) - } - } finally { - spark.sparkContext.removeSparkListener(listener) - } - } - - test("SPARK-30524: Do not optimize skew join if introduce additional shuffle") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", - SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") { - withTempView("skewData1", "skewData2") { - spark - .range(0, 1000, 1, 10) - .selectExpr("id % 3 as key1", "id as value1") - .createOrReplaceTempView("skewData1") - spark - .range(0, 1000, 1, 10) - .selectExpr("id % 1 as key2", "id as value2") - .createOrReplaceTempView("skewData2") - - def checkSkewJoin(query: String, optimizeSkewJoin: Boolean): Unit = { - val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query) - val innerSmj = findTopLevelColumnarSortMergeJoin(innerAdaptivePlan) - assert(innerSmj.size == 1 && innerSmj.head.isSkewJoin == optimizeSkewJoin) - } - - checkSkewJoin( - "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2", true) - // Additional shuffle introduced, so disable the "OptimizeSkewedJoin" optimization - checkSkewJoin( - "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 GROUP BY key1", false) - } - } - } - - ignore("SPARK-29544: adaptive skew join with different join types") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", - SQLConf.SHUFFLE_PARTITIONS.key -> "100", - SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", - SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") { - withTempView("skewData1", "skewData2") { - spark - .range(0, 1000, 1, 10) - .select( - when('id < 250, 249) - .when('id >= 750, 1000) - .otherwise('id).as("key1"), - 'id as "value1") - .createOrReplaceTempView("skewData1") - spark - .range(0, 1000, 1, 10) - .select( - when('id < 250, 249) - .otherwise('id).as("key2"), - 'id as "value2") - .createOrReplaceTempView("skewData2") - - def checkSkewJoin( - joins: Seq[SortMergeJoinExec], - leftSkewNum: Int, - rightSkewNum: Int): Unit = { - assert(joins.size == 1 && joins.head.isSkewJoin) - assert(joins.head.left.collect { - case r: ColumnarCustomShuffleReaderExec => r - }.head.partitionSpecs.collect { - case p: PartialReducerPartitionSpec => p.reducerIndex - }.distinct.length == leftSkewNum) - assert(joins.head.right.collect { - case r: ColumnarCustomShuffleReaderExec => r - }.head.partitionSpecs.collect { - case p: PartialReducerPartitionSpec => p.reducerIndex - }.distinct.length == rightSkewNum) - } - - // skewed inner join optimization - val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM skewData1 join skewData2 ON key1 = key2") - val innerSmj = findTopLevelColumnarSortMergeJoin(innerAdaptivePlan) - checkSkewJoin(innerSmj, 1, 1) - - // skewed left outer join optimization - val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2") - val leftSmj = findTopLevelColumnarSortMergeJoin(leftAdaptivePlan) - checkSkewJoin(leftSmj, 2, 0) - - // skewed right outer join optimization - val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2") - val rightSmj = findTopLevelColumnarSortMergeJoin(rightAdaptivePlan) - checkSkewJoin(rightSmj, 0, 1) - } - } - } - - test("SPARK-30291: AQE should catch the exceptions when doing materialize") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - withTable("bucketed_table") { - val df1 = - (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") - df1.write.format("orc").bucketBy(8, "i").saveAsTable("bucketed_table") - val warehouseFilePath = new URI(spark.sessionState.conf.warehousePath).getPath - val tableDir = new File(warehouseFilePath, "bucketed_table") - Utils.deleteRecursively(tableDir) - df1.write.orc(tableDir.getAbsolutePath) - - val aggregated = spark.table("bucketed_table").groupBy("i").count() - val error = intercept[Exception] { - aggregated.count() - } - assert(error.getCause.toString contains "Invalid bucket file") - assert(error.getSuppressed.size === 0) - } - } - } - - test("SPARK-30403: AQE should handle InSubquery") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - runAdaptiveAndVerifyResult("SELECT * FROM testData LEFT OUTER join testData2" + - " ON key = a AND key NOT IN (select a from testData3) where value = '1'" - ) - } - } - - test("force apply AQE") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { - val plan = sql("SELECT * FROM testData").queryExecution.executedPlan - assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) - } - } - - test("SPARK-30719: do not log warning if intentionally skip AQE") { - val testAppender = new LogAppender("aqe logging warning test when skip") - withLogAppender(testAppender) { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val plan = sql("SELECT * FROM testData").queryExecution.executedPlan - assert(!plan.isInstanceOf[AdaptiveSparkPlanExec]) - } - } - assert(!testAppender.loggingEvents - .exists(msg => msg.getRenderedMessage.contains( - s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is" + - s" enabled but is not supported for"))) - } - - test("test log level") { - def verifyLog(expectedLevel: Level): Unit = { - val logAppender = new LogAppender("adaptive execution") - withLogAppender( - logAppender, - loggerName = Some(AdaptiveSparkPlanExec.getClass.getName.dropRight(1)), - level = Some(Level.TRACE)) { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - sql("SELECT * FROM testData join testData2 ON key = a where value = '1'").collect() - } - } - Seq("Plan changed", "Final plan").foreach { msg => - assert( - logAppender.loggingEvents.exists { event => - event.getRenderedMessage.contains(msg) && event.getLevel == expectedLevel - }) - } - } - - // Verify default log level - verifyLog(Level.DEBUG) - - // Verify custom log level - val levels = Seq( - "TRACE" -> Level.TRACE, - "trace" -> Level.TRACE, - "DEBUG" -> Level.DEBUG, - "debug" -> Level.DEBUG, - "INFO" -> Level.INFO, - "info" -> Level.INFO, - "WARN" -> Level.WARN, - "warn" -> Level.WARN, - "ERROR" -> Level.ERROR, - "error" -> Level.ERROR, - "deBUG" -> Level.DEBUG) - - levels.foreach { level => - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_LOG_LEVEL.key -> level._1) { - verifyLog(level._2) - } - } - } - - test("tree string output") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val df = sql("SELECT * FROM testData join testData2 ON key = a where value = '1'") - val planBefore = df.queryExecution.executedPlan - assert(!planBefore.toString.contains("== Current Plan ==")) - assert(!planBefore.toString.contains("== Initial Plan ==")) - df.collect() - val planAfter = df.queryExecution.executedPlan - assert(planAfter.toString.contains("== Final Plan ==")) - assert(planAfter.toString.contains("== Initial Plan ==")) - } - } - - test("SPARK-31384: avoid NPE in OptimizeSkewedJoin when there's 0 partition plan") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - withTempView("t2") { - // create DataFrame with 0 partition - spark.createDataFrame(sparkContext.emptyRDD[Row], new StructType().add("b", IntegerType)) - .createOrReplaceTempView("t2") - // should run successfully without NPE - runAdaptiveAndVerifyResult("SELECT * FROM testData2 t1 join t2 ON t1.a=t2.b") - } - } - } - - ignore("metrics of the shuffle reader") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val (_, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT key FROM testData GROUP BY key") - val readers = collect(adaptivePlan) { - case r: ColumnarCustomShuffleReaderExec => r - } - print(readers.length) - assert(readers.length == 1) - val reader = readers.head - assert(!reader.isLocalReader) - assert(!reader.hasSkewedPartition) - assert(reader.hasCoalescedPartition) - assert(reader.metrics.keys.toSeq.sorted == Seq( - "numPartitions", "partitionDataSize")) - assert(reader.metrics("numPartitions").value == reader.partitionSpecs.length) - assert(reader.metrics("partitionDataSize").value > 0) - - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - val (_, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM testData join testData2 ON key = a where value = '1'") - val join = collect(adaptivePlan) { - case j: ColumnarBroadcastHashJoinExec => j - }.head - assert(join.buildSide == BuildLeft) - - val readers = collect(join.right) { - case r: ColumnarCustomShuffleReaderExec => r - } - assert(readers.length == 1) - val reader = readers.head - assert(reader.isLocalReader) - assert(reader.metrics.keys.toSeq == Seq("numPartitions")) - assert(reader.metrics("numPartitions").value == reader.partitionSpecs.length) - } - - withSQLConf( - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.SHUFFLE_PARTITIONS.key -> "100", - SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", - SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") { - withTempView("skewData1", "skewData2") { - spark - .range(0, 1000, 1, 10) - .select( - when('id < 250, 249) - .when('id >= 750, 1000) - .otherwise('id).as("key1"), - 'id as "value1") - .createOrReplaceTempView("skewData1") - spark - .range(0, 1000, 1, 10) - .select( - when('id < 250, 249) - .otherwise('id).as("key2"), - 'id as "value2") - .createOrReplaceTempView("skewData2") - val (_, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM skewData1 join skewData2 ON key1 = key2") - val readers = collect(adaptivePlan) { - case r: CustomShuffleReaderExec => r - } - readers.foreach { reader => - assert(!reader.isLocalReader) - assert(reader.hasCoalescedPartition) - assert(reader.hasSkewedPartition) - assert(reader.metrics.contains("numSkewedPartitions")) - } - print(readers(1).metrics("numSkewedPartitions")) - print(readers(1).metrics("numSkewedSplits")) - assert(readers(0).metrics("numSkewedPartitions").value == 2) - assert(readers(0).metrics("numSkewedSplits").value == 15) - assert(readers(1).metrics("numSkewedPartitions").value == 1) - assert(readers(1).metrics("numSkewedSplits").value == 12) - } - } - } - } - - test("control a plan explain mode in listeners via SQLConf") { - - def checkPlanDescription(mode: String, expected: Seq[String]): Unit = { - var checkDone = false - val listener = new SparkListener { - override def onOtherEvent(event: SparkListenerEvent): Unit = { - event match { - case SparkListenerSQLAdaptiveExecutionUpdate(_, planDescription, _) => - assert(expected.forall(planDescription.contains)) - checkDone = true - case _ => // ignore other events - } - } - } - spark.sparkContext.addSparkListener(listener) - withSQLConf(SQLConf.UI_EXPLAIN_MODE.key -> mode, - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - val dfAdaptive = sql("SELECT * FROM testData JOIN testData2 ON key = a WHERE value = '1'") - try { - checkAnswer(dfAdaptive, Row(1, "1", 1, 1) :: Row(1, "1", 1, 2) :: Nil) - spark.sparkContext.listenerBus.waitUntilEmpty() - assert(checkDone) - } finally { - spark.sparkContext.removeSparkListener(listener) - } - } - } - - Seq(("simple", Seq("== Physical Plan ==")), - ("extended", Seq("== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", - "== Optimized Logical Plan ==", "== Physical Plan ==")), - ("codegen", Seq("WholeStageCodegen subtrees")), - ("cost", Seq("== Optimized Logical Plan ==", "Statistics(sizeInBytes")), - ("formatted", Seq("== Physical Plan ==", "Output", "Arguments"))).foreach { - case (mode, expected) => - checkPlanDescription(mode, expected) - } - } - - test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { - withTable("t1") { - val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan - assert(plan.isInstanceOf[DataWritingCommandExec]) - assert(plan.asInstanceOf[DataWritingCommandExec].child.isInstanceOf[AdaptiveSparkPlanExec]) - } - } - } - - test("AQE should set active session during execution") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val df = spark.range(10).select(sum('id)) - assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) - SparkSession.setActiveSession(null) - checkAnswer(df, Seq(Row(45))) - SparkSession.setActiveSession(spark) // recover the active session. - } - } - - test("No deadlock in UI update") { - object TestStrategy extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case _: Aggregate => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { - spark.range(5).rdd - } - Nil - case _ => Nil - } - } - - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { - try { - spark.experimental.extraStrategies = TestStrategy :: Nil - val df = spark.range(10).groupBy('id).count() - df.collect() - } finally { - spark.experimental.extraStrategies = Nil - } - } - } - - test("SPARK-31658: SQL UI should show write commands") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { - withTable("t1") { - var checkDone = false - val listener = new SparkListener { - override def onOtherEvent(event: SparkListenerEvent): Unit = { - event match { - case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) => - assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand") - checkDone = true - case _ => // ignore other events - } - } - } - spark.sparkContext.addSparkListener(listener) - try { - sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() - spark.sparkContext.listenerBus.waitUntilEmpty() - assert(checkDone) - } finally { - spark.sparkContext.removeSparkListener(listener) - } - } - } - } - - test("SPARK-31220, SPARK-32056: repartition by expression with AQE") { - Seq(true, false).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, - SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", - SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", - SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - - val df1 = spark.range(10).repartition($"id") - val df2 = spark.range(10).repartition($"id" + 1) - - val partitionsNum1 = df1.rdd.collectPartitions().length - val partitionsNum2 = df2.rdd.collectPartitions().length - - if (enableAQE) { - assert(partitionsNum1 < 10) - assert(partitionsNum2 < 10) - - checkInitialPartitionNum(df1, 10) - checkInitialPartitionNum(df2, 10) - } else { - assert(partitionsNum1 === 10) - assert(partitionsNum2 === 10) - } - - - // Don't coalesce partitions if the number of partitions is specified. - val df3 = spark.range(10).repartition(10, $"id") - val df4 = spark.range(10).repartition(10) - assert(df3.rdd.collectPartitions().length == 10) - assert(df4.rdd.collectPartitions().length == 10) - } - } - } - - test("SPARK-31220, SPARK-32056: repartition by range with AQE") { - Seq(true, false).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, - SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", - SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", - SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - - val df1 = spark.range(10).toDF.repartitionByRange($"id".asc) - val df2 = spark.range(10).toDF.repartitionByRange(($"id" + 1).asc) - - val partitionsNum1 = df1.rdd.collectPartitions().length - val partitionsNum2 = df2.rdd.collectPartitions().length - - if (enableAQE) { - assert(partitionsNum1 < 10) - assert(partitionsNum2 < 10) - - checkInitialPartitionNum(df1, 10) - checkInitialPartitionNum(df2, 10) - } else { - assert(partitionsNum1 === 10) - assert(partitionsNum2 === 10) - } - - // Don't coalesce partitions if the number of partitions is specified. - val df3 = spark.range(10).repartitionByRange(10, $"id".asc) - assert(df3.rdd.collectPartitions().length == 10) - } - } - } - - test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") { - Seq(true, false).foreach { enableAQE => - withTempView("test") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, - SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", - SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", - SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - - spark.range(10).toDF.createTempView("test") - - val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test") - val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test") - val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id") - val df4 = spark.sql("SELECT * from test CLUSTER BY id") - - val partitionsNum1 = df1.rdd.collectPartitions().length - val partitionsNum2 = df2.rdd.collectPartitions().length - val partitionsNum3 = df3.rdd.collectPartitions().length - val partitionsNum4 = df4.rdd.collectPartitions().length - - if (enableAQE) { - assert(partitionsNum1 < 10) - assert(partitionsNum2 < 10) - assert(partitionsNum3 < 10) - assert(partitionsNum4 < 10) - - checkInitialPartitionNum(df1, 10) - checkInitialPartitionNum(df2, 10) - checkInitialPartitionNum(df3, 10) - checkInitialPartitionNum(df4, 10) - } else { - assert(partitionsNum1 === 10) - assert(partitionsNum2 === 10) - assert(partitionsNum3 === 10) - assert(partitionsNum4 === 10) - } - - // Don't coalesce partitions if the number of partitions is specified. - val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test") - val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test") - assert(df5.rdd.collectPartitions().length == 10) - assert(df6.rdd.collectPartitions().length == 10) - } - } - } - } - - test("SPARK-32573: Eliminate NAAJ when BuildSide is HashedRelationWithAllNullKeys") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString) { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)") - val bhj = findTopLevelBroadcastHashJoin(plan) - assert(bhj.size == 1) - val join = findTopLevelBaseJoin(adaptivePlan) - assert(join.isEmpty) - checkNumLocalShuffleReaders(adaptivePlan) - } - } - - test("SPARK-32717: AQEOptimizer should respect excludedRules configuration") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString, - // This test is a copy of test(SPARK-32573), in order to test the configuration - // `spark.sql.adaptive.optimizer.excludedRules` works as expect. - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> EliminateJoinToEmptyRelation.ruleName) { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)") - val bhj = findTopLevelBroadcastHashJoin(plan) - assert(bhj.size == 1) - val join = findTopLevelBaseJoin(adaptivePlan) - // this is different compares to test(SPARK-32573) due to the rule - // `EliminateJoinToEmptyRelation` has been excluded. - assert(join.nonEmpty) - checkNumLocalShuffleReaders(adaptivePlan) - } - } - - test("SPARK-32649: Eliminate inner to empty relation") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - Seq( - // inner join (small table at right side) - "SELECT * FROM testData t1 join testData3 t2 ON t1.key = t2.a WHERE t2.b = 1", - // inner join (small table at left side) - "SELECT * FROM testData3 t1 join testData t2 ON t1.a = t2.key WHERE t1.b = 1", - // left semi join : left join do not has omni impl - // "SELECT * FROM testData t1 left semi join testData3 t2 ON t1.key = t2.a AND t2.b = 1" - ).foreach(query => { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 1) - val join = findTopLevelBaseJoin(adaptivePlan) - assert(join.isEmpty) - checkNumLocalShuffleReaders(adaptivePlan) - }) - } - } - - test("SPARK-32753: Only copy tags to node with no tags") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - withTempView("v1") { - spark.range(10).union(spark.range(10)).createOrReplaceTempView("v1") - - val (_, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id") - assert(collect(adaptivePlan) { - case s: ShuffleExchangeExec => s - }.length == 1) - } - } - } - - test("Logging plan changes for AQE") { - val testAppender = new LogAppender("plan changes") - withLogAppender(testAppender) { - withSQLConf( - SQLConf.PLAN_CHANGE_LOG_LEVEL.key -> "INFO", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - sql("SELECT * FROM testData JOIN testData2 ON key = a " + - "WHERE value = (SELECT max(a) FROM testData3)").collect() - } - Seq("=== Result of Batch AQE Preparations ===", - "=== Result of Batch AQE Post Stage Creation ===", - "=== Result of Batch AQE Replanning ===", - "=== Result of Batch AQE Query Stage Optimization ===", - "=== Result of Batch AQE Final Query Stage Optimization ===").foreach { expectedMsg => - assert(testAppender.loggingEvents.exists(_.getRenderedMessage.contains(expectedMsg))) - } - } - } - - test("SPARK-32932: Do not use local shuffle reader at final stage on write command") { - withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString, - SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val data = for ( - i <- 1L to 10L; - j <- 1L to 3L - ) yield (i, j) - - val df = data.toDF("i", "j").repartition($"j") - var noLocalReader: Boolean = false - val listener = new QueryExecutionListener { - override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - qe.executedPlan match { - case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) => - assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec]) - noLocalReader = collect(plan) { - case exec: CustomShuffleReaderExec if exec.isLocalReader => exec - }.isEmpty - case _ => // ignore other events - } - } - override def onFailure(funcName: String, qe: QueryExecution, - exception: Exception): Unit = {} - } - spark.listenerManager.register(listener) - - withTable("t") { - df.write.partitionBy("j").saveAsTable("t") - sparkContext.listenerBus.waitUntilEmpty() - assert(noLocalReader) - noLocalReader = false - } - - // Test DataSource v2 - val format = classOf[NoopDataSource].getName - df.write.format(format).mode("overwrite").save() - sparkContext.listenerBus.waitUntilEmpty() - assert(noLocalReader) - noLocalReader = false - - spark.listenerManager.unregister(listener) - } - } - - test("SPARK-33494: Do not use local shuffle reader for repartition") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val df = spark.table("testData").repartition('key) - df.collect() - // local shuffle reader breaks partitioning and shouldn't be used for repartition operation - // which is specified by users. - checkNumLocalShuffleReaders(df.queryExecution.executedPlan, numShufflesWithoutLocalReader = 1) - } - } - - test("SPARK-33551: Do not use custom shuffle reader for repartition") { - def hasRepartitionShuffle(plan: SparkPlan): Boolean = { - find(plan) { - case s: ShuffleExchangeLike => - s.shuffleOrigin == REPARTITION || s.shuffleOrigin == REPARTITION_WITH_NUM - case _ => false - }.isDefined - } - - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.SHUFFLE_PARTITIONS.key -> "5") { - val df = sql( - """ - |SELECT * FROM ( - | SELECT * FROM testData WHERE key = 1 - |) - |RIGHT OUTER JOIN testData2 - |ON value = b - """.stripMargin) - - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - // Repartition with no partition num specified. - val dfRepartition = df.repartition('b) - dfRepartition.collect() - val plan = dfRepartition.queryExecution.executedPlan - // The top shuffle from repartition is optimized out. - assert(!hasRepartitionShuffle(plan)) - val bhj = findTopLevelBroadcastHashJoin(plan) - assert(bhj.length == 1) - checkNumLocalShuffleReaders(plan, 1) - // Probe side is coalesced. - val customReader = bhj.head.right.find(_.isInstanceOf[ColumnarCustomShuffleReaderExec]) - assert(customReader.isDefined) - assert(customReader.get.asInstanceOf[ColumnarCustomShuffleReaderExec].hasCoalescedPartition) - - // Repartition with partition default num specified. - val dfRepartitionWithNum = df.repartition(5, 'b) - dfRepartitionWithNum.collect() - val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan - // The top shuffle from repartition is optimized out. - assert(!hasRepartitionShuffle(planWithNum)) - val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum) - assert(bhjWithNum.length == 1) - checkNumLocalShuffleReaders(planWithNum, 1) - // Probe side is not coalesced. - assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).isEmpty) - - // Repartition with partition non-default num specified. - val dfRepartitionWithNum2 = df.repartition(3, 'b) - dfRepartitionWithNum2.collect() - val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan - // The top shuffle from repartition is not optimized out, and this is the only shuffle that - // does not have local shuffle reader. - assert(hasRepartitionShuffle(planWithNum2)) - val bhjWithNum2 = findTopLevelBroadcastHashJoin(planWithNum2) - assert(bhjWithNum2.length == 1) - checkNumLocalShuffleReaders(planWithNum2, 1) - val customReader2 = bhjWithNum2.head.right - .find(_.isInstanceOf[ColumnarCustomShuffleReaderExec]) - assert(customReader2.isDefined) - assert(customReader2.get.asInstanceOf[ColumnarCustomShuffleReaderExec].isLocalReader) - } - - // Force skew join - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.SKEW_JOIN_ENABLED.key -> "true", - SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "1", - SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0", - SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { - // Repartition with no partition num specified. - val dfRepartition = df.repartition('b) - dfRepartition.collect() - val plan = dfRepartition.queryExecution.executedPlan - // The top shuffle from repartition is optimized out. - assert(!hasRepartitionShuffle(plan)) - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.length == 1) - // No skew join due to the repartition. - assert(!smj.head.isSkewJoin) - // Both sides are coalesced. - val customReaders = collect(smj.head) { - case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c - case c: ColumnarCustomShuffleReaderExec if c.hasCoalescedPartition => c - } - assert(customReaders.length == 2) - - // Repartition with default partition num specified. - val dfRepartitionWithNum = df.repartition(5, 'b) - dfRepartitionWithNum.collect() - val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan - // The top shuffle from repartition is optimized out. - assert(!hasRepartitionShuffle(planWithNum)) - val smjWithNum = findTopLevelSortMergeJoin(planWithNum) - assert(smjWithNum.length == 1) - // No skew join due to the repartition. - assert(!smjWithNum.head.isSkewJoin) - // No coalesce due to the num in repartition. - val customReadersWithNum = collect(smjWithNum.head) { - case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c - } - assert(customReadersWithNum.isEmpty) - - // Repartition with default non-partition num specified. - val dfRepartitionWithNum2 = df.repartition(3, 'b) - dfRepartitionWithNum2.collect() - val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan - // The top shuffle from repartition is not optimized out. - assert(hasRepartitionShuffle(planWithNum2)) - val smjWithNum2 = findTopLevelSortMergeJoin(planWithNum2) - assert(smjWithNum2.length == 1) - // Skew join can apply as the repartition is not optimized out. - assert(smjWithNum2.head.isSkewJoin) - } - } - } - - ignore("SPARK-34091: Batch shuffle fetch in AQE partition coalescing") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.SHUFFLE_PARTITIONS.key -> "10000", - SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "true") { - withTable("t1") { - spark.range(100).selectExpr("id + 1 as a").write.format("parquet").saveAsTable("t1") - val query = "SELECT SUM(a) FROM t1 GROUP BY a" - val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) - val metricName = SQLShuffleReadMetricsReporter.LOCAL_BLOCKS_FETCHED - val blocksFetchedMetric = collectFirst(adaptivePlan) { - case p if p.metrics.contains(metricName) => p.metrics(metricName) - } - assert(blocksFetchedMetric.isDefined) - val blocksFetched = blocksFetchedMetric.get.value - withSQLConf(SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "false") { - val (_, adaptivePlan2) = runAdaptiveAndVerifyResult(query) - val blocksFetchedMetric2 = collectFirst(adaptivePlan2) { - case p if p.metrics.contains(metricName) => p.metrics(metricName) - } - assert(blocksFetchedMetric2.isDefined) - val blocksFetched2 = blocksFetchedMetric2.get.value - assert(blocksFetched < blocksFetched2) - } - } - } - } - - test("Do not use column shuffle in AQE") { - def findCustomShuffleReader(plan: SparkPlan): Seq[CustomShuffleReaderExec] ={ - collect(plan) { - case j: CustomShuffleReaderExec => j - } - } - def findShuffleExchange(plan: SparkPlan): Seq[ShuffleExchangeExec] ={ - collect(plan) { - case j: ShuffleExchangeExec => j - } - } - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - "spark.shuffle.manager"-> "sort", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100", - SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") { - spark - .range(1, 1000, 1).where("id > 995").createOrReplaceTempView("t1") - spark - .range(1, 5, 1)createOrReplaceTempView("t2") - val (_, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM t1 JOIN t2 ON t1.id = t2.id") - val shuffleNum = findShuffleExchange(adaptivePlan) - assert(shuffleNum.length == 2) - val shuffleReaderNum = findCustomShuffleReader(adaptivePlan) - assert(shuffleReaderNum.length == 2) - - } - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarBuiltInFuncSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarBuiltInFuncSuite.scala deleted file mode 100644 index 20879ad520d2a4fc0f6af0f20e3720d64105e03e..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarBuiltInFuncSuite.scala +++ /dev/null @@ -1,638 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.forsql - -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.execution.{ColumnarProjectExec, ColumnarSparkPlanTest, ProjectExec} - -class ColumnarBuiltInFuncSuite extends ColumnarSparkPlanTest{ - import testImplicits.{localSeqToDatasetHolder, newProductEncoder} - - private var buildInDf: DataFrame = _ - - private var buildInDfNum: DataFrame = _ - - protected override def beforeAll(): Unit = { - super.beforeAll() - buildInDf = Seq[(String, String, String, String, Long, Int, String, String)]( - (null, "ChaR1 R", null, " varchar100 ", 1001L, 1, " 中文1aA ", "varchar100_normal"), - ("char200 ", "char2 ", "varchar2", "", 1002L, 2, "中文2bB", "varchar200_normal"), - ("char300 ", "char3 ", "varchar3", "varchar300", 1003L, 3, "中文3cC", "varchar300_normal"), - (null, "char4 ", "varchar4", "varchar400", 1004L, 4, null, "varchar400_normal") - ).toDF("char_null", "char_normal", "varchar_null", "varchar_empty", "long_col", "int_col", "ch_col", "varchar_normal") - buildInDf.createOrReplaceTempView("builtin_table") - - buildInDfNum = Seq[(Double, Int, Double, Int)]( - (123.12345, 1, -123.12345, 134), - (123.1257, 2, -123.1257, 1267), - (123.12, 3, -123.12, 1650), - (123.1, 4, -123.1, 166667) - ).toDF("double1", "int2", "double3", "int4") - buildInDfNum.createOrReplaceTempView("test_table") - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute lower with normal") { - val sql = "select lower(char_normal) from builtin_table" - val expected = Seq( - Row("char1 r"), - Row("char2 "), - Row("char3 "), - Row("char4 ") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute lower with null") { - val sql = "select lower(char_null) from builtin_table" - val expected = Seq( - Row(null), - Row("char200 "), - Row("char300 "), - Row(null) - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute lower with space/empty string") { - val sql = "select lower(varchar_empty) from builtin_table" - val expected = Seq( - Row(" varchar100 "), - Row(""), - Row("varchar300"), - Row("varchar400") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute lower-lower") { - val sql = "select lower(char_null), lower(varchar_null) from builtin_table" - val expected = Seq( - Row(null, null), - Row("char200 ", "varchar2"), - Row("char300 ", "varchar3"), - Row(null, "varchar4"), - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute lower(lower())") { - val sql = "select lower(lower(char_null)) from builtin_table" - val expected = Seq( - Row(null), - Row("char200 "), - Row("char300 "), - Row(null) - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute lower with subQuery") { - val sql = "select lower(l) from (select lower(char_normal) as l from builtin_table)" - val expected = Seq( - Row("char1 r"), - Row("char2 "), - Row("char3 "), - Row("char4 ") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute lower with ch") { - val sql = "select lower(ch_col) from builtin_table" - val expected = Seq( - Row(" 中文1aa "), - Row("中文2bb"), - Row("中文3cc"), - Row(null) - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute length with normal") { - val sql = "select length(char_normal) from builtin_table" - val expected = Seq( - Row(10), - Row(10), - Row(10), - Row(10) - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute length with null") { - val sql = "select length(char_null) from builtin_table" - val expected = Seq( - Row(null), - Row(10), - Row(10), - Row(null) - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute length with space/empty string") { - val sql = "select length(varchar_empty) from builtin_table" - val expected = Seq( - Row(13), - Row(0), - Row(10), - Row(10) - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute length with expr") { - val sql = "select length(char_null) / 2 from builtin_table" - val expected = Seq( - Row(null), - Row(5.0), - Row(5.0), - Row(null) - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute length-length") { - val sql = "select length(char_null),length(varchar_null) from builtin_table" - val expected = Seq( - Row(null, null), - Row(10, 8), - Row(10, 8), - Row(null, 8) - ) - checkResult(sql, expected) - } - - // replace(str, search, replaceStr) - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with matched and replace str") { - val sql = "select replace(varchar_normal,varchar_empty,char_normal) from builtin_table" - val expected = Seq( - Row("varchar100_normal"), - Row("varchar200_normal"), - Row("char3 _normal"), - Row("char4 _normal") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with not matched") { - val sql = "select replace(char_normal,varchar_normal,char_normal) from builtin_table" - val expected = Seq( - Row("ChaR1 R"), - Row("char2 "), - Row("char3 "), - Row("char4 ") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with str null") { - val sql = "select replace(varchar_null,char_normal,varchar_normal) from builtin_table" - val expected = Seq( - Row(null), - Row("varchar2"), - Row("varchar3"), - Row("varchar4") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with str space/empty") { - val sql = "select replace(varchar_empty,varchar_empty,varchar_normal) from builtin_table" - val expected = Seq( - Row("varchar100_normal"), - Row(""), - Row("varchar300_normal"), - Row("varchar400_normal") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with search null") { - val sql = "select replace(varchar_normal,varchar_null,char_normal) from builtin_table" - val expected = Seq( - Row(null), - Row("char2 00_normal"), - Row("char3 00_normal"), - Row("char4 00_normal") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with search space/empty") { - val sql = "select replace(varchar_normal,varchar_empty,char_normal) from builtin_table" - val expected = Seq( - Row("varchar100_normal"), - Row("varchar200_normal"), - Row("char3 _normal"), - Row("char4 _normal") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with replaceStr null") { - val sql = "select replace(varchar_normal,varchar_empty,varchar_null) from builtin_table" - val expected = Seq( - Row(null), - Row("varchar200_normal"), - Row("varchar3_normal"), - Row("varchar4_normal") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with replaceStr space/empty") { - val sql = "select replace(varchar_normal,varchar_normal,varchar_empty) from builtin_table" - val expected = Seq( - Row(" varchar100 "), - Row(""), - Row("varchar300"), - Row("varchar400") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with str/search/replace all null") { - val sql = "select replace(varchar_null,varchar_null,char_null) from builtin_table" - val expected = Seq( - Row(null), - Row("char200 "), - Row("char300 "), - Row(null) - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with replaceStr default") { - val sql = "select replace(varchar_normal,varchar_normal) from builtin_table" - val expected = Seq( - Row(""), - Row(""), - Row(""), - Row("") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with subReplace(normal,normal,normal)") { - val sql = "select replace(res,'c','ccc') from (select replace(varchar_normal,varchar_empty,char_normal) as res from builtin_table)" - val expected = Seq( - Row("varccchar100_normal"), - Row("varccchar200_normal"), - Row("ccchar3 _normal"), - Row("ccchar4 _normal") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with subReplace(null,null,null)") { - val sql = "select replace(res,'c','ccc') from (select replace(varchar_null,varchar_null,char_null) as res from builtin_table)" - val expected = Seq( - Row(null), - Row("ccchar200 "), - Row("ccchar300 "), - Row(null) - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace(replace)") { - val sql = "select replace(replace('ABCabc','AB','abc'),'abc','DEF')" - val expected = Seq( - Row("DEFCDEF") - ) - checkResult(sql, expected) - } - - // upper - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute upper with normal") { - val sql = "select upper(char_normal) from builtin_table" - val expected = Seq( - Row("CHAR1 R"), - Row("CHAR2 "), - Row("CHAR3 "), - Row("CHAR4 ") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute upper with null") { - val sql = "select upper(char_null) from builtin_table" - val expected = Seq( - Row(null), - Row("CHAR200 "), - Row("CHAR300 "), - Row(null) - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute upper with space/empty string") { - val sql = "select upper(varchar_empty) from builtin_table" - val expected = Seq( - Row(" VARCHAR100 "), - Row(""), - Row("VARCHAR300"), - Row("VARCHAR400") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute upper-upper") { - val sql = "select upper(char_null), upper(varchar_null) from builtin_table" - val expected = Seq( - Row(null, null), - Row("CHAR200 ", "VARCHAR2"), - Row("CHAR300 ", "VARCHAR3"), - Row(null, "VARCHAR4"), - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute upper(upper())") { - val sql = "select upper(upper(char_null)) from builtin_table" - val expected = Seq( - Row(null), - Row("CHAR200 "), - Row("CHAR300 "), - Row(null) - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute upper with subQuery") { - val sql = "select upper(l) from (select upper(char_normal) as l from builtin_table)" - val expected = Seq( - Row("CHAR1 R"), - Row("CHAR2 "), - Row("CHAR3 "), - Row("CHAR4 ") - ) - checkResult(sql, expected) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute upper with ch") { - val sql = "select upper(ch_col) from builtin_table" - val expected = Seq( - Row(" 中文1AA "), - Row("中文2BB"), - Row("中文3CC"), - Row(null) - ) - checkResult(sql, expected) - } - - def checkResult(sql: String, expected: Seq[Row], isUseOmni: Boolean = true): Unit = { - def assertOmniProjectHappen(res: DataFrame): Unit = { - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - } - def assertOmniProjectNotHappen(res: DataFrame): Unit = { - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"ColumnarProjectExec happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isDefined, s"ProjectExec not happened, executedPlan as follows: \n$executedPlan") - } - val res = spark.sql(sql) - if (isUseOmni) assertOmniProjectHappen(res) else assertOmniProjectNotHappen(res) - checkAnswer(res, expected) - } - - test("Round(int,2)") { - val res = spark.sql("select round(int2,2) as res from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(1), - Row(2), - Row(3), - Row(4) - ) - ) - } - - test("Round(double,2)") { - val res = spark.sql("select round(double1,2) as res from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(123.12), - Row(123.13), - Row(123.12), - Row(123.1) - ) - ) - } - - test("Round(int,-1)") { - val res = spark.sql("select round(int2,-1) as res from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(0), - Row(0), - Row(0), - Row(0) - ) - ) - } - - test("Round(double,0)") { - val res = spark.sql("select round(double1,0) as res from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(123), - Row(123), - Row(123), - Row(123) - ) - ) - } - - test("Round(-double,2)") { - val res = spark.sql("select round(double3,2) as res from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(-123.12), - Row(-123.13), - Row(-123.12), - Row(-123.1) - ) - ) - } - - test("Round(int,-2)") { - val res = spark.sql("select round(int4,-2) as res from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(100), - Row(1300), - Row(1700), - Row(166700) - ) - ) - } - - test("Round decimal") { - var res = spark.sql("select round(2.5, 0) as res from test_table") - var executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(3), - Row(3), - Row(3), - Row(3) - ) - ) - res = spark.sql("select round(3.5, 0) as res from test_table") - executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(4), - Row(4), - Row(4), - Row(4) - ) - ) - res = spark.sql("select round(-2.5, 0) as res from test_table") - executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(-3), - Row(-3), - Row(-3), - Row(-3) - ) - ) - res = spark.sql("select round(-3.5, 0) as res from test_table") - executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(-4), - Row(-4), - Row(-4), - Row(-4) - ) - ) - res = spark.sql("select round(-0.35, 1) as res from test_table") - executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(-0.4), - Row(-0.4), - Row(-0.4), - Row(-0.4) - ) - ) - res = spark.sql("select round(-35, -1) as res from test_table") - executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(-40), - Row(-40), - Row(-40), - Row(-40) - ) - ) - res = spark.sql("select round(null, 0) as res from test_table") - executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(null), - Row(null), - Row(null), - Row(null) - ) - ) - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarDecimalCastSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarDecimalCastSuite.scala deleted file mode 100644 index 2d56cac9dc777417bf9d44995a6b1cb089c67cfd..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarDecimalCastSuite.scala +++ /dev/null @@ -1,700 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.forsql - -import org.apache.spark.sql.execution.{ColumnarProjectExec, ColumnarSparkPlanTest, ProjectExec} -import org.apache.spark.sql.types.Decimal -import org.apache.spark.sql.{DataFrame, Row} - -class ColumnarDecimalCastSuite extends ColumnarSparkPlanTest{ - import testImplicits.{localSeqToDatasetHolder, newProductEncoder} - - private var byteDecimalDf: DataFrame = _ - private var shortDecimalDf: DataFrame = _ - private var intDecimalDf: DataFrame = _ - private var longDecimalDf: DataFrame = _ - private var floatDecimalDf: DataFrame = _ - private var doubleDecimalDf: DataFrame = _ - private var stringDecimalDf: DataFrame = _ - private var decimalDecimalDf: DataFrame = _ - - protected override def beforeAll(): Unit = { - super.beforeAll() - byteDecimalDf = Seq[(java.lang.Byte, java.lang.Byte, Decimal, Decimal, Decimal)]( - (127.toByte, 12.toByte, Decimal(100000, 7, 0), Decimal(128.99, 7, 2), Decimal(11223344.123, 21, 6)), - ((-12).toByte, null, Decimal(25, 7, 0), Decimal(25.55, 7, 2), Decimal(-99999999999.1234, 21, 6)), - (9.toByte, (-11).toByte, Decimal(-25, 7, 0), null, null), - ((-9).toByte, null, Decimal(145, 7, 0), Decimal(256.66, 7, 2), null) - ).toDF("c_byte_normal", "c_byte_null", "c_deci7_0", "c_deci7_2_null", "c_deci21_6_null") - - shortDecimalDf = Seq[(java.lang.Short, java.lang.Short, Decimal, Decimal)]( - (10.toShort, 15.toShort, Decimal(130.6, 17, 2), Decimal(128.99, 21, 6)), - ((-10).toShort, null, null, Decimal(32723.55, 21, 6)), - (1000.toShort, null, Decimal(-30.8, 17, 2), null), - ((-1000).toShort, 2000.toShort, null, Decimal(-99999.19, 21, 6)), - ).toDF("c_short_normal", "c_short_null", "c_deci17_2_null", "c_deci21_6_null") - - intDecimalDf = Seq[(java.lang.Integer, java.lang.Integer, Decimal, Decimal)]( - (1272763, 1111, null, Decimal(1234.555431, 21, 6)), - (22723, 2222, Decimal(32728543.12, 17, 2), Decimal(99999999.999, 21, 6)), - (9, null, Decimal(-195010407800.34, 17, 2), Decimal(-99999999.999, 21, 6)), - (345, -4444, Decimal(12000.56, 17, 2), null) - ).toDF("c_int_normal", "c_int_null", "c_deci17_2_null", "c_deci21_6_null") - - longDecimalDf = Seq[(java.lang.Long, java.lang.Long, Decimal, Decimal)]( - (922337203L, 1231313L, null, Decimal(1922337203685.99, 38, 2)), - (22723L, null, Decimal(2233720368.12, 17, 2), Decimal(54775800.55, 38, 2)), - (9L, -123131, Decimal(-2192233720.34, 17, 2), null) - ).toDF("c_long_normal", "c_long_null", "c_deci17_2_null", "c_deci38_2_null") - - floatDecimalDf = Seq[(java.lang.Float, java.lang.Float, Decimal, Decimal)]( - (1234.4129F, 123.12F, null, Decimal(10000.99, 38, 2)), - (1234.4125F, 123.34F, Decimal(1234.11, 17, 2), Decimal(10000.99, 38, 2)), - (3.4E10F, null, Decimal(999999999999.22, 17, 2), Decimal(999999999999.99, 38, 2)), - (-3.4E-10F, -1123.1113F, Decimal(-999999999999.33, 17, 2), Decimal(-999999999999.99, 38, 2)) - ).toDF("c_float_normal", "c_float_null", "c_deci17_2_null", "c_deci38_2_null") - - doubleDecimalDf = Seq[(java.lang.Double, java.lang.Double, Decimal, Decimal)]( - (1234.4129, 1123, Decimal(10000.99, 17, 2), Decimal(1234.14, 38, 2)), - (1234.4125, null, Decimal(10000.99, 17, 2), Decimal(1234.14, 38, 2)), - (1234.4124, 1234, Decimal(10000.99, 17, 2), null) - ).toDF("c_double_normal", "c_double_null", "c_deci17_2_null", "c_deci38_2_null") - - stringDecimalDf = Seq[(String, String, Decimal, Decimal)]( - (" 99999 ", "111 ", null, Decimal(128.99, 38, 2)), - ("-1234.15 ", "222.2 ", Decimal(99999.11, 17, 2), Decimal(99999.99, 38, 2)), - ("abc ", "-333.33 ", Decimal(-11111.22, 17, 2), Decimal(-99999.19, 38, 2)), - ("999999 ", null, Decimal(99999.33, 17, 2), null), - ).toDF("c_string_normal", "c_string_null", "c_deci17_2_null", "c_deci38_2_null") - - decimalDecimalDf = Seq[(Decimal, Decimal)]( - (Decimal(128.99, 17, 2), Decimal(1234.555431, 21, 6)), - (null, Decimal(99999999.999, 21, 6)), - (Decimal(25.19, 17, 2), Decimal(-99999999.999, 21, 6)), - (Decimal(-195010407800.19, 17, 2), Decimal(-999999.99999, 21, 6)) - ).toDF("c_deci17_2_null", "c_deci21_6") - - byteDecimalDf.createOrReplaceTempView("deci_byte") - shortDecimalDf.createOrReplaceTempView("deci_short") - intDecimalDf.createOrReplaceTempView("deci_int") - longDecimalDf.createOrReplaceTempView("deci_long") - floatDecimalDf.createOrReplaceTempView("deci_float") - doubleDecimalDf.createOrReplaceTempView("deci_double") - stringDecimalDf.createOrReplaceTempView("deci_string") - decimalDecimalDf.createOrReplaceTempView("deci_decimal") - } - - // byte - test("Test ColumnarProjectExec not happen and result is same as native " + - "when cast byte to decimal") { - val res = spark.sql("select c_byte_normal, cast(c_byte_normal as decimal(10, 2))," + - "cast(c_byte_normal as decimal(19,1)) from deci_byte") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(-12, -12.00, -12.0), - Row(127, 127.00, 127.0), - Row(-9, -9.00, -9.0), - Row(9, 9.00, 9.0) - ) - ) - } - - test("Test ColumnarProjectExec not happen and result is same as native " + - "when cast byte to decimal overflow with spark.sql.ansi.enabled=false") { - val res = spark.sql("select c_byte_normal, cast(c_byte_normal as decimal(2, 1))," + - "cast(c_byte_normal as decimal(19,18)) from deci_byte") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(-12, null, null), - Row(127, null, null), - Row(-9, -9.0, -9.000000000000000000), - Row(9, 9.0, 9.000000000000000000) - ) - ) - } - - test("Test ColumnarProjectExec not happen and result is same as native " + - "when cast byte to decimal with null") { - val res = spark.sql("select c_byte_null, cast(c_byte_null as decimal(10, 2))," + - "cast(c_byte_null as decimal(19,3)) from deci_byte") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(null, null, null), - Row(12, 12.00, 12.000), - Row(null, null, null), - Row(-11, -11.00, -11.000) - ) - ) - } - - test("Test ColumnarProjectExec not happen and result is same as native " + - "when cast decimal to byte") { - val res = spark.sql("select cast(c_deci7_0 as byte)," + - "cast(c_deci7_2_null as byte), cast(c_deci21_6_null as byte) from deci_byte") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(25.toByte, 25.toByte, 1.toByte), - Row((-96).toByte, (-128).toByte, 48.toByte), - Row((-111).toByte, 0.toByte, null), - Row((-25).toByte, null, null) - ) - ) - } - - // short - test("Test ColumnarProjectExec not happen and result is same as native " + - "when cast short to decimal") { - val res = spark.sql("select c_short_normal, cast(c_short_normal as decimal(13, 3))," + - "cast(c_short_normal as decimal(20,2)) from deci_short") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(-1000, -1000.000, -1000.00), - Row(10, 10.000, 10.00), - Row(-10, -10.000, -10.00), - Row(1000, 1000.000, 1000.00) - ) - ) - } - - test("Test ColumnarProjectExec not happen and result is same as native " + - "when cast short to decimal overflow with spark.sql.ansi.enabled=false") { - val res = spark.sql("select c_short_normal, cast(c_short_normal as decimal(2, 1))," + - "cast(c_short_normal as decimal(20,18)) from deci_short") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(-1000, null, null), - Row(10, null, 10.000000000000000000), - Row(-10, null, -10.000000000000000000), - Row(1000, null, null) - ) - ) - } - - test("Test ColumnarProjectExec not happen and result is same as native " + - "when cast short to decimal with null") { - val res = spark.sql("select c_short_null, cast(c_short_null as decimal(14, 2))," + - "cast(c_short_null as decimal(20,3)) from deci_short") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(2000, 2000.00, 2000.000), - Row(15, 15.00, 15.000), - Row(null, null, null), - Row(null, null, null) - ) - ) - } - - test("Test ColumnarProjectExec not happen and result is same as native " + - "when cast decimal to short") { - val res = spark.sql("select cast(c_deci17_2_null as short)," + - "cast(c_deci21_6_null as short) from deci_short") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(null, 31073.toShort), - Row(130.toShort, 128.toShort), - Row(null, 32723.toShort), - Row((-30).toShort, null) - ) - ) - } - - // int - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast int to decimal") { - val res = spark.sql("select c_int_normal, cast(c_int_normal as decimal(16, 3))," + - "cast(c_int_normal as decimal(22,4)) from deci_int") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(22723, 22723.000, 22723.0000), - Row(1272763, 1272763.000, 1272763.0000), - Row(9, 9.000, 9.0000), - Row(345, 345.000, 345.0000) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast int to decimal overflow with spark.sql.ansi.enabled=false") { - val res = spark.sql("select c_int_normal, cast(c_int_normal as decimal(2, 1))," + - "cast(c_int_normal as decimal(22,19)) from deci_int") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(22723, null, null), - Row(1272763, null, null), - Row(9, 9.0, 9.0000000000000000000), - Row(345, null, 345.0000000000000000000) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast int to decimal with null") { - val res = spark.sql("select c_int_null, cast(c_int_null as decimal(16, 4))," + - "cast(c_int_null as decimal(23,5)) from deci_int") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(2222, 2222.0000, 2222.00000), - Row(1111, 1111.0000, 1111.00000), - Row(null, null, null), - Row(-4444, -4444.0000, -4444.00000) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast decimal to int") { - val res = spark.sql("select cast(c_deci17_2_null as int)," + - "cast(c_deci21_6_null as int) from deci_int") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(32728543, 99999999), - Row(null, 1234), - Row(-1736879480, -99999999), - Row(12000, null) - ) - ) - } - - // long - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast long to decimal") { - val res = spark.sql("select c_long_normal, cast(c_long_normal as decimal(16, 2))," + - "cast(c_long_normal as decimal(26,7)) from deci_long") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(922337203L, 922337203.00, 922337203.0000000), - Row(22723L, 22723.00, 22723.0000000), - Row(9L, 9.000, 9.0000) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast long to decimal overflow with spark.sql.ansi.enabled=false") { - val res = spark.sql("select c_long_normal, cast(c_long_normal as decimal(3, 1))," + - "cast(c_long_normal as decimal(27,24)) from deci_long") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(922337203L, null, null), - Row(22723L, null, null), - Row(9L, 9.0, 9.000000000000000000000000) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast long to decimal with null") { - val res = spark.sql("select c_long_null, cast(c_long_null as decimal(17, 6))," + - "cast(c_long_null as decimal(29,7)) from deci_long") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(null, null, null), - Row(1231313L, 1231313.000000, 1231313.0000000), - Row(-123131L, -123131.000000, -123131.0000000) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast decimal to long") { - val res = spark.sql("select cast(c_deci17_2_null as long)," + - "cast(c_deci38_2_null as long) from deci_long") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(null, 1922337203685L), - Row(2233720368L, 54775800L), - Row(-2192233720L, null) - ) - ) - } - - // float - test("Test ColumnarProjectExec not happen and result is same as native " + - "when cast float to decimal") { - val res = spark.sql("select c_float_normal, cast(c_float_normal as decimal(14, 3))," + - "cast(c_float_normal as decimal(30,7)) from deci_float") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(-3.4E-10F, 0.000, 0.0000000), - Row(3.3999999E10F, 33999998976.00, 33999998976.0000000), - Row(1234.4125F, 1234.412, 1234.4124756), - Row(1234.4128F, 1234.413, 1234.4128418) - ) - ) - } - - test("Test ColumnarProjectExec not happen and result is same as native " + - "when cast float to decimal overflow with spark.sql.ansi.enabled=false") { - val res = spark.sql("select c_float_normal, cast(c_float_normal as decimal(14, 11))," + - "cast(c_float_normal as decimal(30,27)) from deci_float") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(-3.4E-10F, -0.00000000034, -0.000000000340000000376150500), - Row(3.3999999E10F, null, null), - Row(1234.4125F, null, null), - Row(1234.4128F, null, null) - ) - ) - } - - test("Test ColumnarProjectExec not happen and result is same as native " + - "when cast float to decimal with null") { - val res = spark.sql("select c_float_null, cast(c_float_null as decimal(17, 6))," + - "cast(c_float_null as decimal(29,7)) from deci_float") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(-1123.1113F, -1123.111328, -1123.1113281), - Row(null, null, null), - Row(123.34F, 123.339996, 123.3399963), - Row(123.12F, 123.120003, 123.1200027) - ) - ) - } - - test("Test ColumnarProjectExec not happen and result is same as native " + - "when cast decimal to float") { - val res = spark.sql("select cast(c_deci17_2_null as float)," + - "cast(c_deci38_2_null as float) from deci_float") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(-1.0E12F, -1.0E12F), - Row(1.0E12F, 1.0E12F), - Row(1234.11F, 10000.99F), - Row(null, 10000.99F) - ) - ) - } - - // double - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast double to decimal") { - val res = spark.sql("select c_double_normal, cast(c_double_normal as decimal(8, 4))," + - "cast(c_double_normal as decimal(32,4)) from deci_double") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(1234.4129, 1234.4129, 1234.4129), - Row(1234.4125, 1234.4125, 1234.4125), - Row(1234.4124, 1234.4124, 1234.4124) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast double to decimal overflow with spark.sql.ansi.enabled=false") { - val res = spark.sql("select c_double_normal, cast(c_double_normal as decimal(8, 6))," + - "cast(c_double_normal as decimal(32,30)) from deci_double") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(1234.4129, null, null), - Row(1234.4125, null, null), - Row(1234.4124, null, null) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast double to decimal with null") { - val res = spark.sql("select c_double_null, cast(c_double_null as decimal(8, 4))," + - "cast(c_double_null as decimal(34,4)) from deci_double") - assertOmniProjectNotHappened(res) - checkAnswer( - res, - Seq( - Row(1123.0, 1123.0000, 1123.0000), - Row(null, null, null), - Row(1234.0, 1234.0000, 1234.0000) - ) - ) - } - - test("Test ColumnarProjectExecc happen and result is same as native " + - "when cast decimal to double") { - val res = spark.sql("select cast(c_deci17_2_null as double)," + - "cast(c_deci38_2_null as double) from deci_double") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(10000.99, 1234.14), - Row(10000.99, 1234.14), - Row(10000.99, null) - ) - ) - } - - // decimal - test("Test ColumnarProjectExec happen when cast decimal to decimal") { - val res = spark.sql("select c_deci21_6, cast(c_deci21_6 as decimal(17, 6))," + - "cast(c_deci21_6 as decimal(28,9)) from deci_decimal") - assertOmniProjectHappened(res) - } - - test("Test ColumnarProjectExec happen when cast decimal " + - "to decimal overflow with spark.sql.ansi.enabled=false") { - val res = spark.sql("select c_deci21_6, cast(c_deci21_6 as decimal(17, 14))," + - "cast(c_deci21_6 as decimal(31,29)) from deci_decimal") - assertOmniProjectHappened(res) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast decimal to decimal with null") { - val res = spark.sql("select c_deci17_2_null, cast(c_deci17_2_null as decimal(18, 6))," + - "cast(c_deci17_2_null as decimal(31,10)) from deci_decimal") - assertOmniProjectHappened(res) - } - - // string - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast string to decimal") { - val res = spark.sql("select c_string_normal, cast(c_string_normal as decimal(16, 5))," + - "cast(c_string_normal as decimal(27,5)) from deci_string") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row("abc ", null, null), - Row("-1234.15 ", -1234.15000, -1234.15000), - Row(" 99999 ", 99999.00000, 99999.00000), - Row("999999 ", 999999.00000, 999999.00000) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast string to decimal overflow with spark.sql.ansi.enabled=false") { - val res = spark.sql("select c_string_normal, cast(c_string_normal as decimal(16, 14))," + - "cast(c_string_normal as decimal(27,23)) from deci_string") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row("abc ", null, null), - Row("-1234.15 ", null, -1234.15000000000000000000000), - Row(" 99999 ", null, null), - Row("999999 ", null, null) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast string to decimal with null") { - val res = spark.sql("select c_string_null, cast(c_string_null as decimal(16, 5))," + - "cast(c_string_null as decimal(27,5)) from deci_string") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row("-333.33 ", -333.33000, -333.33000), - Row("222.2 ", 222.20000, 222.20000), - Row("111 ", 111.00000, 111.00000), - Row(null, null, null) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast decimal to string") { - val res = spark.sql("select cast(cast(c_deci17_2_null as string) as decimal(38, 2))," + - "cast(cast(c_deci38_2_null as string) as decimal(38, 2)) from deci_string") - val executedPlan = res.queryExecution.executedPlan - println(executedPlan) - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(-11111.22, -99999.19), - Row(99999.11, 99999.99), - Row(null, 128.99), - Row(99999.33, null) - ) - ) - } - - // literal int - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast literal int to decimal") { - val res = spark.sql("select cast(1111 as decimal(7,0)), cast(1111 as decimal(21, 6))") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(1111, 1111.000000) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast literal int to decimal overflow with spark.sql.ansi.enabled=false") { - val res = spark.sql("select cast(1111 as decimal(7,5)), cast(1111 as decimal(21, 19))") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(null, null) - ) - ) - } - - // literal long - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast literal long to decimal") { - val res = spark.sql("select cast(111111111111111 as decimal(15,0)), cast(111111111111111 as decimal(21, 6))") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(111111111111111L, 111111111111111.000000) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast literal long to decimal overflow with spark.sql.ansi.enabled=false") { - val res = spark.sql("select cast(111111111111111 as decimal(15,11)), cast(111111111111111 as decimal(21, 15))") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(null, null) - ) - ) - } - - // literal double - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast literal double to decimal") { - val res = spark.sql("select cast(666666.666 as decimal(15,3)), cast(666666.666 as decimal(21, 6))") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(666666.666, 666666.666) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast literal decimal to decimal overflow with spark.sql.ansi.enabled=false") { - val res = spark.sql("select cast(666666.666 as decimal(4,3)), cast(666666.666 as decimal(21, 16))," + - "cast(666666.666 as decimal(21, 16)), cast(666666.666 as decimal(21, 2))") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(null, null, null, 666666.67) - ) - ) - } - - // literal string - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast literal string to decimal") { - val res = spark.sql("select cast(' 666666.666 ' as decimal(15,3)), cast(' 666666.666 ' as decimal(21, 6))") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(666666.666, 666666.666000) - ) - ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast literal string to decimal overflow with spark.sql.ansi.enabled=false") { - val res = spark.sql("select cast(' 666666.666 ' as decimal(15,3)), cast(' 666666.666 ' as decimal(21, 6)), " + - "cast(' 666666.666 ' as decimal(21,18)), cast(' 666666.666 ' as decimal(21, 2))") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(666666.666, 666666.666000, null, 666666.67) - ) - ) - } - - // literal null - test("Test ColumnarProjectExec happen and result is same as native " + - "when cast literal null to decimal") { - val res = spark.sql("select cast(null as decimal(2,0)), cast(null as decimal(21, 0))") - assertOmniProjectHappened(res) - checkAnswer( - res, - Seq( - Row(null, null) - ) - ) - } - - private def assertOmniProjectHappened(res: DataFrame) = { - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - } - - private def assertOmniProjectNotHappened(res: DataFrame) = { - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"ColumnarProjectExec happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isDefined, s"ProjectExec not happened, executedPlan as follows: \n$executedPlan") - } -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarHashAggregateExecSqlSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarHashAggregateExecSqlSuite.scala deleted file mode 100644 index 5b0d14f45cc25c3c90b67e99c73458d3d6c75b13..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarHashAggregateExecSqlSuite.scala +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.forsql - -import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.{ColumnarHashAggregateExec, ColumnarSparkPlanTest} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row} - -class ColumnarHashAggregateExecSqlSuite extends ColumnarSparkPlanTest { - private var df: DataFrame = _ - - protected override def beforeAll(): Unit = { - super.beforeAll() - df = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(1, 2.0, 1L, "a"), - Row(1, 2.0, 2L, null), - Row(2, 1.0, 3L, "c"), - Row(null, null, 6L, "e"), - Row(null, 5.0, 7L, "f") - )), new StructType().add("intCol", IntegerType).add("doubleCol", DoubleType) - .add("longCol", LongType).add("stringCol", StringType)) - df.createOrReplaceTempView("test_table") - } - - test("Test ColumnarHashAggregateExec happen and result is correct when execute count(*)") { - val res = spark.sql("select count(*) from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(5)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result is correct when execute count(1)") { - val res = spark.sql("select count(1) from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(5)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result is correct when execute count(-1)") { - val res = spark.sql("select count(-1) from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(5)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result " + - "is correct when execute otherAgg-count(*)") { - val res = spark.sql("select max(intCol), count(*) from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(2, 5)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result " + - "is correct when execute count(*)-otherAgg") { - val res = spark.sql("select count(*), max(intCol) from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(5, 2)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result " + - "is correct when execute count(*)-count(*)") { - val res = spark.sql("select count(*), count(*) from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(5, 5)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result " + - "is correct when execute count(*)-otherAgg-count(*)") { - val res = spark.sql("select count(*), max(intCol), count(*) from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(5, 2, 5)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result " + - "is correct when execute otherAgg-count(*)-otherAgg") { - val res = spark.sql("select max(intCol), count(*), min(intCol) from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(2, 5, 1)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result " + - "is correct when execute otherAgg-count(*)-count(*)") { - val res = spark.sql("select max(intCol), count(*), count(*) from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(2, 5, 5)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result " + - "is correct when execute count(*) with group by") { - val res = spark.sql("select count(*) from test_table group by intCol") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(2), Row(1), Row(2)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result" + - " is correct when execute count(*) with calculation expr") { - val res = spark.sql("select count(*) / 2 from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(2.5)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result" + - " is correct when execute count(*) with cast expr") { - val res = spark.sql("select cast(count(*) as bigint) from test_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(5)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result" + - " is correct when execute count(*) with subQuery") { - val res = spark.sql("select count(*) from (select intCol," + - "count(*) from test_table group by intCol)") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(3)) - ) - } -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarSupportDataTypeSqlSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarSupportDataTypeSqlSuite.scala deleted file mode 100644 index ca008c377136f8008e9446df9c2a358ec197cf32..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarSupportDataTypeSqlSuite.scala +++ /dev/null @@ -1,463 +0,0 @@ -/* - * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.forsql - -import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.joins.{ColumnarBroadcastHashJoinExec, ColumnarShuffledHashJoinExec, ColumnarSortMergeJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} -import org.apache.spark.sql.execution.window.WindowExec -import org.apache.spark.sql.execution.{ColumnarConditionProjectExec, ColumnarExpandExec, ColumnarFilterExec, ColumnarHashAggregateExec, ColumnarProjectExec, ColumnarShuffleExchangeExec, ColumnarSortExec, ColumnarSparkPlanTest, ColumnarTakeOrderedAndProjectExec, ColumnarUnionExec, ColumnarWindowExec, ExpandExec, FilterExec, ProjectExec, SortExec, TakeOrderedAndProjectExec, UnionExec} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row} - -class ColumnarSupportDataTypeSqlSuite extends ColumnarSparkPlanTest { - private var shortDf: DataFrame = _ - private var joinShort1Df: DataFrame = _ - private var joinShort2Df: DataFrame = _ - - protected override def beforeAll(): Unit = { - super.beforeAll() - shortDf = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(2, 10.toShort, null), - Row(4, 15.toShort, null), - Row(6, 20.toShort, 3.toShort) - )), new StructType().add("id", IntegerType).add("short_normal", ShortType) - .add("short_null", ShortType)) - - joinShort1Df = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(100, 10.toShort), - Row(100, null), - Row(200, 20.toShort), - Row(300, 8.toShort) - )), new StructType().add("id", IntegerType).add("short_col", ShortType)) - - joinShort2Df = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(100, 30.toShort), - Row(200, null), - Row(300, 80.toShort), - Row(400, null) - )), new StructType().add("id", IntegerType).add("short_col", ShortType)) - - shortDf.createOrReplaceTempView("short_table") - joinShort1Df.createOrReplaceTempView("join_short1") - joinShort2Df.createOrReplaceTempView("join_short2") - } - - test("Test ColumnarProjectExec not happen and result is correct when support short") { - val res = spark.sql("select short_normal as col from short_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isDefined, s"ProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"ColumnarProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(10), - Row(15), - Row(20)) - ) - } - - test("Test ColumnarProjectExec not happen and result is correct when support short with expr") { - val res = spark.sql("select short_normal + 2 from short_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isDefined, s"ProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"ColumnarProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(12), - Row(17), - Row(22)) - ) - } - - test("Test ColumnarProjectExec not happen and result is correct when support short with null") { - val res = spark.sql("select short_null + 1 from short_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isDefined, s"ProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"ColumnarProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(null), - Row(null), - Row(4)) - ) - } - - test("Test ColumnarFilterExec not happen and result is correct when support short") { - val res = spark.sql("select short_normal from short_table where short_normal > 3") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[FilterExec]).isDefined, s"FilterExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ColumnarFilterExec]).isEmpty, s"ColumnarFilterExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(10), - Row(15), - Row(20)) - ) - } - - test("Test ColumnarFilterExec not happen and result is correct when support short with expr") { - val res = spark.sql("select short_normal from short_table where short_normal + 2 > 3") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[FilterExec]).isDefined, s"FilterExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ColumnarFilterExec]).isEmpty, s"ColumnarFilterExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(10), - Row(15), - Row(20)) - ) - } - - test("Test ColumnarFilterExec not happen and result is correct when support short with null") { - val res = spark.sql("select short_null from short_table where short_normal > 3") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[FilterExec]).isDefined, s"FilterExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ColumnarFilterExec]).isEmpty, s"ColumnarFilterExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(null), - Row(null), - Row(3)) - ) - } - - test("Test ColumnarConditionProjectExec not happen and result is correct when support short") { - val res = spark.sql("select short_normal + 2 from short_table where short_normal > 3") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarConditionProjectExec]).isEmpty, s"ColumnarConditionProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(12), - Row(17), - Row(22)) - ) - } - - test("Test ColumnarUnionExec happen and result is correct when support short") { - val res = spark.sql("select short_null + 2 from short_table where short_normal > 3 " + - "union select short_normal from short_table where short_normal > 3") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarUnionExec]).isDefined, s"ColumnarUnionExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[UnionExec]).isEmpty, s"UnionExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(5), - Row(10), - Row(15), - Row(20), - Row(null) - ) - ) - } - - test("Test ColumnarHashAggregateExec happen and result is correct when support short") { - val res = spark.sql("select max(short_normal) from short_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"UnionExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(20)) - ) - } - - test("Test HashAggregateExec happen and result is correct when support short with expr") { - val res = spark.sql("select max(short_normal) + 2 from short_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isDefined, s"HashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(22)) - ) - } - - test("Test ColumnarHashAggregateExec happen and result is correct when support short with null") { - val res = spark.sql("select max(short_null) from short_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq(Row(3)) - ) - } - - test("Test ColumnarSortExec happen and result is correct when support short") { - val res = spark.sql("select short_normal from short_table order by short_normal") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined, s"ColumnarSortExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[SortExec]).isEmpty, s"SortExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(10), - Row(15), - Row(20)) - ) - } - - test("Test ColumnarSortExec not happen and result is correct when support short with expr") { - val res = spark.sql("select short_normal from short_table order by short_normal + 1") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isEmpty, s"ColumnarSortExec happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[SortExec]).isDefined, s"SortExec not happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(10), - Row(15), - Row(20)) - ) - } - - test("Test ColumnarSortExec happen and result is correct when support short with null") { - val res = spark.sql("select short_null from short_table order by short_null") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined, s"ColumnarSortExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[SortExec]).isEmpty, s"SortExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(null), - Row(null), - Row(3)) - ) - } - - // window - test("Test ColumnarWindowExec happen and result is correct when support short") { - val res = spark.sql("select id, short_normal, RANK() OVER (PARTITION BY short_normal ORDER BY id) AS rank from short_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarWindowExec]).isDefined, s"ColumnarWindowExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[WindowExec]).isEmpty, s"WindowExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(2, 10, 1), - Row(4, 15, 1), - Row(6, 20, 1)) - ) - } - - test("Test ColumnarWindowExec not happen and result is correct when support short with expr") { - val res = spark.sql("select id + 1, short_normal, sum(short_normal) OVER (PARTITION BY short_normal ORDER BY id) AS rank from short_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarWindowExec]).isEmpty, s"ColumnarWindowExec happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[WindowExec]).isDefined, s"WindowExec not happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(3, 10, 10), - Row(5, 15, 15), - Row(7, 20, 20)) - ) - } - - test("Test ColumnarWindowExec happen and result is correct when support short with null") { - val res = spark.sql("select id, short_null, RANK() OVER (PARTITION BY short_null ORDER BY id) AS rank from short_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarWindowExec]).isDefined, s"ColumnarWindowExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[WindowExec]).isEmpty, s"WindowExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(2, null, 1), - Row(4, null, 2), - Row(6, 3, 1)) - ) - } - - test("Test ColumnarTakeOrderedAndProjectExec happen and result is correct when support short") { - val res = spark.sql("select short_normal from short_table order by short_normal limit 2") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarTakeOrderedAndProjectExec]).isDefined, s"ColumnarTakeOrderedAndProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isEmpty, s"TakeOrderedAndProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(10), - Row(15)) - ) - } - - test("Test ColumnarTakeOrderedAndProjectExec not happen and result is correct when support short with expr") { - val res = spark.sql("select short_normal from short_table order by short_normal + 1 limit 2") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarTakeOrderedAndProjectExec]).isEmpty, s"ColumnarTakeOrderedAndProjectExec happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined, s"TakeOrderedAndProjectExec not happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(10), - Row(15)) - ) - } - - test("Test ColumnarTakeOrderedAndProjectExec happen and result is correct when support short with null") { - val res = spark.sql("select short_null from short_table order by short_null limit 2") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarTakeOrderedAndProjectExec]).isDefined, s"ColumnarTakeOrderedAndProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isEmpty, s"TakeOrderedAndProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(null), - Row(null)) - ) - } - - test("Test ColumnarShuffleExchangeExec happen and result is correct when support short with group by no-short") { - val res = spark.sql("select id, sum(short_null) from short_table group by id") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined, s"ColumnarShuffleExchangeExec not happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(2, null), - Row(4, null), - Row(6, 3)) - ) - } - - test("Test ColumnarShuffleExchangeExec not happen and result is correct when support short with group by short") { - val res = spark.sql("select short_null, sum(short_null) from short_table group by short_null") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isEmpty, s"ColumnarShuffleExchangeExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(null, null), - Row(3, 3)) - ) - } - - test("Test ColumnarExpandExec not happen and result is correct when support short") { - val res = spark.sql("select id, short_null, sum(short_normal) as sum from short_table group by " + - "grouping sets((id, short_null)) order by id, short_null") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty, s"ColumnarExpandExec happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ExpandExec]).isDefined, s"ExpandExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(2, null, 10), - Row(4, null, 15), - Row(6, 3, 20)) - ) - } - - test("Test ColumnarSortMergeJoinExec happen and result is correct when support short") { - val res = spark.sql("select /*+ MERGEJOIN(t2) */ t1.*, t2.* from join_short1 t1, join_short2 t2 where t1.id = t2.id") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarSortMergeJoinExec]).isDefined, s"ColumnarSortMergeJoinExec not happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(100, null, 100, 30), - Row(100, 10, 100, 30), - Row(200, 20, 200, null), - Row(300, 8, 300, 80)) - ) - } - - test("Test ColumnarSortMergeJoinExec not happen and result is correct when support short with expr") { - val res = spark.sql("select /*+ MERGEJOIN(t2) */ t1.*, t2.* from join_short1 t1, join_short2 t2 where t1.short_col < t2.short_col") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarSortMergeJoinExec]).isEmpty, s"ColumnarSortMergeJoinExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(100, 10, 100, 30), - Row(100, 10, 300, 80), - Row(200, 20, 100, 30), - Row(200, 20, 300, 80), - Row(300, 8, 100, 30), - Row(300, 8, 300, 80)) - ) - } - - test("Test ColumnarShuffledHashJoinExec happen and result is correct when support short") { - val res = spark.sql("select /*+ SHUFFLE_HASH(t2) */ t1.*, t2.* from join_short1 t1, join_short2 t2 where t1.id = t2.id") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarShuffledHashJoinExec]).isDefined, s"ColumnarShuffledHashJoinExec not happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(100, null, 100, 30), - Row(100, 10, 100, 30), - Row(200, 20, 200, null), - Row(300, 8, 300, 80)) - ) - } - - test("Test ColumnarShuffledHashJoinExec not happen and result is correct when support short with expr") { - val res = spark.sql("select /*+ SHUFFLE_HASH(t2) */ t1.*, t2.* from join_short1 t1, join_short2 t2 where t1.short_col < t2.short_col") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarShuffledHashJoinExec]).isEmpty, s"ColumnarShuffledHashJoinExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(100, 10, 100, 30), - Row(100, 10, 300, 80), - Row(200, 20, 100, 30), - Row(200, 20, 300, 80), - Row(300, 8, 100, 30), - Row(300, 8, 300, 80)) - ) - } - - test("Test ColumnarBroadcastHashJoinExec happen and result is correct when support short") { - val res = spark.sql("select /*+ BROADCAST(t2) */ t1.*, t2.* from join_short1 t1, join_short2 t2 where t1.id = t2.id") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarBroadcastHashJoinExec]).isDefined, s"ColumnarBroadcastHashJoinExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(100, null, 100, 30), - Row(100, 10, 100, 30), - Row(200, 20, 200, null), - Row(300, 8, 300, 80)) - ) - } - - test("Test ColumnarBroadcastHashJoinExec not happen and result is correct when support short with expr") { - val res = spark.sql("select /*+ BROADCAST(t2) */ t1.*, t2.* from join_short1 t1, join_short2 t2 where t1.short_col < t2.short_col") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarBroadcastHashJoinExec]).isEmpty, s"ColumnarBroadcastHashJoinExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(100, 10, 100, 30), - Row(100, 10, 300, 80), - Row(200, 20, 100, 30), - Row(200, 20, 300, 80), - Row(300, 8, 100, 30), - Row(300, 8, 300, 80)) - ) - } -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml index c0a217a2fbbb9f1ae3e432cb9506979c0e5e8cec..790cebec41bfe6d4398e80e5fa30dfc703fd34cb 100644 --- a/omnioperator/omniop-spark-extension/pom.xml +++ b/omnioperator/omniop-spark-extension/pom.xml @@ -8,14 +8,14 @@ com.huawei.kunpeng boostkit-omniop-spark-parent pom - 3.1.1-1.2.0 + 3.2.1-1.2.0 BoostKit Spark Native Sql Engine Extension Parent Pom 2.12.10 2.12 - 3.1.1 + 3.2.1 3.2.2 UTF-8 UTF-8 @@ -55,6 +55,18 @@ org.apache.curator curator-recipes + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-databind + @@ -101,6 +113,20 @@ ${omniruntime.version} aarch64 provided + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-databind + + com.google.protobuf @@ -124,6 +150,18 @@ org.apache.curator curator-recipes + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-databind +