diff --git a/omniadvisor/pom.xml b/omniadvisor/pom.xml index 74ad839147c027537db0dee19ef162f0df6743e7..d72c33d17aa732bcbca24f960e15d14c3c1ad05c 100644 --- a/omniadvisor/pom.xml +++ b/omniadvisor/pom.xml @@ -305,6 +305,10 @@ org.apache.commons commons-text + + org.apache.commons + commons-compress + com.google.guava guava diff --git a/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala b/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala index 65368433a2bf733b6c5bdb63784980528765a712..b15800ef498fa27c58c80a1d0133432066e6486a 100644 --- a/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala +++ b/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala @@ -17,11 +17,10 @@ package org.apache.spark import com.huawei.boostkit.omniadvisor.fetcher.FetcherType import com.huawei.boostkit.omniadvisor.models.AppResult -import com.huawei.boostkit.omniadvisor.spark.utils.ScalaUtils.parseMapToJsonString -import com.huawei.boostkit.omniadvisor.spark.utils.ScalaUtils.checkSuccess +import com.huawei.boostkit.omniadvisor.spark.utils.ScalaUtils.{checkSuccess, parseMapToJsonString} import com.huawei.boostkit.omniadvisor.spark.utils.SparkUtils import com.nimbusds.jose.util.StandardCharset -import org.apache.spark.sql.execution.ui.SQLExecutionUIData +import org.apache.spark.sql.execution.ui.{SQLExecutionUIData, SparkPlanGraph} import org.apache.spark.status.api.v1._ import org.slf4j.{Logger, LoggerFactory} @@ -41,7 +40,8 @@ object SparkApplicationDataExtractor { workload: String, environmentInfo: ApplicationEnvironmentInfo, jobsList: Seq[JobData], - sqlExecutionsList: Seq[SQLExecutionUIData]): AppResult = { + sqlExecutionsList: Seq[SQLExecutionUIData], + sqlGraphMap: mutable.Map[Long, SparkPlanGraph]): AppResult = { val appResult = new AppResult appResult.applicationId = appInfo.id appResult.applicationName = appInfo.name @@ -57,7 +57,7 @@ object SparkApplicationDataExtractor { val attempt: ApplicationAttemptInfo = lastAttempt(appInfo) if (attempt.completed && jobsList.nonEmpty && checkSuccess(jobsList)) { - saveSuccessfulStatus(appResult, jobsList, sqlExecutionsList) + saveSuccessfulStatus(appResult, jobsList, sqlExecutionsList, sqlGraphMap) } else { saveFailedStatus(appResult, attempt) } @@ -112,7 +112,7 @@ object SparkApplicationDataExtractor { appResult.query = "" } - private def saveSuccessfulStatus(appResult: AppResult, jobsList: Seq[JobData], sqlExecutionsList: Seq[SQLExecutionUIData]): Unit = { + private def saveSuccessfulStatus(appResult: AppResult, jobsList: Seq[JobData], sqlExecutionsList: Seq[SQLExecutionUIData], sqlGraphMap: mutable.Map[Long, SparkPlanGraph]): Unit = { appResult.executionStatus = AppResult.SUCCEEDED_STATUS val (startTime, finishTime) = extractJobsTime(jobsList) @@ -122,7 +122,7 @@ object SparkApplicationDataExtractor { finishTime - startTime else AppResult.FAILED_JOB_DURATION if (appResult.submit_method.equals(AppResult.SPARK_SQL)) { - appResult.query = extractQuerySQL(sqlExecutionsList) + appResult.query = extractQuerySQL(sqlExecutionsList, sqlGraphMap) } else { appResult.query = "" } @@ -153,16 +153,32 @@ object SparkApplicationDataExtractor { (startTime, finishTime) } - private def extractQuerySQL(sqlExecutionsList: Seq[SQLExecutionUIData]): String = { + private def extractQuerySQL(sqlExecutionsList: Seq[SQLExecutionUIData], sqlGraphMap: mutable.Map[Long, SparkPlanGraph]): String = { require(sqlExecutionsList.nonEmpty) - val nonEmptyDescriptions = sqlExecutionsList.flatMap { execution => - Option(execution.description).filter(_.nonEmpty) + val descriptions: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String] + var previousExecutionDesc: Option[String] = None + for ((execution, index) <- sqlExecutionsList.zipWithIndex) { + if (index > 0) { + val sqlGraph = sqlGraphMap(execution.executionId) + if (execution.description.nonEmpty && !isDuplicatedQuery(execution, previousExecutionDesc, sqlGraph)) { + descriptions += execution.description.trim + } + } else { + if (execution.description.nonEmpty) { + descriptions += execution.description.trim + } + } + previousExecutionDesc = Some(execution.description) } - nonEmptyDescriptions.mkString(";") + descriptions.mkString(";\n") } private def lastAttempt(applicationInfo: ApplicationInfo): ApplicationAttemptInfo = { require(applicationInfo.attempts.nonEmpty) applicationInfo.attempts.last } + + private def isDuplicatedQuery(execution: SQLExecutionUIData, previousExecutionDesc: Option[String], sqlGraph: SparkPlanGraph): Boolean = { + execution.description.equals(previousExecutionDesc.getOrElse("")) && sqlGraph.allNodes.size == 1 && sqlGraph.allNodes.head.name.equals("LocalTableScan") + } } diff --git a/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala b/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala index 236469cce9da0828965d2a694a426e214069e1fb..aa5a4bf7d8b1c68c249575acd90e94518a62376e 100644 --- a/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala +++ b/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala @@ -16,16 +16,17 @@ package org.apache.spark import com.huawei.boostkit.omniadvisor.models.AppResult -import org.apache.spark.status.api.v1 -import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} import org.apache.spark.internal.config.Status.ASYNC_TRACKING_ENABLED import org.apache.spark.scheduler.ReplayListenerBus +import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLExecutionUIData, SparkPlanGraph} +import org.apache.spark.status.api.v1 import org.apache.spark.status.{AppStatusListener, AppStatusStore, ElementTrackingStore} -import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLExecutionUIData} import org.apache.spark.util.Utils +import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} import org.slf4j.{Logger, LoggerFactory} import java.io.InputStream +import scala.collection.mutable class SparkDataCollection { val LOG: Logger = LoggerFactory.getLogger(classOf[SparkDataCollection]) @@ -36,6 +37,7 @@ class SparkDataCollection { var jobsList: Seq[v1.JobData] = _ var appInfo: v1.ApplicationInfo = _ var sqlExecutionsList: Seq[SQLExecutionUIData] = _ + var sqlGraphMap: mutable.Map[Long, SparkPlanGraph] = _ def replayEventLogs(in: InputStream, sourceName: String): Unit = { @@ -66,12 +68,22 @@ class SparkDataCollection { val sqlAppStatusStore: SQLAppStatusStore = new SQLAppStatusStore(store) sqlExecutionsList = sqlAppStatusStore.executionsList() + sqlGraphMap = mutable.HashMap.empty[Long, SparkPlanGraph] + sqlExecutionsList.foreach { sqlExecution => + try { + val planGraph = sqlAppStatusStore.planGraph(sqlExecution.executionId) + sqlGraphMap.put(sqlExecution.executionId, planGraph) + } catch { + case e: Exception => + LOG.warn(s"Get PlanGraph for SQLExecution [${sqlExecution.executionId}] in ${appInfo.id} failed") + } + } appStatusStore.close() } def getAppResult(workload: String): AppResult = { - SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(appInfo, workload, environmentInfo, jobsList, sqlExecutionsList) + SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(appInfo, workload, environmentInfo, jobsList, sqlExecutionsList, sqlGraphMap) } private def createInMemoryStore(): KVStore = { diff --git a/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java b/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java index 7eed130902a290b108988aadbedd3db301053c24..02953cb7cc46400f5de9343995b8e07dfbf47d28 100644 --- a/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java +++ b/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java @@ -113,7 +113,7 @@ public class TestSparkApplicationDataExtractor { sqlExecutionList.add(new SQLExecutionUIData(1, sqlQ10, "", "", asScalaBuffer(ImmutableList.of()), 0, Option.apply(new Date()), new HashMap<>(), new HashSet<>(), new HashMap<>())); sqlExecutionList.add(new SQLExecutionUIData(2, sqlQ11, "", "", asScalaBuffer(ImmutableList.of()), 0, Option.apply(new Date()), new HashMap<>(), new HashSet<>(), new HashMap<>())); - AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(sqlExecutionList)); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(sqlExecutionList), null); assertEquals(result.applicationId, "id"); assertEquals(result.durationTime, 15 * 60 * 1000L); assertEquals(result.submit_method, "spark-sql"); @@ -134,7 +134,7 @@ public class TestSparkApplicationDataExtractor { when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(clientSparkProperties)); when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(systemProperties)); - AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(ImmutableList.of())); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(ImmutableList.of()), null); assertEquals(result.applicationId, "id"); assertEquals(result.durationTime, 15 * 60 * 1000L); assertEquals(result.submit_method, "spark-submit"); @@ -154,7 +154,7 @@ public class TestSparkApplicationDataExtractor { when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(clusterSparkProperties)); when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(systemProperties)); - AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(ImmutableList.of())); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(ImmutableList.of()), null); assertEquals(result.applicationId, "id"); assertEquals(result.durationTime, 15 * 60 * 1000L); assertEquals(result.submit_method, "spark-submit"); @@ -169,7 +169,7 @@ public class TestSparkApplicationDataExtractor { ApplicationEnvironmentInfo environmentInfo = Mockito.mock(ApplicationEnvironmentInfo.class); when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); - AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(runningData)), asScalaBuffer(ImmutableList.of())); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(runningData)), asScalaBuffer(ImmutableList.of()), null); assertEquals(result.applicationId, "id"); assertEquals(result.executionStatus, AppResult.FAILED_STATUS); assertEquals(result.durationTime, AppResult.FAILED_JOB_DURATION); @@ -181,7 +181,7 @@ public class TestSparkApplicationDataExtractor { ApplicationEnvironmentInfo environmentInfo = Mockito.mock(ApplicationEnvironmentInfo.class); when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); - AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(failedData)), asScalaBuffer(ImmutableList.of())); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(failedData)), asScalaBuffer(ImmutableList.of()), null); assertEquals(result.applicationId, "id"); assertEquals(result.executionStatus, AppResult.FAILED_STATUS); assertEquals(result.durationTime, AppResult.FAILED_JOB_DURATION); @@ -193,7 +193,7 @@ public class TestSparkApplicationDataExtractor { ApplicationEnvironmentInfo environmentInfo = Mockito.mock(ApplicationEnvironmentInfo.class); when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); - AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of()), asScalaBuffer(ImmutableList.of())); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of()), asScalaBuffer(ImmutableList.of()), null); assertEquals(result.applicationId, "id"); assertEquals(result.executionStatus, AppResult.FAILED_STATUS); assertEquals(result.durationTime, AppResult.FAILED_JOB_DURATION); @@ -205,6 +205,6 @@ public class TestSparkApplicationDataExtractor { ApplicationEnvironmentInfo environmentInfo = Mockito.mock(ApplicationEnvironmentInfo.class); when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); - SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of()), asScalaBuffer(ImmutableList.of())); + SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of()), asScalaBuffer(ImmutableList.of()), null); } } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java index 21e717a0d58d54bc52967f3416095426a0719c2b..6d83c383bced9885552cabd804a8afd34ee20d42 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java @@ -18,10 +18,12 @@ package com.huawei.boostkit.hive; +import static com.huawei.boostkit.hive.expression.TypeUtils.checkUnsupportedTimestamp; import static com.huawei.boostkit.hive.expression.TypeUtils.checkOmniJsonWhiteList; import static com.huawei.boostkit.hive.expression.TypeUtils.checkUnsupportedArithmetic; import static com.huawei.boostkit.hive.expression.TypeUtils.checkUnsupportedCast; import static com.huawei.boostkit.hive.expression.TypeUtils.convertHiveTypeToOmniType; +import static com.huawei.boostkit.hive.expression.TypeUtils.isValidConversion; import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_AVG; import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; import static org.apache.hadoop.hive.serde.serdeConstants.SERIALIZATION_LIB; @@ -36,6 +38,7 @@ import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspe import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.STRING; import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.TIMESTAMP; import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.VARCHAR; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.VOID; import com.huawei.boostkit.hive.expression.BaseExpression; import com.huawei.boostkit.hive.expression.CastFunctionExpression; @@ -50,6 +53,7 @@ import nova.hetu.omniruntime.constants.FunctionType; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.ql.QueryPlan; import org.apache.hadoop.hive.ql.exec.ExplainTask; import org.apache.hadoop.hive.ql.exec.MapJoinOperator; @@ -93,6 +97,7 @@ import org.apache.hadoop.hive.ql.plan.ReduceWork; import org.apache.hadoop.hive.ql.plan.SelectDesc; import org.apache.hadoop.hive.ql.plan.TableDesc; import org.apache.hadoop.hive.ql.plan.TableScanDesc; +import org.apache.hadoop.hive.ql.plan.VectorTableScanDesc; import org.apache.hadoop.hive.ql.plan.TezEdgeProperty; import org.apache.hadoop.hive.ql.plan.TezWork; import org.apache.hadoop.hive.ql.plan.UnionWork; @@ -131,7 +136,7 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { public static final Set SUPPORTED_JOIN = new HashSet<>(Arrays.asList(JoinDesc.INNER_JOIN, JoinDesc.LEFT_OUTER_JOIN, JoinDesc.FULL_OUTER_JOIN, JoinDesc.LEFT_SEMI_JOIN)); private static final Set SUPPORTED_TYPE = new HashSet<>(Arrays.asList(BOOLEAN, - SHORT, INT, LONG, DOUBLE, STRING, DATE, DECIMAL, VARCHAR, CHAR)); + SHORT, INT, LONG, DOUBLE, STRING, DATE, DECIMAL, VARCHAR, CHAR, VOID)); private static final int DECIMAL64_MAX_PRECISION = 19; @@ -530,6 +535,21 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { if (tableMetadata != null && (!tableMetadata.getInputFormatClass().equals(OrcInputFormat.class) || tableMetadata.getParameters().getOrDefault("transactional", "").equals("true"))) { return false; } + if (tableScanDesc.isVectorized()) { + TypeInfo[] columnTypeInfos = ((VectorTableScanDesc) tableScanDesc.getVectorDesc()).getProjectedColumnTypeInfos(); + for (int id : tableScanDesc.getNeededColumnIDs()) { + if (columnTypeInfos[id].getTypeName() == "timestamp") { + return false; + } + } + } else if (tableMetadata != null && tableMetadata.getCols() != null) { + List colList = tableMetadata.getCols(); + for (int id : tableScanDesc.getNeededColumnIDs()) { + if (colList.get(id).getType() == "timestamp") { + return false; + } + } + } List> childOperators = op.getChildOperators(); for (Operator childOperator : childOperators) { if (childOperator.getType().equals(OperatorType.REDUCESINK) && reduceSinkDescUnReplaceable((ReduceSinkDesc) childOperator.getConf())) { @@ -619,7 +639,7 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { return replaceable; case FILTER: List colList = Collections.singletonList(((FilterDesc) operator.getConf()).getPredicate()); - if (!isUDFSupport(colList) || !isLegalDeciConstant(colList)) { + if ((!isUDFSupport(colList) && !isLegalTimestamp(colList)) || !isLegalDeci(colList)) { return false; } boolean result = true; @@ -629,7 +649,7 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { for (Operator child : operator.getChildOperators()) { if (child.getType() != null && child.getType().equals(OperatorType.SELECT)) { SelectDesc conf = (SelectDesc) child.getConf(); - result = result && isUDFSupport(conf.getColList()) && isLegalDeciConstant(conf.getColList()); + result = result && isUDFSupport(conf.getColList()) && isLegalDeci(conf.getColList()); } } return result; @@ -698,6 +718,16 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { return false; } List windowFunctionDefs = ((WindowTableFunctionDef) conf.getFuncDef()).getWindowFunctions(); + for (WindowFunctionDef functionDef : windowFunctionDefs) { + if (functionDef.getArgs() == null) { + continue; + } + for (PTFExpressionDef expressionDef : functionDef.getArgs()) { + if (expressionDef.getExprNode() != null && expressionDef.getExprNode().getTypeInfo().getTypeName() == "timestamp") { + return false; + } + } + } if (!PTFSupportedAgg(windowFunctionDefs)) { return false; } @@ -724,6 +754,12 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { if (!isUDFSupport(exprNodeDescList)) { return false; } + JoinCondDesc[] joinCondDescs = mapJoinDesc.getConds(); + if (joinCondDescs.length >= 2) { + if (joinCondDescs[0].getType() == JoinDesc.LEFT_OUTER_JOIN && joinCondDescs[1].getType() == JoinDesc.LEFT_SEMI_JOIN) { + return false; + } + } return joinReplaceable(operator); } @@ -783,12 +819,14 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { private boolean joinReplaceable(Operator operator) { JoinDesc joinDesc = (JoinDesc) operator.getConf(); JoinCondDesc[] joinCondDescs = joinDesc.getConds(); - if (joinDesc.getConds()[0].getType() == JoinDesc.FULL_OUTER_JOIN && joinDesc.getKeysString().get(0) == null || !SUPPORTED_JOIN.contains(joinCondDescs[0].getType())) { - return false; + for (JoinCondDesc joinCondDesc: joinCondDescs) { + if (joinCondDesc.getType() == JoinDesc.FULL_OUTER_JOIN && joinDesc.getKeysString().get("0") == null || !SUPPORTED_JOIN.contains(joinCondDesc.getType())) { + return false; + } } if (joinCondDescs.length >= 2) { - for (int i = 0; i < joinCondDescs.length; i++) { - if (joinCondDescs[i].getType() != JoinDesc.INNER_JOIN) { + for (List filters: joinDesc.getFilters().values()) { + if (!filters.isEmpty()) { return false; } } @@ -859,16 +897,16 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { return checkOmniJsonWhiteList("", expressions.toArray(new String[0])); } - private boolean isLegalDeciConstant(List colList) { + private boolean isLegalDeci(List colList) { for (ExprNodeDesc desc : colList) { - if (!checkDecimalConstant(desc)) { + if (!isValidConversion(desc)) { return false; } } if (colList.size() > 0 && colList.get(0).getChildren() != null) { List childList = colList.get(0).getChildren(); for (ExprNodeDesc desc : childList) { - if (!checkDecimalConstant(desc)) { + if (!isValidConversion(desc)) { return false; } } @@ -876,23 +914,16 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { return true; } - private boolean checkDecimalConstant(ExprNodeDesc desc) { - if (desc instanceof ExprNodeGenericFuncDesc && desc.getChildren() != null && desc.getChildren().size() == 2) { - List child = desc.getChildren(); - if (child.get(0) instanceof ExprNodeConstantDesc && child.get(1) instanceof ExprNodeColumnDesc) { - Collections.swap(child, 0, 1); + private boolean isLegalTimestamp(List colList) { + for (ExprNodeDesc desc : colList) { + if (checkUnsupportedTimestamp(desc)) { + return false; } - if (child.get(0) instanceof ExprNodeColumnDesc && child.get(1) instanceof ExprNodeConstantDesc) { - TypeInfo deciInfo = child.get(0).getTypeInfo(); - TypeInfo constInfo = child.get(1).getTypeInfo(); - if (!(deciInfo instanceof DecimalTypeInfo && constInfo instanceof DecimalTypeInfo)) { - return true; - } - int deciPrecision = ((DecimalTypeInfo) deciInfo).getPrecision(); - int deciScale = ((DecimalTypeInfo) deciInfo).getScale(); - int constPrecision = ((DecimalTypeInfo) constInfo).getPrecision(); - int constScale = ((DecimalTypeInfo) constInfo).getScale(); - if (constPrecision - constScale > deciPrecision - deciScale || constScale > deciScale) { + } + if (colList.size() > 0 && colList.get(0).getChildren() != null) { + List childList = colList.get(0).getChildren(); + for (ExprNodeDesc desc : childList) { + if (checkUnsupportedTimestamp(desc)) { return false; } } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniFilterOperator.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniFilterOperator.java index 8a8b2f550a196cb1e381fc9d3f5ab0f9f936018b..134b02c72b4c8e66d115d2c7e47b4b6533b045af 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniFilterOperator.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniFilterOperator.java @@ -23,6 +23,9 @@ import com.huawei.boostkit.hive.expression.ExpressionUtils; import com.huawei.boostkit.hive.expression.ReferenceFactor; import com.huawei.boostkit.hive.expression.TypeUtils; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; + import nova.hetu.omniruntime.operator.OmniOperator; import nova.hetu.omniruntime.operator.config.OperatorConfig; import nova.hetu.omniruntime.operator.config.OverflowConfig; @@ -32,7 +35,9 @@ import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.vector.VecBatch; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.CompilationOpContext; +import org.apache.hadoop.hive.ql.metadata.Hive; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; @@ -57,6 +62,13 @@ public class OmniFilterOperator extends OmniHiveOperator impleme private transient OmniOperator omniOperator; private transient Iterator output; + private static Cache cache= CacheBuilder.newBuilder().concurrencyLevel(8).initialCapacity(10) + .maximumSize(100).recordStats().removalListener(notification ->{ + ((OmniFilterAndProjectOperatorFactory) notification.getValue()).close(); + }).build(); + + private static boolean addCloseThread = false; + public OmniFilterOperator() { super(); } @@ -74,6 +86,16 @@ public class OmniFilterOperator extends OmniHiveOperator impleme protected void initializeOp(Configuration hconf) throws HiveException { super.initializeOp(hconf); ExprNodeDesc predicate = conf.getPredicate(); + String queryId = HiveConf.getVar(hconf, HiveConf.ConfVars.HIVEQUERYID); + String cacheKey = queryId + "FILTER_" + this.getOperatorId() + "_container"; + + OmniFilterAndProjectOperatorFactory omniFilterAndProjectOperatorFactory = (OmniFilterAndProjectOperatorFactory) cache + .getIfPresent(cacheKey); + if (omniFilterAndProjectOperatorFactory != null) { + this.filterAndProjectOperatorFactory = omniFilterAndProjectOperatorFactory; + this.omniOperator = this.filterAndProjectOperatorFactory.createOperator(); + return; + } BaseExpression root; if (predicate instanceof ExprNodeGenericFuncDesc) { root = ExpressionUtils.build((ExprNodeGenericFuncDesc) predicate, inputObjInspectors[0]); @@ -97,6 +119,12 @@ public class OmniFilterOperator extends OmniHiveOperator impleme Arrays.asList(projections), 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OVERFLOW_CONFIG_NULL), true)); this.omniOperator = this.filterAndProjectOperatorFactory.createOperator(); + + cache.put(cacheKey, this.filterAndProjectOperatorFactory); + if (!addCloseThread) { + Runtime.getRuntime().addShutdownHook(new Thread(() -> cache.invalidateAll())); + addCloseThread = true; + } } @Override @@ -121,8 +149,12 @@ public class OmniFilterOperator extends OmniHiveOperator impleme @Override protected void closeOp(boolean abort) throws HiveException { - filterAndProjectOperatorFactory.close(); - omniOperator.close(); + if (filterAndProjectOperatorFactory != null) { + filterAndProjectOperatorFactory.close(); + } + if (omniOperator != null) { + omniOperator.close(); + } output = null; super.closeOp(abort); } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java index 8105b4c1a38579bc59def1225c6e919c01bc46af..9a0e7c7912930feaed35a93586007a6d2aefe992 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java @@ -46,6 +46,7 @@ import nova.hetu.omniruntime.vector.Vec; import nova.hetu.omniruntime.vector.VecBatch; import nova.hetu.omniruntime.vector.VecFactory; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.common.type.Date; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.CompilationOpContext; import org.apache.hadoop.hive.ql.exec.ExprNodeColumnEvaluator; @@ -570,38 +571,48 @@ public class OmniGroupByOperator extends OmniHiveOperator imple Vec newVec = VecFactory.createFlatVec(rowCount, dataType); DataType.DataTypeId dataTypeId = dataType.getId(); for (int i = 0; i < rowCount; i++) { + Object exprValue = exprNodeConstantEvaluator.getExpr().getValue(); + if (exprValue == null) { + newVec.setNull(i); + continue; + } switch (dataTypeId) { case OMNI_INT: case OMNI_DATE32: - ((IntVec) newVec).set(i, (int) exprNodeConstantEvaluator.getExpr().getValue()); + if (exprValue instanceof Date) { + ((IntVec) newVec).set(i, ((Date) exprValue).toEpochDay()); + } else { + ((IntVec) newVec).set(i, (int) exprValue); + } break; case OMNI_LONG: case OMNI_DATE64: case OMNI_DECIMAL64: - Object exprValue = exprNodeConstantEvaluator.getExpr().getValue(); if (exprValue instanceof Timestamp) { ((LongVec) newVec).set(i, ((Timestamp) exprValue).toEpochMilli()); + } else if (exprValue instanceof Date) { + ((LongVec) newVec).set(i, ((Date) exprValue).toEpochDay()); } else { - ((LongVec) newVec).set(i, (long) exprNodeConstantEvaluator.getExpr().getValue()); + ((LongVec) newVec).set(i, (long) exprValue); } break; case OMNI_DOUBLE: - ((DoubleVec) newVec).set(i, (double) exprNodeConstantEvaluator.getExpr().getValue()); + ((DoubleVec) newVec).set(i, (double) exprValue); break; case OMNI_BOOLEAN: - ((BooleanVec) newVec).set(i, (boolean) exprNodeConstantEvaluator.getExpr().getValue()); + ((BooleanVec) newVec).set(i, (boolean) exprValue); break; case OMNI_SHORT: - ((ShortVec) newVec).set(i, (short) exprNodeConstantEvaluator.getExpr().getValue()); + ((ShortVec) newVec).set(i, (short) exprValue); break; case OMNI_DECIMAL128: - HiveDecimal hiveDecimal = (HiveDecimal) exprNodeConstantEvaluator.getExpr().getValue(); + HiveDecimal hiveDecimal = (HiveDecimal) exprValue; DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) exprNodeConstantEvaluator.getExpr().getTypeInfo(); ((Decimal128Vec) newVec).setBigInteger(i, hiveDecimal.bigIntegerBytesScaled(decimalTypeInfo.getScale()), hiveDecimal.signum() == -1); break; case OMNI_VARCHAR: case OMNI_CHAR: - ((VarcharVec) newVec).set(i, exprNodeConstantEvaluator.getExpr().getValue().toString().getBytes()); + ((VarcharVec) newVec).set(i, exprValue.toString().getBytes()); break; default: throw new RuntimeException("Not support dataType, dataTypeId: " + dataTypeId); diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniJoinOperator.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniJoinOperator.java deleted file mode 100644 index c3cb6e59831b0c9c56b6a8f8e7e87657e5822fa2..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniJoinOperator.java +++ /dev/null @@ -1,490 +0,0 @@ -/* - * Copyright (C) 2023-2024. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.hive; - -import static com.huawei.boostkit.hive.JoinUtils.getExprNodeColumnEvaluator; -import static com.huawei.boostkit.hive.OmniMapJoinOperator.JOIN_TYPE_MAP; - -import com.huawei.boostkit.hive.expression.BaseExpression; -import com.huawei.boostkit.hive.expression.ExpressionUtils; -import com.huawei.boostkit.hive.expression.TypeUtils; - -import nova.hetu.omniruntime.constants.JoinType; -import nova.hetu.omniruntime.operator.OmniOperator; -import nova.hetu.omniruntime.operator.OmniOperatorFactory; -import nova.hetu.omniruntime.operator.join.OmniSmjBufferedTableWithExprOperatorFactory; -import nova.hetu.omniruntime.operator.join.OmniSmjStreamedTableWithExprOperatorFactory; -import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.BooleanVec; -import nova.hetu.omniruntime.vector.Decimal128Vec; -import nova.hetu.omniruntime.vector.DoubleVec; -import nova.hetu.omniruntime.vector.IntVec; -import nova.hetu.omniruntime.vector.LongVec; -import nova.hetu.omniruntime.vector.ShortVec; -import nova.hetu.omniruntime.vector.VarcharVec; -import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecBatch; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.ql.CompilationOpContext; -import org.apache.hadoop.hive.ql.exec.CommonJoinOperator; -import org.apache.hadoop.hive.ql.exec.ExprNodeColumnEvaluator; -import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator; -import org.apache.hadoop.hive.ql.exec.Operator; -import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.ql.exec.tez.RecordSource; -import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; -import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; -import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; -import org.apache.hadoop.hive.ql.plan.JoinDesc; -import org.apache.hadoop.hive.ql.plan.OperatorDesc; -import org.apache.hadoop.hive.ql.plan.api.OperatorType; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.StructField; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.AbstractPrimitiveObjectInspector; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Queue; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -public class OmniJoinOperator extends CommonJoinOperator implements Serializable { - private static final long serialVersionUID = 1L; - - protected static final int SMJ_NEED_ADD_STREAM_TBL_DATA = 2; - protected static final int SMJ_NEED_ADD_BUFFERED_TBL_DATA = 3; - protected static final int SCAN_FINISH = 4; - protected static final int RES_INIT = 0; - protected static final int SMJ_FETCH_JOIN_DATA = 5; - - protected transient RecordSource[] sources; - protected transient boolean[] fetchDone; - - protected transient OmniSmjBufferedTableWithExprOperatorFactory[] bufferFactories; - protected transient OmniSmjStreamedTableWithExprOperatorFactory[] streamFactories; - protected transient OmniOperator[] bufferOperators; - protected transient OmniOperator[] streamOperators; - - protected transient int[] resCode; - protected transient int[] flowControlCode; - protected transient Queue[] streamData; - protected transient Queue[] bufferData; - protected transient DataType[][] streamTypes; - protected transient DataType[][] bufferTypes; - private transient Iterator output; - - protected OmniJoinOperator() { - super(); - } - - public OmniJoinOperator(CompilationOpContext ctx) { - super((ctx)); - } - - public OmniJoinOperator(CompilationOpContext ctx, JoinDesc joinDesc) { - super(ctx); - this.conf = new OmniMergeJoinDesc(joinDesc); - } - - // If mergeJoinOperator has n (n>=2) tables, first join tables0 and table1, and output all columns of tables0 and - // tables1, get result table_0_1. Then use table_0_1 to join tables2, and outout all columns, get result - // tables_0_1_2. Then use the result table_0_..._n-1 join table_n and output the required columns. - @Override - protected void initializeOp(Configuration hconf) throws HiveException { - int sourceNum = parentOperators.get(0).getInputObjInspectors().length; - ObjectInspector[] newInputObjInspectors = new ObjectInspector[sourceNum]; - for (int i = 0; i < sourceNum; i++) { - newInputObjInspectors[i] = ((StructObjectInspector) inputObjInspectors[0]).getAllStructFieldRefs().get(i) - .getFieldObjectInspector(); - } - inputObjInspectors = newInputObjInspectors; - super.initializeOp(hconf); - fetchDone = new boolean[sourceNum]; - streamFactories = new OmniSmjStreamedTableWithExprOperatorFactory[sourceNum - 1]; - streamOperators = new OmniOperator[sourceNum - 1]; - bufferFactories = new OmniSmjBufferedTableWithExprOperatorFactory[sourceNum - 1]; - bufferOperators = new OmniOperator[sourceNum - 1]; - resCode = new int[sourceNum - 1]; - flowControlCode = new int[sourceNum - 1]; - Arrays.fill(flowControlCode, SMJ_NEED_ADD_STREAM_TBL_DATA); - streamData = new Queue[sourceNum - 1]; - bufferData = new Queue[sourceNum - 1]; - for (int i = 0; i < streamData.length; i++) { - streamData[i] = new LinkedList<>(); - bufferData[i] = new LinkedList<>(); - } - streamTypes = new DataType[sourceNum - 1][]; - bufferTypes = new DataType[sourceNum - 1][]; - for (int i = 1; i < sourceNum; i++) { - generateOmniOperator(i, true); - } - generateOmniOperator(streamFactories.length, false); - } - - private void generateOmniOperator(int bufferIndex, boolean getAll) { - int opIndex = bufferIndex - 1; - List streamAliasList = new ArrayList<>(); - for (int i = 0; i < bufferIndex; i++) { - streamAliasList.add(i); - } - streamFactories[opIndex] = (OmniSmjStreamedTableWithExprOperatorFactory) getFactory(streamAliasList, null, - getAll, opIndex); - streamOperators[opIndex] = streamFactories[opIndex].createOperator(); - bufferFactories[opIndex] = (OmniSmjBufferedTableWithExprOperatorFactory) getFactory(Arrays.asList(bufferIndex), - streamFactories[opIndex], getAll, opIndex); - bufferOperators[opIndex] = bufferFactories[opIndex].createOperator(); - } - - private OmniOperatorFactory getFactory(List aliasList, - OmniSmjStreamedTableWithExprOperatorFactory streamFactory, - boolean getAll, int opIndex) { - List inputFields = aliasList.stream() - .flatMap(alias -> ((StructObjectInspector) inputObjInspectors[alias]).getAllStructFieldRefs().stream() - .flatMap(keyValue -> ((StructObjectInspector) keyValue.getFieldObjectInspector()) - .getAllStructFieldRefs().stream())).collect(Collectors.toList()); - DataType[] inputTypes = new DataType[inputFields.size()]; - List> colNameToId = new ArrayList<>(); - aliasList.forEach(a -> colNameToId.add(new HashMap<>())); - int[] fieldNum = new int[aliasList.size()]; - fieldNum[0] = ((StructObjectInspector) inputObjInspectors[aliasList.get(0)]).getAllStructFieldRefs().stream() - .mapToInt(keyValue -> ((StructObjectInspector) keyValue.getFieldObjectInspector()) - .getAllStructFieldRefs().size()).sum(); - for (int i = 1; i < fieldNum.length; i++) { - fieldNum[i] = fieldNum[i - 1] - + ((StructObjectInspector) inputObjInspectors[aliasList.get(i)]).getAllStructFieldRefs().stream() - .mapToInt(keyValue -> ((StructObjectInspector) keyValue.getFieldObjectInspector()) - .getAllStructFieldRefs().size()).sum(); - } - int tagIndex = 0; - for (int i = 0; i < inputFields.size(); i++) { - if (i >= fieldNum[tagIndex]) { - ++tagIndex; - } - inputTypes[i] = TypeUtils.buildInputDataType(((AbstractPrimitiveObjectInspector) inputFields.get(i).getFieldObjectInspector()).getTypeInfo()); - colNameToId.get(tagIndex).put(inputFields.get(i).getFieldName(), i); - } - int[] outputCols; - if (getAll) { - outputCols = new int[inputTypes.length]; - for (int i = 0; i < inputTypes.length; i++) { - outputCols[i] = i; - } - } else { - int start = 0; - outputCols = new int[aliasList.stream().mapToInt(a -> joinValuesObjectInspectors[a].size()).sum()]; - for (int i = 0; i < aliasList.size(); i++) { - List outputFieldsName = getExprNodeColumnEvaluator(joinValues[aliasList.get(i)]).stream() - .map(evaluator -> ((ExprNodeColumnEvaluator) evaluator).getExpr().getColumn() - .split("\\.")[1]).collect(Collectors.toList()); - for (int j = start; j < start + outputFieldsName.size(); j++) { - outputCols[j] = colNameToId.get(i).get(outputFieldsName.get(j - start)); - } - start += outputFieldsName.size(); - } - } - String[] hashKey = getHashKey(aliasList, streamFactory, opIndex, colNameToId); - JoinType joinType = JOIN_TYPE_MAP.get(condn[opIndex].getType()); - if (streamFactory == null) { - Optional filter = generateFilter(opIndex); - streamTypes[opIndex] = inputTypes; - return new OmniSmjStreamedTableWithExprOperatorFactory(inputTypes, hashKey, outputCols, joinType, filter); - } else { - bufferTypes[opIndex] = inputTypes; - return new OmniSmjBufferedTableWithExprOperatorFactory(inputTypes, hashKey, outputCols, streamFactory); - } - } - - // sql like cs1.cs_warehouse_sk <> cs2.cs_warehouse_sk will have - // residualJoinFilters - private Optional generateFilter(int opIndex) { - if (residualJoinFilters == null || residualJoinFilters.get(opIndex) == null) { - return Optional.empty(); - } - int bufferIndex = opIndex + 1; - List inspectors = IntStream.range(0, bufferIndex + 1).boxed() - .flatMap(tableIndex -> ((StructObjectInspector) inputObjInspectors[tableIndex]).getAllStructFieldRefs() - .stream().flatMap(keyValue -> ((StructObjectInspector) keyValue.getFieldObjectInspector()) - .getAllStructFieldRefs().stream())).sorted(Comparator.comparing(StructField::getFieldName)) - .map(field -> field.getFieldObjectInspector()).collect(Collectors.toList()); - Map inputColNameToExprName = new HashMap<>(); - for (Map.Entry entry : conf.getColumnExprMap().entrySet()) { - ExprNodeColumnDesc exprNodeColumnDesc = (ExprNodeColumnDesc) entry.getValue(); - inputColNameToExprName.put(exprNodeColumnDesc.getColumn().replace("VALUE.", "").replace("KEY.", ""), entry.getKey()); - } - List fieldNames = conf.getColumnExprMap().keySet().stream().sorted().collect(Collectors.toList()); - StructObjectInspector exprObjInspector = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, - inspectors); - BaseExpression root = ExpressionUtils.build(getResidualFilter(), exprObjInspector); - return Optional.of(root.toString()); - } - - private ExprNodeGenericFuncDesc getResidualFilter() { - List filters = residualJoinFilters.stream().map(ExprNodeEvaluator::getExpr).collect(Collectors.toList()); - if (filters.size() ==1) { - return (ExprNodeGenericFuncDesc) filters.get(0); - } - try { - return ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPAnd(), filters); - } catch (UDFArgumentException e) { - throw new RuntimeException("wrong udf", e); - } - } - - private String[] getHashKey(List aliasList, OmniSmjStreamedTableWithExprOperatorFactory streamFactory, - int index, List> colNameToId) { - List expressions = new ArrayList<>(); - int keyIndex = streamFactory == null ? condn[index].getLeft() : condn[index].getRight(); - for (int i = 0; i < aliasList.size(); i++) { - if (aliasList.get(i) != keyIndex) { - continue; - } - int finalI = i; - expressions = ((StructObjectInspector) ((StructObjectInspector) inputObjInspectors[aliasList.get(i)]) - .getAllStructFieldRefs().get(0).getFieldObjectInspector()).getAllStructFieldRefs().stream() - .map(field -> TypeUtils.buildExpression( - ((AbstractPrimitiveObjectInspector) field.getFieldObjectInspector()).getTypeInfo(), - colNameToId.get(finalI).get(field.getFieldName()))).collect(Collectors.toList()); - } - return expressions.toArray(new String[0]); - } - - @Override - public void endGroup() throws HiveException { - // we do not want the end group to cause a checkAndGenObject - defaultEndGroup(); - } - - @Override - public void startGroup() throws HiveException { - // we do not want the start group to cause a checkAndGenObject - defaultStartGroup(); - } - - @Override - public void process(Object row, int tag) throws HiveException { - VecBatch input = (VecBatch) row; - if (tag == 0) { - streamData[0].offer(input); - } else if (tag >= 1) { - bufferData[tag - 1].offer(input); - } - processOmni(0, 1); - for (int opIndex = 1; opIndex < streamFactories.length; opIndex++) { - if (!streamData[opIndex].isEmpty()) { - processOmni(opIndex, opIndex + 1); - } - } - } - - protected void processOmni(int opIndex, int bufferIndex) throws HiveException { - if (flowControlCode[opIndex] != SCAN_FINISH && resCode[opIndex] == RES_INIT) { - if (flowControlCode[opIndex] == SMJ_NEED_ADD_STREAM_TBL_DATA) { - processOmniSmj(opIndex, opIndex, streamData, streamOperators, SMJ_NEED_ADD_STREAM_TBL_DATA, streamTypes); - } else { - processOmniSmj(opIndex, bufferIndex, bufferData, bufferOperators, SMJ_NEED_ADD_BUFFERED_TBL_DATA, bufferTypes); - } - } - if (resCode[opIndex] == SMJ_FETCH_JOIN_DATA) { - output = bufferOperators[opIndex].getOutput(); - while (!getDone() && output.hasNext()) { - VecBatch vecBatch = output.next(); - if (streamFactories.length > opIndex + 1) { - if (flowControlCode[opIndex + 1] != SCAN_FINISH) { - streamData[opIndex + 1].offer(vecBatch); - processOmni(opIndex + 1, opIndex + 2); - } else { - vecBatch.releaseAllVectors(); - vecBatch.close(); - } - } else { - forward(vecBatch, outputObjInspector); - } - } - resCode[opIndex] = RES_INIT; - } - } - - /** - * processOmniSmj - * - * @param opIndex 0 is the first join, 1is the second join - * @param dataIndex data source index, indicate table0, table1, table2 - * @param data data queue - * @param operators streamOperators or bufferOperators - * @param controlCode flowControlCode - * @param types bufferTypes or streamTypes - * @throws HiveException HiveException - */ - protected void processOmniSmj(int opIndex, int dataIndex, Queue[] data, OmniOperator[] operators, - int controlCode, DataType[][] types) throws HiveException { - if (data[opIndex].isEmpty() && fetchDone[dataIndex]) { - setStatus(operators[opIndex].addInput(createEofVecBatch(types[opIndex])), opIndex); - } else { - while (flowControlCode[opIndex] == controlCode && resCode[opIndex] == RES_INIT && !data[opIndex].isEmpty()) { - setStatus(operators[opIndex].addInput(data[opIndex].poll()), opIndex); - } - } - } - - protected void setStatus(int code, int tag) { - flowControlCode[tag] = code >> 16; - resCode[tag] = code & 0xFFFF; - } - - @Override - public String getName() { - return getOperatorName(); - } - - public static String getOperatorName() { - return "MERGEJOIN_OMNI"; - } - - @Override - public OperatorType getType() { - return OperatorType.MERGEJOIN; - } - - @Override - public void close(boolean abort) throws HiveException { - if (!allInitializedParentsAreClosed()) { - return; - } - if (sources == null) { - fetchDone = new boolean[]{true, true, true}; - } - Set needDeal = new HashSet<>(); - for (int opIndex = streamFactories.length - 1; opIndex >= 0; opIndex--) { - if (flowControlCode[opIndex] == SCAN_FINISH) { - break; - } - needDeal.add(opIndex); - } - for (int opIndex = 0; opIndex < streamFactories.length; opIndex++) { - if (!needDeal.contains(opIndex)) { - continue; - } - while (!getDone() && flowControlCode[opIndex] != SCAN_FINISH && flowControlCode[opIndex] != 0) { - processOmni(opIndex, opIndex + 1); - } - } - super.close(abort); - } - - protected VecBatch createEofVecBatch(DataType[] dataTypes) { - Vec[] vecs = new Vec[dataTypes.length]; - for (int i = 0; i < dataTypes.length; i++) { - switch (dataTypes[i].getId()) { - case OMNI_INT: - case OMNI_DATE32: - vecs[i] = new IntVec(0); - break; - case OMNI_LONG: - case OMNI_DECIMAL64: - vecs[i] = new LongVec(0); - break; - case OMNI_DOUBLE: - vecs[i] = new DoubleVec(0); - break; - case OMNI_BOOLEAN: - vecs[i] = new BooleanVec(0); - break; - case OMNI_CHAR: - case OMNI_VARCHAR: - vecs[i] = new VarcharVec(0); - break; - case OMNI_DECIMAL128: - vecs[i] = new Decimal128Vec(0); - break; - case OMNI_SHORT: - vecs[i] = new ShortVec(0); - break; - default: - throw new IllegalArgumentException(String.format("VecType %s is not supported in %s yet", - dataTypes[i].getClass().getSimpleName(), this.getClass().getSimpleName())); - } - } - return new VecBatch(vecs, 0); - } - - public boolean[] getFetchDone() { - return fetchDone; - } - - @Override - protected void forward(Object row, ObjectInspector rowInspector) throws HiveException { - VecBatch vecBatch = (VecBatch) row; - this.runTimeNumRows += vecBatch.getRowCount(); - if (getDone()) { - vecBatch.releaseAllVectors(); - vecBatch.close(); - return; - } - int childrenDone = 0; - for (int i = 0; i < childOperatorsArray.length; i++) { - Operator o = childOperatorsArray[i]; - if (o.getDone()) { - childrenDone++; - } else { - o.process(row, childOperatorsTag[i]); - } - } - - if (childrenDone != 0 && childrenDone == childOperatorsArray.length) { - setDone(true); - vecBatch.releaseAllVectors(); - vecBatch.close(); - } - } - - @Override - public void closeOp(boolean abort) throws HiveException { - for (int i = 0; i < streamOperators.length; i++) { - streamOperators[i].close(); - bufferOperators[i].close(); - streamFactories[i].close(); - bufferFactories[i].close(); - for (VecBatch vecBatch : streamData[i]) { - vecBatch.releaseAllVectors(); - vecBatch.close(); - } - for (VecBatch vecBatch : bufferData[i]) { - vecBatch.releaseAllVectors(); - vecBatch.close(); - } - } - output = null; - super.closeOp(abort); - } -} diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMapJoinOperator.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMapJoinOperator.java index 75ab7c35dda65b9aa8b25bdfd4acacb4d4fe74a6..31f9feea6964a6c8746b61ac3a473a86a8ffc796 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMapJoinOperator.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMapJoinOperator.java @@ -198,7 +198,7 @@ public class OmniMapJoinOperator extends AbstractMapJoinOperator buildInspectors[buildIndex] = new ArrayList<>(joinKeysObjectInspectors[buildIndex]); buildInspectors[buildIndex].addAll(joinValuesObjectInspectors[buildIndex]); } - JoinType joinType = JOIN_TYPE_MAP.get(condn[buildIndexes.size() - 1].getType()); + JoinType joinType = JOIN_TYPE_MAP.get(condn[Math.min(posBigTable, condn.length -1)].getType()); DataType[] buildTypes = getTypeFromInspectors(Arrays.stream(buildInspectors).filter(Objects::nonNull) .flatMap(List::stream).collect(Collectors.toList())); omniHashBuilderWithExprOperatorFactory = getOmniHashBuilderWithExprOperatorFactory(joinType, buildTypes, @@ -729,17 +729,38 @@ public class OmniMapJoinOperator extends AbstractMapJoinOperator if ((conf.getResidualFilterExprs() == null || conf.getResidualFilterExprs().isEmpty()) && joinFilter == null) { return Optional.empty(); } - Map inputColNameToExprName = new HashMap<>(); + Map> inputColNameToExprName = new HashMap<>(); for (Map.Entry entry : conf.getColumnExprMap().entrySet()) { ExprNodeColumnDesc exprNodeColumnDesc = (ExprNodeColumnDesc) entry.getValue(); - inputColNameToExprName.put(exprNodeColumnDesc.getColumn(), entry.getKey()); + if (!inputColNameToExprName.containsKey(exprNodeColumnDesc.getColumn())) { + inputColNameToExprName.put(exprNodeColumnDesc.getColumn(), new ArrayList<>()); + } + inputColNameToExprName.get(exprNodeColumnDesc.getColumn()).add(entry.getKey()); } List fields = ((StructObjectInspector) inputObjInspectors[posBigTable]).getAllStructFieldRefs(); - List fieldNames = fields.stream().map(StructField::getFieldName).collect(Collectors.toList()); + List fieldNames = fields.stream().map(field -> { + String key = field.getFieldName().replace("value.", "VALUE.").replace("key.", "KEY."); + if (inputColNameToExprName.containsKey(key)) { + List exprNames = inputColNameToExprName.get(key); + return exprNames.get(Math.min(exprNames.size() - 1, posBigTable)).replace("value.", "").replace("key.", ""); + } else { + return field.getFieldName().replace("value.", "").replace("key.", ""); + } + } + ).collect(Collectors.toList()); List inspectors = fields.stream().map(StructField::getFieldObjectInspector).collect(Collectors.toList()); for (int buildIndex : buildIndexes) { fields = ((StructObjectInspector) inputObjInspectors[buildIndex]).getAllStructFieldRefs(); - fieldNames.addAll(fields.stream().map(field -> inputColNameToExprName.getOrDefault(field.getFieldName(), field.getFieldName())).collect(Collectors.toList())); + fieldNames.addAll(fields.stream().map(field -> { + String key = field.getFieldName().replace("value.", "VALUE.").replace("key.", "KEY."); + if (inputColNameToExprName.containsKey(key)) { + List exprNames = inputColNameToExprName.get(key); + return exprNames.get(Math.min(exprNames.size() - 1, buildIndex)).replace("value.", "").replace("key.", ""); + } else { + return field.getFieldName().replace("value.", "").replace("key.", ""); + } + } + ).collect(Collectors.toList())); inspectors.addAll(fields.stream().map(StructField::getFieldObjectInspector).collect(Collectors.toList())); } StructObjectInspector exprObjInspector = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, inspectors); diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMergeJoinOperator.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMergeJoinOperator.java index e5364729b4b4306934a9839fbb6cea4fe598ddbf..bc2f2bedb628f33ed75072fd9aadd4760cb0bd8a 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMergeJoinOperator.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMergeJoinOperator.java @@ -18,20 +18,94 @@ package com.huawei.boostkit.hive; +import com.huawei.boostkit.hive.expression.BaseExpression; +import com.huawei.boostkit.hive.expression.ExpressionUtils; +import com.huawei.boostkit.hive.expression.TypeUtils; +import nova.hetu.omniruntime.constants.JoinType; import nova.hetu.omniruntime.operator.OmniOperator; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.join.OmniSmjBufferedTableWithExprOperatorFactory; +import nova.hetu.omniruntime.operator.join.OmniSmjStreamedTableWithExprOperatorFactory; import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.vector.BooleanVec; +import nova.hetu.omniruntime.vector.Decimal128Vec; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.ShortVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; import nova.hetu.omniruntime.vector.VecBatch; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.ql.CompilationOpContext; +import org.apache.hadoop.hive.ql.exec.CommonJoinOperator; +import org.apache.hadoop.hive.ql.exec.ExprNodeColumnEvaluator; +import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator; import org.apache.hadoop.hive.ql.exec.MapredContext; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.Utilities; +import org.apache.hadoop.hive.ql.exec.tez.RecordSource; import org.apache.hadoop.hive.ql.exec.tez.TezContext; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.CommonMergeJoinDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.api.OperatorType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.AbstractPrimitiveObjectInspector; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.Queue; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; -public class OmniMergeJoinOperator extends OmniJoinOperator { +import static com.huawei.boostkit.hive.JoinUtils.getExprNodeColumnEvaluator; +import static com.huawei.boostkit.hive.OmniMapJoinOperator.JOIN_TYPE_MAP; + +public class OmniMergeJoinOperator extends CommonJoinOperator implements Serializable { + private static final long serialVersionUID = 1L; + + protected static final int SMJ_NEED_ADD_STREAM_TBL_DATA = 2; + protected static final int SMJ_NEED_ADD_BUFFERED_TBL_DATA = 3; + protected static final int SCAN_FINISH = 4; + protected static final int RES_INIT = 0; + protected static final int SMJ_FETCH_JOIN_DATA = 5; + + protected transient RecordSource[] sources; + protected transient boolean[] fetchDone; + + protected transient OmniSmjBufferedTableWithExprOperatorFactory[] bufferFactories; + protected transient OmniSmjStreamedTableWithExprOperatorFactory[] streamFactories; + protected transient OmniOperator[] bufferOperators; + protected transient OmniOperator[] streamOperators; + + protected transient int[] resCode; + protected transient int[] flowControlCode; + protected transient Queue[] streamData; + protected transient Queue[] bufferData; + protected transient DataType[][] streamTypes; + protected transient DataType[][] bufferTypes; + private transient Iterator output; protected int posBigTable; private OmniVectorOperator omniVectorOperator; @@ -57,13 +131,194 @@ public class OmniMergeJoinOperator extends OmniJoinOperator { // all columns of table0 and table1. // Then use the output to join table2, and output required columns. protected void initializeOp(Configuration hconf) throws HiveException { + int sourceNum = parentOperators.get(0).getInputObjInspectors().length; + ObjectInspector[] newInputObjInspectors = new ObjectInspector[sourceNum]; + for (int i = 0; i < sourceNum; i++) { + newInputObjInspectors[i] = ((StructObjectInspector) inputObjInspectors[0]).getAllStructFieldRefs().get(i) + .getFieldObjectInspector(); + } + inputObjInspectors = newInputObjInspectors; super.initializeOp(hconf); + fetchDone = new boolean[sourceNum]; + streamFactories = new OmniSmjStreamedTableWithExprOperatorFactory[sourceNum - 1]; + streamOperators = new OmniOperator[sourceNum - 1]; + bufferFactories = new OmniSmjBufferedTableWithExprOperatorFactory[sourceNum - 1]; + bufferOperators = new OmniOperator[sourceNum - 1]; + resCode = new int[sourceNum - 1]; + flowControlCode = new int[sourceNum - 1]; + Arrays.fill(flowControlCode, SMJ_NEED_ADD_STREAM_TBL_DATA); + streamData = new Queue[sourceNum - 1]; + bufferData = new Queue[sourceNum - 1]; + for (int i = 0; i < streamData.length; i++) { + streamData[i] = new LinkedList<>(); + bufferData[i] = new LinkedList<>(); + } + streamTypes = new DataType[sourceNum - 1][]; + bufferTypes = new DataType[sourceNum - 1][]; + for (int i = 1; i < sourceNum; i++) { + generateOmniOperator(i, true); + } + generateOmniOperator(streamFactories.length, false); sources = ((TezContext) MapredContext.get()).getRecordSources(); if (parentOperators.get(0) instanceof OmniVectorOperator) { omniVectorOperator = (OmniVectorOperator) parentOperators.get(0); } } + private void generateOmniOperator(int bufferIndex, boolean getAll) throws HiveException{ + int opIndex = bufferIndex - 1; + List streamAliasList = new ArrayList<>(); + for (int i = 0; i < bufferIndex; i++) { + streamAliasList.add(i); + } + streamFactories[opIndex] = (OmniSmjStreamedTableWithExprOperatorFactory) getFactory(streamAliasList, null, + getAll, opIndex); + streamOperators[opIndex] = streamFactories[opIndex].createOperator(); + bufferFactories[opIndex] = (OmniSmjBufferedTableWithExprOperatorFactory) getFactory(Arrays.asList(bufferIndex), + streamFactories[opIndex], getAll, opIndex); + bufferOperators[opIndex] = bufferFactories[opIndex].createOperator(); + } + + private OmniOperatorFactory getFactory(List aliasList, + OmniSmjStreamedTableWithExprOperatorFactory streamFactory, + boolean getAll, int opIndex) throws HiveException{ + List inputFields = aliasList.stream() + .flatMap(alias -> ((StructObjectInspector) inputObjInspectors[alias]).getAllStructFieldRefs().stream() + .flatMap(keyValue -> ((StructObjectInspector) keyValue.getFieldObjectInspector()) + .getAllStructFieldRefs().stream())).collect(Collectors.toList()); + DataType[] inputTypes = new DataType[inputFields.size()]; + List> colNameToId = new ArrayList<>(); + aliasList.forEach(a -> colNameToId.add(new HashMap<>())); + int[] fieldNum = new int[aliasList.size()]; + fieldNum[0] = ((StructObjectInspector) inputObjInspectors[aliasList.get(0)]).getAllStructFieldRefs().stream() + .mapToInt(keyValue -> ((StructObjectInspector) keyValue.getFieldObjectInspector()) + .getAllStructFieldRefs().size()).sum(); + for (int i = 1; i < fieldNum.length; i++) { + fieldNum[i] = fieldNum[i - 1] + + ((StructObjectInspector) inputObjInspectors[aliasList.get(i)]).getAllStructFieldRefs().stream() + .mapToInt(keyValue -> ((StructObjectInspector) keyValue.getFieldObjectInspector()) + .getAllStructFieldRefs().size()).sum(); + } + int tagIndex = 0; + for (int i = 0; i < inputFields.size(); i++) { + if (i >= fieldNum[tagIndex]) { + ++tagIndex; + } + inputTypes[i] = TypeUtils.buildInputDataType(((AbstractPrimitiveObjectInspector) inputFields.get(i).getFieldObjectInspector()).getTypeInfo()); + colNameToId.get(tagIndex).put(inputFields.get(i).getFieldName(), i); + } + int[] outputCols; + if (getAll) { + outputCols = new int[inputTypes.length]; + for (int i = 0; i < inputTypes.length; i++) { + outputCols[i] = i; + } + } else { + int start = 0; + outputCols = new int[aliasList.stream().mapToInt(a -> joinValuesObjectInspectors[a].size()).sum()]; + for (int i = 0; i < aliasList.size(); i++) { + List outputFieldsName = getExprNodeColumnEvaluator(joinValues[aliasList.get(i)]).stream() + .map(evaluator -> ((ExprNodeColumnEvaluator) evaluator).getExpr().getColumn() + .split("\\.")[1]).collect(Collectors.toList()); + for (int j = start; j < start + outputFieldsName.size(); j++) { + outputCols[j] = colNameToId.get(i).get(outputFieldsName.get(j - start)); + } + start += outputFieldsName.size(); + } + } + String[] hashKey = getHashKey(aliasList, streamFactory, opIndex, colNameToId); + JoinType joinType = JOIN_TYPE_MAP.get(condn[opIndex].getType()); + if (streamFactory == null) { + Optional filter = generateFilter(opIndex); + streamTypes[opIndex] = inputTypes; + return new OmniSmjStreamedTableWithExprOperatorFactory(inputTypes, hashKey, outputCols, joinType, filter); + } else { + bufferTypes[opIndex] = inputTypes; + return new OmniSmjBufferedTableWithExprOperatorFactory(inputTypes, hashKey, outputCols, streamFactory); + } + } + + // sql like cs1.cs_warehouse_sk <> cs2.cs_warehouse_sk will have + // residualJoinFilters + private Optional generateFilter(int opIndex) throws HiveException{ + ExprNodeGenericFuncDesc joinFilter = getJoinFilter(opIndex); + if ((residualJoinFilters == null || residualJoinFilters.get(opIndex) == null) && joinFilter == null) { + return Optional.empty(); + } + BaseExpression root; + if (joinFilter == null) { + int bufferIndex = opIndex + 1; + List inspectors = IntStream.range(0, bufferIndex + 1).boxed() + .flatMap(tableIndex -> ((StructObjectInspector) inputObjInspectors[tableIndex]).getAllStructFieldRefs() + .stream().flatMap(keyValue -> ((StructObjectInspector) keyValue.getFieldObjectInspector()) + .getAllStructFieldRefs().stream())).sorted(Comparator.comparing(StructField::getFieldName)) + .map(field -> field.getFieldObjectInspector()).collect(Collectors.toList()); + Map inputColNameToExprName = new HashMap<>(); + for (Map.Entry entry : conf.getColumnExprMap().entrySet()) { + ExprNodeColumnDesc exprNodeColumnDesc = (ExprNodeColumnDesc) entry.getValue(); + inputColNameToExprName.put(exprNodeColumnDesc.getColumn().replace("VALUE.", "").replace("KEY.", ""), entry.getKey()); + } + List fieldNames = conf.getColumnExprMap().keySet().stream().sorted().collect(Collectors.toList()); + StructObjectInspector exprObjInspector = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, + inspectors); + root = ExpressionUtils.build(getResidualFilter(), exprObjInspector); + } else { + StructObjectInspector flattenInspector = Utilities.constructVectorizedReduceRowOI((StructObjectInspector) ((StructObjectInspector) inputObjInspectors[opIndex]).getAllStructFieldRefs().get(0).getFieldObjectInspector(), + (StructObjectInspector) ((StructObjectInspector) inputObjInspectors[opIndex]).getAllStructFieldRefs().get(1).getFieldObjectInspector()); + List inspectors = flattenInspector.getAllStructFieldRefs().stream().map(field -> field.getFieldObjectInspector()).collect(Collectors.toList()); + List fieldNames = flattenInspector.getAllStructFieldRefs().stream().map(field -> field.getFieldName()).collect(Collectors.toList()); + StructObjectInspector exprObjInspector = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, + inspectors); + root = ExpressionUtils.build(joinFilter, exprObjInspector); + } + return Optional.of(root.toString()); + } + + private ExprNodeGenericFuncDesc getJoinFilter(int opIndex) { + List filters = joinFilters[opIndex].stream().map(ExprNodeEvaluator::getExpr).collect(Collectors.toList()); + if (filters.isEmpty()) { + return null; + } + if (filters.size() == 1) { + return (ExprNodeGenericFuncDesc) filters.get(0); + } + try { + return ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPAnd(), filters); + } catch (UDFArgumentException e) { + throw new RuntimeException("wrong UDF", e); + } + } + + private ExprNodeGenericFuncDesc getResidualFilter() { + List filters = residualJoinFilters.stream().map(ExprNodeEvaluator::getExpr).collect(Collectors.toList()); + if (filters.size() ==1) { + return (ExprNodeGenericFuncDesc) filters.get(0); + } + try { + return ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPAnd(), filters); + } catch (UDFArgumentException e) { + throw new RuntimeException("wrong udf", e); + } + } + + private String[] getHashKey(List aliasList, OmniSmjStreamedTableWithExprOperatorFactory streamFactory, + int index, List> colNameToId) { + List expressions = new ArrayList<>(); + int keyIndex = streamFactory == null ? condn[index].getLeft() : condn[index].getRight(); + for (int i = 0; i < aliasList.size(); i++) { + if (aliasList.get(i) != keyIndex) { + continue; + } + int finalI = i; + expressions = ((StructObjectInspector) ((StructObjectInspector) inputObjInspectors[aliasList.get(i)]) + .getAllStructFieldRefs().get(0).getFieldObjectInspector()).getAllStructFieldRefs().stream() + .map(field -> TypeUtils.buildExpression( + ((AbstractPrimitiveObjectInspector) field.getFieldObjectInspector()).getTypeInfo(), + colNameToId.get(finalI).get(field.getFieldName()))).collect(Collectors.toList()); + } + return expressions.toArray(new String[0]); + } + @Override public void process(Object row, int tag) throws HiveException { VecBatch input = (VecBatch) row; @@ -92,7 +347,36 @@ public class OmniMergeJoinOperator extends OmniJoinOperator { } } - @Override + protected void processOmni(int opIndex, int bufferIndex) throws HiveException { + if (flowControlCode[opIndex] != SCAN_FINISH && resCode[opIndex] == RES_INIT) { + if (flowControlCode[opIndex] == SMJ_NEED_ADD_STREAM_TBL_DATA) { + processOmniSmj(opIndex, opIndex, streamData, streamOperators, SMJ_NEED_ADD_STREAM_TBL_DATA, streamTypes); + } else { + processOmniSmj(opIndex, bufferIndex, bufferData, bufferOperators, SMJ_NEED_ADD_BUFFERED_TBL_DATA, bufferTypes); + } + } + if (resCode[opIndex] == SMJ_FETCH_JOIN_DATA) { + output = bufferOperators[opIndex].getOutput(); + while (!getDone() && output.hasNext()) { + VecBatch vecBatch = output.next(); + if (streamFactories.length <= opIndex + 1) { + forward(vecBatch, outputObjInspector); + continue; + } + if (flowControlCode[opIndex + 1] == SCAN_FINISH) { + vecBatch.releaseAllVectors(); + vecBatch.close(); + continue; + } + streamData[opIndex + 1].offer(vecBatch); + if (flowControlCode[opIndex + 1] == SMJ_NEED_ADD_STREAM_TBL_DATA) { + processOmni(opIndex + 1, opIndex + 2); + } + } + resCode[opIndex] = RES_INIT; + } + } + protected void processOmniSmj(int opIndex, int dataIndex, Queue[] data, OmniOperator[] operators, int controlCode, DataType[][] types) throws HiveException { if (!data[opIndex].isEmpty()) { @@ -124,6 +408,11 @@ public class OmniMergeJoinOperator extends OmniJoinOperator { } } + protected void setStatus(int code, int tag) { + flowControlCode[tag] = code >> 16; + resCode[tag] = code & 0xFFFF; + } + public int getPosBigTable() { return posBigTable; } @@ -139,4 +428,130 @@ public class OmniMergeJoinOperator extends OmniJoinOperator { public void publicSetDone(boolean done) { this.done = done; } + + @Override + public String getName() { + return getOperatorName(); + } + + public static String getOperatorName() { + return "MERGEJOIN_OMNI"; + } + + @Override + public OperatorType getType() { + return OperatorType.MERGEJOIN; + } + + @Override + public void close(boolean abort) throws HiveException { + if (!allInitializedParentsAreClosed()) { + return; + } + if (sources == null) { + fetchDone = new boolean[]{true, true, true}; + } + Set needDeal = new HashSet<>(); + for (int opIndex = streamFactories.length - 1; opIndex >= 0; opIndex--) { + if (flowControlCode[opIndex] == SCAN_FINISH) { + break; + } + needDeal.add(opIndex); + } + for (int opIndex = 0; opIndex < streamFactories.length; opIndex++) { + if (!needDeal.contains(opIndex)) { + continue; + } + while (!getDone() && flowControlCode[opIndex] != SCAN_FINISH && flowControlCode[opIndex] != 0) { + processOmni(opIndex, opIndex + 1); + } + } + super.close(abort); + } + + protected VecBatch createEofVecBatch(DataType[] dataTypes) { + Vec[] vecs = new Vec[dataTypes.length]; + for (int i = 0; i < dataTypes.length; i++) { + switch (dataTypes[i].getId()) { + case OMNI_INT: + case OMNI_DATE32: + vecs[i] = new IntVec(0); + break; + case OMNI_LONG: + case OMNI_DECIMAL64: + vecs[i] = new LongVec(0); + break; + case OMNI_DOUBLE: + vecs[i] = new DoubleVec(0); + break; + case OMNI_BOOLEAN: + vecs[i] = new BooleanVec(0); + break; + case OMNI_CHAR: + case OMNI_VARCHAR: + vecs[i] = new VarcharVec(0); + break; + case OMNI_DECIMAL128: + vecs[i] = new Decimal128Vec(0); + break; + case OMNI_SHORT: + vecs[i] = new ShortVec(0); + break; + default: + throw new IllegalArgumentException(String.format("VecType %s is not supported in %s yet", + dataTypes[i].getClass().getSimpleName(), this.getClass().getSimpleName())); + } + } + return new VecBatch(vecs, 0); + } + + public boolean[] getFetchDone() { + return fetchDone; + } + + @Override + protected void forward(Object row, ObjectInspector rowInspector) throws HiveException { + VecBatch vecBatch = (VecBatch) row; + this.runTimeNumRows += vecBatch.getRowCount(); + if (getDone()) { + vecBatch.releaseAllVectors(); + vecBatch.close(); + return; + } + int childrenDone = 0; + for (int i = 0; i < childOperatorsArray.length; i++) { + Operator o = childOperatorsArray[i]; + if (o.getDone()) { + childrenDone++; + } else { + o.process(row, childOperatorsTag[i]); + } + } + + if (childrenDone != 0 && childrenDone == childOperatorsArray.length) { + setDone(true); + vecBatch.releaseAllVectors(); + vecBatch.close(); + } + } + + @Override + public void closeOp(boolean abort) throws HiveException { + for (int i = 0; i < streamOperators.length; i++) { + streamOperators[i].close(); + bufferOperators[i].close(); + streamFactories[i].close(); + bufferFactories[i].close(); + for (VecBatch vecBatch : streamData[i]) { + vecBatch.releaseAllVectors(); + vecBatch.close(); + } + for (VecBatch vecBatch : bufferData[i]) { + vecBatch.releaseAllVectors(); + vecBatch.close(); + } + } + output = null; + super.closeOp(abort); + } } \ No newline at end of file diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMergeJoinWithSortOperator.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMergeJoinWithSortOperator.java index c6c6c3df208c50049a4b9c8c73a983b1638419ac..a3b41444b3b76a193575f78dd860c5e914337d61 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMergeJoinWithSortOperator.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMergeJoinWithSortOperator.java @@ -80,7 +80,7 @@ public class OmniMergeJoinWithSortOperator extends OmniMergeJoinOperator { setStatus(operators[opIndex].addInput(data[opIndex].poll()), opIndex); } } - } else if (!omniVectorWithSortOperator.outputs[dataIndex].hasNext()) { + } else { setStatus(operators[opIndex].addInput(createEofVecBatch(types[opIndex])), opIndex); } } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/cache/VecBufferCache.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/cache/VecBufferCache.java index 5aa10fb19e906fe2a5c0e1300da5ac14bebe2bcc..43f39df5af2b9f44c47da1d13d03d568e1bccc50 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/cache/VecBufferCache.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/cache/VecBufferCache.java @@ -107,6 +107,7 @@ public class VecBufferCache { vec = new ShortVec(rowCount); break; case BOOLEAN: + case VOID: vec = new BooleanVec(rowCount); break; case DOUBLE: diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/Decimal64VecConverter.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/Decimal64VecConverter.java index e3ae983ec84d96f401b71437f91139c5ef9f3672..398b55d24f8634c3e14bc84338b15a855541986d 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/Decimal64VecConverter.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/Decimal64VecConverter.java @@ -117,10 +117,10 @@ public class Decimal64VecConverter extends LongVecConverter { try { if (vectorizationContext != null) { convertedDecimal64 = vectorizationContext.getDataTypePhysicalVariation( - vectorizationContext.getInputColumnIndex(fieldName)) == DataTypePhysicalVariation.DECIMAL_64; + vectorizationContext.getInputColumnIndex(fieldName.replace("key.","KEY.").replace("value.","VALUE."))) == DataTypePhysicalVariation.DECIMAL_64; } } catch (HiveException e) { - LOG.error("error occurs when finding field from vectorizationContext"); + throw new RuntimeException("error occurs when finding field from vectorizationContext", e); } return convertedDecimal64; } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/VecConverter.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/VecConverter.java index f28517f37a75bdaeb2c3c5c1b42c23adf75fcc50..e9892fe9f022f21d13628b70958803e8136b0848 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/VecConverter.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/VecConverter.java @@ -48,6 +48,7 @@ public interface VecConverter { put(PrimitiveObjectInspector.PrimitiveCategory.TIMESTAMP, new TimestampVecConverter()); put(PrimitiveObjectInspector.PrimitiveCategory.DATE, new DateVecConverter()); put(PrimitiveObjectInspector.PrimitiveCategory.DECIMAL, new DecimalVecConverter()); + put(PrimitiveObjectInspector.PrimitiveCategory.VOID, new BooleanVecConverter()); } }; diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/TypeUtils.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/TypeUtils.java index 0ecfe38bf7dbb89650b337b0bdf005d0b549cc61..1e061f0ff2d46df2c9713ea61028432d9e08a88c 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/TypeUtils.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/TypeUtils.java @@ -87,6 +87,7 @@ import org.apache.parquet.format.DecimalType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -108,6 +109,7 @@ public class TypeUtils { put(PrimitiveObjectInspector.PrimitiveCategory.INTERVAL_DAY_TIME, LongDataType.LONG); put(PrimitiveObjectInspector.PrimitiveCategory.BYTE, ShortDataType.SHORT); put(PrimitiveObjectInspector.PrimitiveCategory.FLOAT, DoubleDataType.DOUBLE); + put(PrimitiveObjectInspector.PrimitiveCategory.VOID, BooleanDataType.BOOLEAN); } }; @@ -371,15 +373,88 @@ public class TypeUtils { return true; } } - boolean anyDecimal128 = children.stream().anyMatch(child -> child.getTypeInfo() instanceof DecimalTypeInfo - && ((DecimalTypeInfo) child.getTypeInfo()).getPrecision() > 18); - if ((functionName.equals("GenericUDFOPMultiply") || functionName.equals("GenericUDFOPDivide") - || functionName.equals("GenericUDFOPMod")) && anyDecimal128) { - return true; + + if (functionName.equals("GenericUDFOPMultiply") || functionName.equals("GenericUDFOPDivide") + || functionName.equals("GenericUDFOPMod")) { + return !isValidConversion(node); + } + return false; + } + + public static boolean checkUnsupportedTimestamp(ExprNodeDesc desc) { + TypeInfo typeInfo = desc.getTypeInfo(); + if (typeInfo instanceof PrimitiveTypeInfo) { + if (typeInfo.getTypeName() != "timestamp") { + return true; + } + if (desc instanceof ExprNodeConstantDesc) { + Timestamp timeValue = (Timestamp) ((ExprNodeConstantDesc) desc).getValue(); + if (timeValue.getNanos() % 1000000 != 0) { + return true; + } + } else { + return true; + } } return false; } + public static boolean isValidConversion(ExprNodeDesc node) { + if (node instanceof ExprNodeGenericFuncDesc && node.getChildren() != null && node.getChildren().size() == 2) { + List children = node.getChildren(); + int precision = 0; + int scale = 0; + int maxScale = 0; + if (node.getTypeInfo() instanceof DecimalTypeInfo) { + precision = ((DecimalTypeInfo) node.getTypeInfo()).getPrecision(); + scale = ((DecimalTypeInfo) node.getTypeInfo()).getScale(); + } + if (children.get(0) instanceof ExprNodeConstantDesc && children.get(1) instanceof ExprNodeColumnDesc) { + Collections.swap(children, 0, 1); + } + if (children.get(0) instanceof ExprNodeColumnDesc && children.get(1) instanceof ExprNodeConstantDesc) { + ExprNodeDesc exprNodeDesc = children.get(0); + if (exprNodeDesc.getTypeInfo() instanceof DecimalTypeInfo) { + maxScale = ((DecimalTypeInfo) exprNodeDesc.getTypeInfo()).getScale(); + } + } else { + maxScale = getMaxScale(children, scale); + } + + int targetChildPrecision = 0; + int targetChildScale = 0; + for (ExprNodeDesc child : children) { + if (child.getTypeInfo() instanceof DecimalTypeInfo) { + int childScale = ((DecimalTypeInfo) child.getTypeInfo()).getScale(); + int childPrecision = ((DecimalTypeInfo) child.getTypeInfo()).getPrecision(); + if (maxScale != childScale) { + targetChildPrecision = Math.min(Math.max(childPrecision + maxScale - childScale, precision), 38); + targetChildScale = maxScale; + if (childPrecision - childScale > targetChildPrecision - targetChildScale || childScale > targetChildScale) { + return false; + } + } + } + } + return true; + } + return true; + } + + public static int getMaxScale(List children, int maxScale) { + for (ExprNodeDesc child : children) { + if (!(child.getTypeInfo() instanceof DecimalTypeInfo)) { + continue; + } + DecimalTypeInfo childTypeInfo = (DecimalTypeInfo) child.getTypeInfo(); + int childScale = childTypeInfo.getScale(); + if (childScale >= maxScale) { + maxScale = childScale; + } + } + return maxScale; + } + public static boolean checkOmniJsonWhiteList(String filterExpr, String[] projections) { // inputTypes will not be checked if parseFormat is json( == 1), // only if its parseFormat is String(==0) diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ArithmeticExpressionProcessor.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ArithmeticExpressionProcessor.java index 63cd5d7a4cf763a00819da3d2cd9c5c351b269a5..097d17cd63788d44e5d6dfb0ce35174f874147df 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ArithmeticExpressionProcessor.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ArithmeticExpressionProcessor.java @@ -18,6 +18,8 @@ package com.huawei.boostkit.hive.processor; +import static com.huawei.boostkit.hive.expression.TypeUtils.getMaxScale; + import com.huawei.boostkit.hive.expression.BaseExpression; import com.huawei.boostkit.hive.expression.CastFunctionExpression; import com.huawei.boostkit.hive.expression.DivideExpression; @@ -104,18 +106,4 @@ public class ArithmeticExpressionProcessor implements ExpressionProcessor { TypeUtils.getCharWidth(node), childPrecision, childScale); compareExpression.add(ExpressionUtils.optimizeCast(childNode, functionExpression)); } - - private int getMaxScale(List children, int maxScale) { - for (ExprNodeDesc child : children) { - if (!(child.getTypeInfo() instanceof DecimalTypeInfo)) { - continue; - } - DecimalTypeInfo childTypeInfo = (DecimalTypeInfo) child.getTypeInfo(); - int childScale = childTypeInfo.getScale(); - if (childScale >= maxScale) { - maxScale = childScale; - } - } - return maxScale; - } } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/TimestampExpressionProcessor.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/TimestampExpressionProcessor.java index 4d0c9b5166dcd64faf91105c8e6f28b4be19dbb0..fb12b7b0e6aa6804c9d0bb8fbab76a0e2b682092 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/TimestampExpressionProcessor.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/TimestampExpressionProcessor.java @@ -20,10 +20,16 @@ package com.huawei.boostkit.hive.processor; import com.huawei.boostkit.hive.expression.BaseExpression; import com.huawei.boostkit.hive.expression.CastFunctionExpression; +import com.huawei.boostkit.hive.expression.DivideExpression; import com.huawei.boostkit.hive.expression.ExpressionUtils; +import com.huawei.boostkit.hive.expression.LiteralFactor; import com.huawei.boostkit.hive.expression.TypeUtils; import com.sun.jdi.LongType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.Decimal64DataType; import nova.hetu.omniruntime.type.LongDataType; +import org.apache.hadoop.hive.common.type.Decimal128; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -32,13 +38,41 @@ public class TimestampExpressionProcessor implements ExpressionProcessor { @Override public BaseExpression process(ExprNodeGenericFuncDesc node, String operator, ObjectInspector inspector) { ExprNodeDesc exprNodeDesc = node.getChildren().get(0); - CastFunctionExpression cast = new CastFunctionExpression(LongDataType.LONG.getId().toValue(), - TypeUtils.getCharWidth(node), null, null); + BaseExpression baseExpression; + int dataType = TypeUtils.convertHiveTypeToOmniType(exprNodeDesc.getTypeInfo()); if (exprNodeDesc instanceof ExprNodeGenericFuncDesc) { - cast.add(ExpressionUtils.build((ExprNodeGenericFuncDesc) exprNodeDesc, inspector)); + baseExpression = ExpressionUtils.build((ExprNodeGenericFuncDesc) exprNodeDesc, inspector); } else { - cast.add(ExpressionUtils.createNode(exprNodeDesc, inspector)); + baseExpression = ExpressionUtils.createNode(exprNodeDesc, inspector); } - return cast; + LiteralFactor longLiteralFactorDay = new LiteralFactor<>("LITERAL", null, null, + 86400000L, null, LongDataType.LONG.getId().toValue()); + DivideExpression divideExpression = new DivideExpression(LongDataType.LONG.getId().toValue(), "MULTIPLY", null, null); + if (dataType == Decimal128DataType.DECIMAL128.getId().toValue() || dataType == DoubleDataType.DOUBLE.getId().toValue() + || dataType == Decimal64DataType.DECIMAL64.getId().toValue()) { + DivideExpression secondExpression = new DivideExpression(dataType, "MULTIPLY", null, null); + secondExpression.add(baseExpression); + LiteralFactor longLiteralFactorSecond = new LiteralFactor<>("LITERAL", null, null, + 1L, null, LongDataType.LONG.getId().toValue()); + LiteralFactor doubleLiteralFactorSecond = new LiteralFactor<>("LITERAL", null, null, + 1000.0, null, DoubleDataType.DOUBLE.getId().toValue()); + secondExpression.add(doubleLiteralFactorSecond); + CastFunctionExpression cast = new CastFunctionExpression(LongDataType.LONG.getId().toValue(), + null, null, null); + cast.add(secondExpression); + divideExpression.add(cast); + divideExpression.add(longLiteralFactorSecond); + return divideExpression; + } else if (dataType != LongDataType.LONG.getId().toValue()) { + CastFunctionExpression cast = new CastFunctionExpression(LongDataType.LONG.getId().toValue(), + null, null, null); + cast.add(baseExpression); + divideExpression.add(cast); + divideExpression.add(longLiteralFactorDay); + return divideExpression; + } + divideExpression.add(baseExpression); + divideExpression.add(longLiteralFactorDay); + return divideExpression; } } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java index 1456cd56564aa59161a2941a6d241b6a5195fd7d..e61c142d046c4bee25bf82bafc04b0a86725a2fb 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java @@ -19,9 +19,19 @@ package com.huawei.boostkit.hive.reader; import static com.huawei.boostkit.hive.cache.VectorCache.BATCH; +import static com.huawei.boostkit.hive.expression.TypeUtils.DEFAULT_VARCHAR_LENGTH; import static org.apache.hadoop.hive.ql.io.orc.OrcInputFormat.getDesiredRowTypeDescr; import static org.apache.hadoop.hive.serde2.ColumnProjectionUtils.READ_COLUMN_IDS_CONF_STR; +import nova.hetu.omniruntime.type.BooleanDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.ShortDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.Decimal128Vec; import nova.hetu.omniruntime.vector.Vec; import nova.hetu.omniruntime.vector.VecBatch; @@ -35,6 +45,7 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl; import org.apache.hadoop.hive.ql.plan.api.OperatorType; import org.apache.hadoop.hive.serde2.SerDeStats; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.mapred.FileSplit; import org.apache.hadoop.mapred.RecordReader; @@ -45,12 +56,31 @@ import org.apache.orc.TypeDescription; import java.io.IOException; import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; public class OmniOrcRecordReader implements RecordReader, StatsProvidingRecordReader { + private static final Map CATEGORY_TO_OMNI_TYPE = new HashMap() { + { + put(TypeDescription.Category.SHORT, ShortDataType.SHORT); + put(TypeDescription.Category.INT, IntDataType.INTEGER); + put(TypeDescription.Category.LONG, LongDataType.LONG); + put(TypeDescription.Category.BOOLEAN, BooleanDataType.BOOLEAN); + put(TypeDescription.Category.DOUBLE, DoubleDataType.DOUBLE); + put(TypeDescription.Category.STRING, new VarcharDataType(DEFAULT_VARCHAR_LENGTH)); + put(TypeDescription.Category.TIMESTAMP, LongDataType.LONG); + put(TypeDescription.Category.DATE, IntDataType.INTEGER); + put(TypeDescription.Category.BYTE, ShortDataType.SHORT); + put(TypeDescription.Category.FLOAT, DoubleDataType.DOUBLE); + put(TypeDescription.Category.DECIMAL, Decimal128DataType.DECIMAL128); + put(TypeDescription.Category.CHAR, VarcharDataType.VARCHAR); + put(TypeDescription.Category.VARCHAR, VarcharDataType.VARCHAR); + } + }; protected OrcColumnarBatchScanReader recordReader; protected Vec[] vecs; protected final long offset; @@ -59,6 +89,7 @@ public class OmniOrcRecordReader implements RecordReader included; protected Operator tableScanOp; + protected int[] typeIds; OmniOrcRecordReader(Configuration conf, FileSplit split) throws IOException { TypeDescription schema = getDesiredRowTypeDescr(conf, false, Integer.MAX_VALUE); @@ -70,6 +101,10 @@ public class OmniOrcRecordReader implements RecordReader +#include #include "jni_common.h" #include "common/UriInfo.h" @@ -28,7 +29,6 @@ using namespace std; using namespace orc; static constexpr int32_t MAX_DECIMAL64_DIGITS = 18; -bool isDecimal64Transfor128 = false; // vecFildsNames存储文件每列的列名,从orc reader c++侧获取,回传到java侧使用 JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_initializeReader(JNIEnv *env, @@ -80,7 +80,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea env->NewStringUTF("notNeedFSCache")); reader = createReader(orc::readFileOverride(uri, notNeedFSCache), readerOptions); std::vector orcColumnNames = reader->getAllFiedsName(); - for (int i = 0; i < orcColumnNames.size(); i++) { + for (uint32_t i = 0; i < orcColumnNames.size(); i++) { jstring fildname = env->NewStringUTF(orcColumnNames[i].c_str()); // use ArrayList and function env->CallBooleanMethod(vecFildsNames, arrayListAdd, fildname); @@ -274,12 +274,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea { JNI_FUNC_START orc::Reader *readerPtr = (orc::Reader *)reader; - // Get if the decimal for spark or hive - jboolean jni_isDecimal64Transfor128 = env->CallBooleanMethod(jsonObj, jsonMethodHas, - env->NewStringUTF("isDecimal64Transfor128")); - if (jni_isDecimal64Transfor128) { - isDecimal64Transfor128 = true; - } + // get offset from json obj jlong offset = env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("offset")); jlong length = env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("length")); @@ -334,75 +329,80 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea JNI_FUNC_END(runtimeExceptionClass) } -template uint64_t CopyFixedWidth(orc::ColumnVectorBatch *field) +template +std::unique_ptr CopyFixedWidth(orc::ColumnVectorBatch *field) { using T = typename NativeType::type; ORC_TYPE *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->data.data(); auto notNulls = lvb->notNull.data(); - auto newVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); // Check ColumnVectorBatch has null or not firstly if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (notNulls[i]) { - newVector->SetValue(i, (T)(values[i])); + newVectorPtr->SetValue(i, (T)(values[i])); } else { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } } } else { for (uint i = 0; i < numElements; i++) { - newVector->SetValue(i, (T)(values[i])); + newVectorPtr->SetValue(i, (T)(values[i])); } } - return (uint64_t)newVector; + return newVector; } -template uint64_t CopyOptimizedForInt64(orc::ColumnVectorBatch *field) +template +std::unique_ptr CopyOptimizedForInt64(orc::ColumnVectorBatch *field) { using T = typename NativeType::type; ORC_TYPE *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->data.data(); auto notNulls = lvb->notNull.data(); - auto newVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); // Check ColumnVectorBatch has null or not firstly if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (!notNulls[i]) { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } } } - newVector->SetValues(0, values, numElements); - return (uint64_t)newVector; + newVectorPtr->SetValues(0, values, numElements); + return newVector; } -uint64_t CopyVarWidth(orc::ColumnVectorBatch *field) +std::unique_ptr CopyVarWidth(orc::ColumnVectorBatch *field) { orc::StringVectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->data.data(); auto notNulls = lvb->notNull.data(); auto lens = lvb->length.data(); - auto newVector = new Vector>(numElements); + auto newVector = std::make_unique>>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (notNulls[i]) { auto data = std::string_view(reinterpret_cast(values[i]), lens[i]); - newVector->SetValue(i, data); + newVectorPtr->SetValue(i, data); } else { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } } } else { for (uint i = 0; i < numElements; i++) { auto data = std::string_view(reinterpret_cast(values[i]), lens[i]); - newVector->SetValue(i, data); + newVectorPtr->SetValue(i, data); } } - return (uint64_t)newVector; + return newVector; } inline void FindLastNotEmpty(const char *chars, long &len) @@ -412,14 +412,15 @@ inline void FindLastNotEmpty(const char *chars, long &len) } } -uint64_t CopyCharType(orc::ColumnVectorBatch *field) +std::unique_ptr CopyCharType(orc::ColumnVectorBatch *field) { orc::StringVectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->data.data(); auto notNulls = lvb->notNull.data(); auto lens = lvb->length.data(); - auto newVector = new Vector>(numElements); + auto newVector = std::make_unique>>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (notNulls[i]) { @@ -427,9 +428,9 @@ uint64_t CopyCharType(orc::ColumnVectorBatch *field) auto len = lens[i]; FindLastNotEmpty(chars, len); auto data = std::string_view(chars, len); - newVector->SetValue(i, data); + newVectorPtr->SetValue(i, data); } else { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } } } else { @@ -438,133 +439,170 @@ uint64_t CopyCharType(orc::ColumnVectorBatch *field) auto len = lens[i]; FindLastNotEmpty(chars, len); auto data = std::string_view(chars, len); - newVector->SetValue(i, data); + newVectorPtr->SetValue(i, data); } } - return (uint64_t)newVector; + return newVector; } -uint64_t CopyToOmniDecimal128Vec(orc::ColumnVectorBatch *field) +std::unique_ptr CopyToOmniDecimal128Vec(orc::ColumnVectorBatch *field) { orc::Decimal128VectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->values.data(); auto notNulls = lvb->notNull.data(); - auto newVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (notNulls[i]) { - newVector->SetValue(i, Decimal128(values[i].getHighBits(), values[i].getLowBits())); + __int128_t dst = values[i].getHighBits(); + dst <<= 64; + dst |= values[i].getLowBits(); + newVectorPtr->SetValue(i, Decimal128(dst)); } else { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } } } else { for (uint i = 0; i < numElements; i++) { - newVector->SetValue(i, Decimal128(values[i].getHighBits(), values[i].getLowBits())); + newVectorPtr->SetValue(i, Decimal128(values[i].getHighBits(), values[i].getLowBits())); } } - return (uint64_t)newVector; + return newVector; } -uint64_t CopyToOmniDecimal64Vec(orc::ColumnVectorBatch *field) +std::unique_ptr CopyToOmniDecimal64Vec(orc::ColumnVectorBatch *field) { orc::Decimal64VectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->values.data(); auto notNulls = lvb->notNull.data(); - auto newVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (!notNulls[i]) { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } } } - newVector->SetValues(0, values, numElements); - return (uint64_t)newVector; + newVectorPtr->SetValues(0, values, numElements); + return newVector; } -uint64_t CopyToOmniDecimal128VecFrom64(orc::ColumnVectorBatch *field) +std::unique_ptr CopyToOmniDecimal128VecFrom64(orc::ColumnVectorBatch *field) { orc::Decimal64VectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->values.data(); auto notNulls = lvb->notNull.data(); - auto newVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (!notNulls[i]) { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } else { Decimal128 d128(values[i]); - newVector->SetValue(i, d128); + newVectorPtr->SetValue(i, d128); } } } else { for (uint i = 0; i < numElements; i++) { Decimal128 d128(values[i]); - newVector->SetValue(i, d128); + newVectorPtr->SetValue(i, d128); + } + } + + return newVector; +} + +std::unique_ptr DealLongVectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_BOOLEAN: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_SHORT: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_INT: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_LONG: + return CopyOptimizedForInt64(field); + case omniruntime::type::OMNI_DATE32: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_DATE64: + return CopyOptimizedForInt64(field); + default: { + throw std::runtime_error("DealLongVectorBatch not support for type: " + id); } } +} + +std::unique_ptr DealDoubleVectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_DOUBLE: + return CopyOptimizedForInt64(field); + default: { + throw std::runtime_error("DealDoubleVectorBatch not support for type: " + id); + } + } +} + +std::unique_ptr DealDecimal64VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_DECIMAL64: + return CopyToOmniDecimal64Vec(field); + case omniruntime::type::OMNI_DECIMAL128: + return CopyToOmniDecimal128VecFrom64(field); + default: { + throw std::runtime_error("DealDecimal64VectorBatch not support for type: " + id); + } + } +} - return (uint64_t)newVector; +std::unique_ptr DealDecimal128VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_DECIMAL128: + return CopyToOmniDecimal128Vec(field); + default: { + throw std::runtime_error("DealDecimal128VectorBatch not support for type: " + id); + } + } } -int CopyToOmniVec(const orc::Type *type, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, - bool isDecimal64Transfor128) +std::unique_ptr CopyToOmniVec(const orc::Type *type, int omniTypeId, orc::ColumnVectorBatch *field) { + DataTypeId dataTypeId = static_cast(omniTypeId); switch (type->getKind()) { case orc::TypeKind::BOOLEAN: - omniTypeId = static_cast(OMNI_BOOLEAN); - omniVecId = CopyFixedWidth(field); - break; case orc::TypeKind::SHORT: - omniTypeId = static_cast(OMNI_SHORT); - omniVecId = CopyFixedWidth(field); - break; case orc::TypeKind::DATE: - omniTypeId = static_cast(OMNI_DATE32); - omniVecId = CopyFixedWidth(field); - break; case orc::TypeKind::INT: - omniTypeId = static_cast(OMNI_INT); - omniVecId = CopyFixedWidth(field); - break; case orc::TypeKind::LONG: - omniTypeId = static_cast(OMNI_LONG); - omniVecId = CopyOptimizedForInt64(field); - break; + return DealLongVectorBatch(dataTypeId, field); case orc::TypeKind::DOUBLE: - omniTypeId = static_cast(OMNI_DOUBLE); - omniVecId = CopyOptimizedForInt64(field); - break; + return DealDoubleVectorBatch(dataTypeId, field); case orc::TypeKind::CHAR: - omniTypeId = static_cast(OMNI_VARCHAR); - omniVecId = CopyCharType(field); - break; + if (dataTypeId != OMNI_VARCHAR) { + throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc char"); + } + return CopyCharType(field); case orc::TypeKind::STRING: case orc::TypeKind::VARCHAR: - omniTypeId = static_cast(OMNI_VARCHAR); - omniVecId = CopyVarWidth(field); - break; + if (dataTypeId != OMNI_VARCHAR) { + throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc string/varchar"); + } + return CopyVarWidth(field); case orc::TypeKind::DECIMAL: if (type->getPrecision() > MAX_DECIMAL64_DIGITS) { - omniTypeId = static_cast(OMNI_DECIMAL128); - omniVecId = CopyToOmniDecimal128Vec(field); - } else if (isDecimal64Transfor128) { - omniTypeId = static_cast(OMNI_DECIMAL128); - omniVecId = CopyToOmniDecimal128VecFrom64(field); + return DealDecimal128VectorBatch(dataTypeId, field); } else { - omniTypeId = static_cast(OMNI_DECIMAL64); - omniVecId = CopyToOmniDecimal64Vec(field); + return DealDecimal64VectorBatch(dataTypeId, field); } - break; default: { throw std::runtime_error("Native ColumnarFileScan Not support For This Type: " + type->getKind()); } } - return 1; } JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderNext(JNIEnv *env, @@ -573,24 +611,33 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea JNI_FUNC_START orc::RowReader *rowReaderPtr = (orc::RowReader *)rowReader; orc::ColumnVectorBatch *columnVectorBatch = (orc::ColumnVectorBatch *)batch; + std::vector> omniVecs; + const orc::Type &baseTp = rowReaderPtr->getSelectedType(); - int vecCnt = 0; - long batchRowSize = 0; + uint64_t batchRowSize = 0; + auto ptr = env->GetIntArrayElements(typeId, JNI_FALSE); + if (ptr == NULL) { + throw std::runtime_error("Types should not be null"); + } + int32_t arrLen = (int32_t) env->GetArrayLength(typeId); if (rowReaderPtr->next(*columnVectorBatch)) { orc::StructVectorBatch *root = dynamic_cast(columnVectorBatch); - vecCnt = root->fields.size(); batchRowSize = root->fields[0]->numElements; - for (int id = 0; id < vecCnt; id++) { + int32_t vecCnt = root->fields.size(); + if (vecCnt != arrLen) { + throw std::runtime_error("Types should align to root fields"); + } + for (int32_t id = 0; id < vecCnt; id++) { auto type = baseTp.getSubtype(id); - int omniTypeId = 0; - uint64_t omniVecId = 0; - CopyToOmniVec(type, omniTypeId, omniVecId, root->fields[id], isDecimal64Transfor128); - env->SetIntArrayRegion(typeId, id, 1, &omniTypeId); - jlong omniVec = static_cast(omniVecId); + int omniTypeId = ptr[id]; + omniVecs.emplace_back(CopyToOmniVec(type, omniTypeId, root->fields[id])); + } + for (int32_t id = 0; id < vecCnt; id++) { + jlong omniVec = reinterpret_cast(omniVecs[id].release()); env->SetLongArrayRegion(vecNativeId, id, 1, &omniVec); } } - return (jlong)batchRowSize; + return (jlong) batchRowSize; JNI_FUNC_END(runtimeExceptionClass) } diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h index 829f5c0744d3d563601ec6506ebfc82b5a020e93..8b942fe8b3e0975cccfece4a68e9112b6c550802 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h @@ -141,8 +141,7 @@ int BuildLeaves(PredicateOperatorType leafOp, std::vector &litList bool StringToBool(const std::string &boolStr); -int CopyToOmniVec(const orc::Type *type, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, - bool isDecimal64Transfor128); +std::unique_ptr CopyToOmniVec(const orc::Type *type, int omniTypeId, orc::ColumnVectorBatch *field); #ifdef __cplusplus } diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp index 21c0b81c96b2c83fcade1207ae7a2816da7af7fb..991699a7be573db1f191fe7b20de0a45baf4964d 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp @@ -93,10 +93,14 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJn { JNI_FUNC_START ParquetReader *pReader = (ParquetReader *)reader; - std::vector recordBatch(pReader->columnReaders.size()); + std::vector recordBatch(pReader->columnReaders.size(), 0); long batchRowSize = 0; auto state = pReader->ReadNextBatch(recordBatch, &batchRowSize); if (state != Status::OK()) { + for (auto vec : recordBatch) { + delete vec; + } + recordBatch.clear(); env->ThrowNew(runtimeExceptionClass, state.ToString().c_str()); return 0; } diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp index 5f4f379e413c854161dc5e04cce146e98e5a148f..672f4ceedffa043b5577f7a954154ae8e8e599e2 100644 --- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp @@ -143,8 +143,12 @@ Status ParquetReader::GetRecordBatchReader(const std::vector &row_group_ind return Status::OK(); } - for (uint64_t i = 0; i < columnReaders.size(); i++) { - RETURN_NOT_OK(columnReaders[i]->NextBatch(read_size, &batch[i])); + try { + for (uint64_t i = 0; i < columnReaders.size(); i++) { + RETURN_NOT_OK(columnReaders[i]->NextBatch(read_size, &batch[i])); + } + } catch (const std::exception &e) { + return Status::Invalid(e.what()); } // Check BaseVector diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h index 76108fab68de5c3583b6708cfba60f899a7aa2d9..3f602c979d71e76d2995e99c15f73355c956a98c 100644 --- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h @@ -486,11 +486,10 @@ namespace omniruntime::reader { virtual void InitVec(int64_t capacity) { vec_ = new Vector(capacity); + auto capacity_bytes = capacity * byte_width_; if (parquet_vec_ != nullptr) { - auto capacity_bytes = capacity * byte_width_; memset(parquet_vec_, 0, capacity_bytes); } else { - auto capacity_bytes = capacity * byte_width_; parquet_vec_ = new uint8_t[capacity_bytes]; } // Init nulls diff --git a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt index 491cfb7086037229608f2963cf6c278ca132b198..10f630ad13925922872540fb13b379a0b52e15b3 100644 --- a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt @@ -1,9 +1,9 @@ -# project name -project(spark-thestral-plugin) - # required cmake version cmake_minimum_required(VERSION 3.10) +# project name +project(spark-thestral-plugin) + # configure cmake set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_COMPILER "g++") diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h b/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h index 73fe13732d27dca87e12ac72900635f8f26cd5f4..ab8a52c229017b4277c7c3b5552477133aa27e4b 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h +++ b/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h @@ -16,29 +16,42 @@ * limitations under the License. */ - #ifndef CPP_BUFFER_H - #define CPP_BUFFER_H +#ifndef CPP_BUFFER_H +#define CPP_BUFFER_H - #include - #include - #include - #include - #include - - class Buffer { - public: - Buffer(uint8_t* data, int64_t size, int64_t capacity) - : data_(data), - size_(size), - capacity_(capacity) { +#include +#include +#include +#include +#include +#include + +class Buffer { +public: + Buffer(uint8_t* data, int64_t size, int64_t capacity, bool isOmniAllocated = true) + : data_(data), + size_(size), + capacity_(capacity), + allocatedByOmni(isOmniAllocated) { + } + + ~Buffer() { + if (allocatedByOmni && not releaseFlag) { + auto *allocator = omniruntime::mem::Allocator::GetAllocator(); + allocator->Free(data_, capacity_); } + } - ~Buffer() {} + void SetReleaseFlag() { + releaseFlag = true; + } - public: - uint8_t * data_; - int64_t size_; - int64_t capacity_; - }; +public: + uint8_t * data_; + int64_t size_; + int64_t capacity_; + bool allocatedByOmni = true; + bool releaseFlag = false; +}; - #endif //CPP_BUFFER_H \ No newline at end of file +#endif //CPP_BUFFER_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp index ca982c0a4ca56100cb6c11599d6d0c334009da92..14785a9cf453f5925974f8b85dc80538a5b85a17 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp @@ -131,7 +131,6 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split( JNIEnv *env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress) { - JNI_FUNC_START auto splitter = g_shuffleSplitterHolder.Lookup(splitter_id); if (!splitter) { std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); @@ -140,10 +139,11 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split } auto vecBatch = (VectorBatch *) jVecBatchAddress; - + splitter->SetInputVecBatch(vecBatch); + JNI_FUNC_START splitter->Split(*vecBatch); return 0L; - JNI_FUNC_END(runtimeExceptionClass) + JNI_FUNC_END_WITH_VECBATCH(runtimeExceptionClass, splitter->GetInputVecBatch()) } JNIEXPORT jobject JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_stop( diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h index 4b59296e152876062a06db3d69c81a7ed22b670b..964fab6dfc06ac692294fa212f40afc21a4d1041 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h @@ -48,6 +48,15 @@ jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const ch return; \ } \ +#define JNI_FUNC_END_WITH_VECBATCH(exceptionClass, toDeleteVecBatch) \ + } \ + catch (const std::exception &e) \ + { \ + VectorHelper::FreeVecBatch(toDeleteVecBatch); \ + env->ThrowNew(exceptionClass, e.what()); \ + return 0; \ + } + extern jclass runtimeExceptionClass; extern jclass splitResultClass; extern jclass jsonClass; diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp index f0a9b3e76ab9956581a13744b16851cd4f34d040..462aee16731ffe932d4881ff28fb22d9593dabac 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp @@ -77,7 +77,7 @@ int Splitter::AllocatePartitionBuffers(int32_t partition_id, int32_t new_size) { case SHUFFLE_DECIMAL128: default: { void *ptr_tmp = static_cast(options_.allocator->Alloc(new_size * (1 << column_type_id_[i]))); - fixed_valueBuffer_size_[partition_id] = new_size * (1 << column_type_id_[i]); + fixed_valueBuffer_size_[partition_id] += new_size * (1 << column_type_id_[i]); if (nullptr == ptr_tmp) { throw std::runtime_error("Allocator for AllocatePartitionBuffers Failed! "); } @@ -346,7 +346,7 @@ int Splitter::SplitFixedWidthValidityBuffer(VectorBatch& vb){ dst_addrs[pid] = const_cast(validity_buffer->data_); std::memset(validity_buffer->data_, 0, new_size); partition_fixed_width_buffers_[col][pid][0] = std::move(validity_buffer); - fixed_nullBuffer_size_[pid] = new_size; + fixed_nullBuffer_size_[pid] += new_size; } } @@ -449,6 +449,7 @@ int Splitter::DoSplit(VectorBatch& vb) { num_row_splited_ += vb.GetRowCount(); // release the fixed width vector and release vectorBatch at the same time ReleaseVectorBatch(&vb); + this->ResetInputVecBatch(); // 阈值检查,是否溢写 if (num_row_splited_ >= SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) { @@ -675,7 +676,7 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, valueStr.resize(onceCopyLen); std::string nullStr; - std::shared_ptr ptr_value (new Buffer((uint8_t*)valueStr.data(), 0, onceCopyLen)); + std::shared_ptr ptr_value (new Buffer((uint8_t*)valueStr.data(), 0, onceCopyLen, false)); std::shared_ptr ptr_validity; // options_.spill_batch_row_num长度切割与拼接 @@ -698,7 +699,7 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp]); if (not nullAllocated && partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0] != nullptr) { nullStr.resize(splitRowInfoTmp->onceCopyRow); - ptr_validity.reset(new Buffer((uint8_t*)nullStr.data(), 0, splitRowInfoTmp->onceCopyRow)); + ptr_validity.reset(new Buffer((uint8_t*)nullStr.data(), 0, splitRowInfoTmp->onceCopyRow, false)); nullAllocated = true; } if ((onceCopyLen - destCopyedLength) >= (cacheBatchSize - splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp])) { @@ -714,9 +715,11 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, // 释放内存 options_.allocator->Free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_, partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->capacity_); + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->SetReleaseFlag(); } options_.allocator->Free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->data_, partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->capacity_); + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->SetReleaseFlag(); destCopyedLength += memCopyLen; splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp] += 1; // cacheBatchIndex下标后移 splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] = 0; // 初始化下一个cacheBatch的起始偏移 @@ -1058,6 +1061,7 @@ int Splitter::DeleteSpilledTmpFile() { auto tmpDataFilePath = pair.first + ".data"; // 释放存储有各个临时文件的偏移数据内存 options_.allocator->Free(pair.second->data_, pair.second->capacity_); + pair.second->SetReleaseFlag(); if (IsFileExist(tmpDataFilePath)) { remove(tmpDataFilePath.c_str()); } diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h index 4e9684de963a3bb4eedf0720ec8a01c1d2160f2b..d54b97849dadb06c72c6a456aedd9eefa1f21092 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h @@ -165,6 +165,7 @@ private: } } vectorAddress.clear(); + vb->ClearVectors(); delete vb; } @@ -173,7 +174,7 @@ private: std::vector vector_batch_col_types_; InputDataTypes input_col_types; std::vector binary_array_empirical_size_; - + omniruntime::vec::VectorBatch *inputVecBatch = nullptr; public: bool singlePartitionFlag = false; int32_t num_partitions_; @@ -221,6 +222,22 @@ public: int64_t TotalComputePidTime() const { return total_compute_pid_time_; } const std::vector& PartitionLengths() const { return partition_lengths_; } + + omniruntime::vec::VectorBatch *GetInputVecBatch() + { + return inputVecBatch; + } + + void SetInputVecBatch(omniruntime::vec::VectorBatch *inVecBatch) + { + inputVecBatch = inVecBatch; + } + + // no need to clear memory when exception, so we have to reset + void ResetInputVecBatch() + { + inputVecBatch = nullptr; + } }; diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java index 8edbdf4622f0b5888cd4fd680fee6fc1eb3c4880..611a10826042a7916ff7d98f804c7ee5eba319b1 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java @@ -34,7 +34,6 @@ import org.slf4j.LoggerFactory; import java.net.URI; import java.sql.Date; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; public class OrcColumnarBatchScanReader { @@ -253,8 +252,7 @@ public class OrcColumnarBatchScanReader { } } - public int next(Vec[] vecList) { - int[] typeIds = new int[realColsCnt]; + public int next(Vec[] vecList, int[] typeIds) { long[] vecNativeIds = new long[realColsCnt]; long rtn = jniReader.recordReaderNext(recordReader, batchReader, typeIds, vecNativeIds); if (rtn == 0) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java index 9f6cadf70b7348a1696e140a13887c87018af4b2..6a0c1b27c4282016aecada2ba4ef0c48c320f20f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java @@ -20,6 +20,7 @@ package com.huawei.boostkit.spark.serialize; import com.google.protobuf.InvalidProtocolBufferException; +import nova.hetu.omniruntime.utils.OmniRuntimeException; import nova.hetu.omniruntime.vector.BooleanVec; import nova.hetu.omniruntime.vector.Decimal128Vec; import nova.hetu.omniruntime.vector.DoubleVec; @@ -35,21 +36,31 @@ import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; - public class ShuffleDataSerializer { public static ColumnarBatch deserialize(byte[] bytes) { + ColumnVector[] vecs = null; try { VecData.VecBatch vecBatch = VecData.VecBatch.parseFrom(bytes); int vecCount = vecBatch.getVecCnt(); int rowCount = vecBatch.getRowCnt(); - ColumnVector[] vecs = new ColumnVector[vecCount]; + vecs = new ColumnVector[vecCount]; for (int i = 0; i < vecCount; i++) { vecs[i] = buildVec(vecBatch.getVecs(i), rowCount); } return new ColumnarBatch(vecs, rowCount); } catch (InvalidProtocolBufferException e) { throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage()); + } catch (OmniRuntimeException e) { + if (vecs != null) { + for (int i = 0; i < vecs.length; i++) { + ColumnVector vec = vecs[i]; + if (vec != null) { + vec.close(); + } + } + } + throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage()); } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/parquet/OmniParquetColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/parquet/OmniParquetColumnarBatchReader.java index 0bf7527a7b2744e50a4b55ba7ba73d6c5f612b5e..12f297f52f48b98e3c68c6fca631f242d2a68cb4 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/parquet/OmniParquetColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/parquet/OmniParquetColumnarBatchReader.java @@ -108,8 +108,10 @@ public class OmniParquetColumnarBatchReader extends RecordReader - if (!isSlice) { - vector.getVec - } else { - vector.getVec.slice(0, cb.numRows()) + try { + for (i <- 0 until cb.numCols()) { + val omniVec: Vec = cb.column(i) match { + case vector: OmniColumnVector => + if (!isSlice) { + vector.getVec + } else { + vector.getVec.slice(0, cb.numRows()) + } + case vector: ColumnVector => + transColumnVector(vector, cb.numRows()) + case _ => + throw new UnsupportedOperationException("unsupport column vector!") + } + input(i) = omniVec + } + } catch { + case e: Exception => { + for (j <- 0 until cb.numCols()) { + val vec = input(j) + if (vec != null) vec.close + cb.column(j) match { + case vector: OmniColumnVector => + vector.close() } - case vector: ColumnVector => - transColumnVector(vector, cb.numRows()) - case _ => - throw new UnsupportedOperationException("unsupport column vector!") + } + throw new RuntimeException("allocate memory failed!") } - input(i) = omniVec } input } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala index 77707365e2e92680b027c4649144a0d599768720..33e8037a12aa253da8b7a40f0895c511d31cfec8 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala @@ -312,7 +312,7 @@ object ColumnarBatchToInternalRow { val batchIter = batches.flatMap { batch => - // toClosedVecs closed case: + // toClosedVecs closed case: [Deprcated] // 1) all rows of batch fetched and closed // 2) only fetch parital rows(eg: top-n, limit-n), closed at task CompletionListener callback val toClosedVecs = new ListBuffer[Vec] @@ -330,27 +330,22 @@ object ColumnarBatchToInternalRow { new Iterator[InternalRow] { val numOutputRowsMetric: SQLMetric = numOutputRows - var closed = false - - // only invoke if fetch partial rows of batch - if (mayPartialFetch) { - SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => - if (!closed) { - toClosedVecs.foreach {vec => - vec.close() - } + + + SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => + toClosedVecs.foreach {vec => + vec.close() } - } } override def hasNext: Boolean = { val has = iter.hasNext - // fetch all rows and closed - if (!has && !closed) { + // fetch all rows + if (!has) { toClosedVecs.foreach {vec => vec.close() + toClosedVecs.remove(toClosedVecs.indexOf(vec)) } - closed = true } has } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala index 11541ccee1c0e9ad775eb71360bf0d081d72bcd2..12fbdb6b4a2c5b8b4eec6ea51fbe7f6d919dbb76 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala @@ -330,24 +330,27 @@ case class ColumnarOptRollupExec( omniOutputPartials) omniCodegenTimeMetric += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + + val results = new ListBuffer[VecBatch]() + var hashaggResults: java.util.Iterator[VecBatch] = null + // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperators.foreach(operator => operator.close()) hashaggOperator.close() + results.foreach(vecBatch => { + vecBatch.releaseAllVectors() + vecBatch.close() + }) }) - val results = new ListBuffer[VecBatch]() - var hashaggResults: java.util.Iterator[VecBatch] = null - while (iter.hasNext) { val batch = iter.next() val input = transColBatchToOmniVecs(batch) val vecBatch = new VecBatch(input, batch.numRows()) results.append(vecBatch) projectOperators.foreach(projectOperator => { - val vecs = vecBatch.getVectors.map(vec => { - vec.slice(0, vecBatch.getRowCount) - }) + val vecs = transColBatchToOmniVecs(batch, true) val projectInput = new VecBatch(vecs, vecBatch.getRowCount) var startInput = System.nanoTime() diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index f928ad11c258bb07c64f11d065adadb5f439dcc0..233503141dc11a6d74ea8e000a02594cb4607116 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.createShuffleWriteProcessor import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.metric._ -import org.apache.spark.sql.execution.util.MergeIterator +import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.internal.SQLConf @@ -167,10 +167,14 @@ class ColumnarShuffleExchangeExec( if (enableShuffleBatchMerge) { cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => - new MergeIterator(iter, + val mergeIterator = new MergeIterator(iter, StructType.fromAttributes(child.output), longMetric("numMergedVecBatches"), longMetric("bypassVecBatches")) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + mergeIterator.close() + }) + mergeIterator } } else { cachedShuffleRDD diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala index d7365c8c152b507c366d917cce8fd127acf55a2b..098b606becd652009b272f2bbbc4e130f81e8f80 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.util.MergeIterator +import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -207,10 +207,14 @@ case class ColumnarCustomShuffleReaderExec( if (enableShuffleBatchMerge) { cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => - new MergeIterator(iter, + val mergeIterator = new MergeIterator(iter, StructType.fromAttributes(child.output), longMetric("numMergedVecBatches"), longMetric("bypassVecBatches")) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + mergeIterator.close() + }) + mergeIterator } } else { cachedShuffleRDD diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java similarity index 86% rename from omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java rename to omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java index 24a93ede4cf4e17ccd0050157e04a848ab38ba0d..3fdb0d28683c90b2c8904ae094a461cb90a68e89 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.orc; import com.google.common.annotations.VisibleForTesting; +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; import com.huawei.boostkit.spark.jni.OrcColumnarBatchScanReader; import nova.hetu.omniruntime.vector.Vec; import org.apache.hadoop.conf.Configuration; @@ -79,6 +80,8 @@ public class OmniOrcColumnarBatchReader extends RecordReader(); + // collect read cols types + ArrayList typeBuilder = new ArrayList<>(); for (int i = 0; i < requiredfieldNames.length; i++) { String target = requiredfieldNames[i]; boolean is_find = false; @@ -161,6 +168,7 @@ public class OmniOrcColumnarBatchReader extends RecordReader val startBuildInput = System.nanoTime() @@ -356,7 +360,19 @@ case class ColumnarBroadcastHashJoinExec( buildAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildInput) } val startBuildGetOp = System.nanoTime() - op.getOutput + try { + op.getOutput + } catch { + case e: Exception => { + if (isShared) { + OmniHashBuilderWithExprOperatorFactory.dereferenceHashBuilderOperatorAndFactory(buildPlan.id) + } else { + op.close() + opFactory.close() + } + throw new RuntimeException("HashBuilder getOutput failed") + } + } buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp) (opFactory, op) } @@ -368,11 +384,9 @@ case class ColumnarBroadcastHashJoinExec( try { buildOpFactory = OmniHashBuilderWithExprOperatorFactory.getHashBuilderOperatorFactory(buildPlan.id) if (buildOpFactory == null) { - val (opFactory, op) = createBuildOpFactoryAndOp() + val (opFactory, op) = createBuildOpFactoryAndOp(true) buildOpFactory = opFactory buildOp = op - OmniHashBuilderWithExprOperatorFactory.saveHashBuilderOperatorAndFactory(buildPlan.id, - buildOpFactory, buildOp) } } catch { case e: Exception => { @@ -382,7 +396,7 @@ case class ColumnarBroadcastHashJoinExec( OmniHashBuilderWithExprOperatorFactory.gLock.unlock() } } else { - val (opFactory, op) = createBuildOpFactoryAndOp() + val (opFactory, op) = createBuildOpFactoryAndOp(false) buildOpFactory = opFactory buildOp = op } @@ -492,7 +506,11 @@ case class ColumnarBroadcastHashJoinExec( } if (enableJoinBatchMerge) { - new MergeIterator(iterBatch, resultSchema, numMergedVecBatches, bypassVecBatches) + val mergeIterator = new MergeIterator(iterBatch, resultSchema, numMergedVecBatches, bypassVecBatches) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + mergeIterator.close() + }) + mergeIterator } else { iterBatch } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala index f3fb45a1cab4a0b1f9fe23c45c9f298b728a3761..8a61c3fdbe61f4a568d10c796dd32ba5e64bbab0 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala @@ -285,7 +285,11 @@ case class ColumnarSortMergeJoinExec( } if (enableSortMergeJoinBatchMerge) { - new MergeIterator(iterBatch, resultSchema, numMergedVecBatches, bypassVecBatches) + val mergeIterator = new MergeIterator(iterBatch, resultSchema, numMergedVecBatches, bypassVecBatches) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + mergeIterator.close() + }) + mergeIterator } else { iterBatch } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala index 105d3435367c2818c7f2b1da267ca19f43ce27ef..cea35f2a37d159fee797632c49101939b09ae67c 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala @@ -44,29 +44,40 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, private def createOmniVectors(schema: StructType, columnSize: Int): Array[Vec] = { val vecs = new Array[Vec](schema.fields.length) - schema.fields.zipWithIndex.foreach { case (field, index) => - field.dataType match { - case LongType => - vecs(index) = new LongVec(columnSize) - case DateType | IntegerType => - vecs(index) = new IntVec(columnSize) - case ShortType => - vecs(index) = new ShortVec(columnSize) - case DoubleType => - vecs(index) = new DoubleVec(columnSize) - case BooleanType => - vecs(index) = new BooleanVec(columnSize) - case StringType => - val vecType: DataType = sparkTypeToOmniType(field.dataType, field.metadata) - vecs(index) = new VarcharVec(columnSize) - case dt: DecimalType => - if (DecimalType.is64BitDecimalType(dt)) { + try { + schema.fields.zipWithIndex.foreach { case (field, index) => + field.dataType match { + case LongType => vecs(index) = new LongVec(columnSize) - } else { - vecs(index) = new Decimal128Vec(columnSize) + case DateType | IntegerType => + vecs(index) = new IntVec(columnSize) + case ShortType => + vecs(index) = new ShortVec(columnSize) + case DoubleType => + vecs(index) = new DoubleVec(columnSize) + case BooleanType => + vecs(index) = new BooleanVec(columnSize) + case StringType => + val vecType: DataType = sparkTypeToOmniType(field.dataType, field.metadata) + vecs(index) = new VarcharVec(columnSize) + case dt: DecimalType => + if (DecimalType.is64BitDecimalType(dt)) { + vecs(index) = new LongVec(columnSize) + } else { + vecs(index) = new Decimal128Vec(columnSize) + } + case _ => + throw new UnsupportedOperationException("Fail to create omni vector, unsupported fields") + } + } + } catch { + case e: Exception => { + for (vec <- vecs) { + if (vec != null) { + vec.close() } - case _ => - throw new UnsupportedOperationException("Fail to create omni vector, unsupported fields") + } + throw new RuntimeException("allocate memory failed!") } } vecs @@ -165,4 +176,15 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, def isFull(): Boolean = { totalRows > maxRowCount || currentBatchSizeInBytes >= maxBatchSizeInBytes } + + def close(): Unit = { + for (elem <- bufferedVecBatch) { + elem.releaseAllVectors() + elem.close() + } + for (elem <- outputQueue) { + elem.releaseAllVectors() + elem.close() + } + } } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java index 0c65d0bcc16354910ba8808995caf9fff4779aba..5490f10eed8ac897a76a807117f628e3b9f6a16a 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java @@ -40,6 +40,10 @@ import java.util.ArrayList; import java.net.URI; import java.net.URISyntaxException; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; + @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase { public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @@ -101,7 +105,7 @@ public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase { @Test public void testNext() { - int[] typeId = new int[4]; + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_INT.ordinal()}; long[] vecNativeId = new long[4]; long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); @@ -120,4 +124,27 @@ public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase { vec3.close(); vec4.close(); } + + // Here we test OMNI_LONG type instead of OMNI_INT in 4th field. + @Test + public void testNextIfSchemaChange() { + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_LONG.ordinal()}; + long[] vecNativeId = new long[4]; + long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); + assertTrue(rtn == 4096); + LongVec vec1 = new LongVec(vecNativeId[0]); + VarcharVec vec2 = new VarcharVec(vecNativeId[1]); + VarcharVec vec3 = new VarcharVec(vecNativeId[2]); + LongVec vec4 = new LongVec(vecNativeId[3]); + assertTrue(vec1.get(10) == 11); + String tmp1 = new String(vec2.get(4080)); + assertTrue(tmp1.equals("AAAAAAAABPPAAAAA")); + String tmp2 = new String(vec3.get(4070)); + assertTrue(tmp2.equals("Particular, arab cases shall like less current, different names. Computers start for the changes. Scottish, trying exercises operate marks; long, supreme miners may ro")); + assertTrue(0 == vec4.get(1000)); + vec1.close(); + vec2.close(); + vec3.close(); + vec4.close(); + } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java index 0a06c7016a1f945e433f140e4e1759569b5824f2..358eb6c62b627d5ce9927e256bad2294f4448a47 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java @@ -37,6 +37,9 @@ import java.util.ArrayList; import java.net.URI; import java.net.URISyntaxException; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; + @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase { public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @@ -92,7 +95,7 @@ public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase { @Test public void testNext() { - int[] typeId = new int[2]; + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()}; long[] vecNativeId = new long[2]; long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java index 28598ebc9a2322c86892e779c73014a2dbbf6aaf..029516c1c33dc66f7a3caab4b39a25af52d754f7 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java @@ -45,6 +45,9 @@ import java.net.URISyntaxException; import java.lang.reflect.Array; import java.util.ArrayList; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; + @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderPushDownTest extends TestCase { public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @@ -146,7 +149,7 @@ public class OrcColumnarBatchJniReaderPushDownTest extends TestCase { @Test public void testNext() { - int[] typeId = new int[2]; + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()}; long[] vecNativeId = new long[2]; long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java index 4b9625930a13a7ed570bf1d741a3b919780b50ce..0c83e19102b1a03d4133cacdc246ba1a9f639489 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java @@ -40,6 +40,10 @@ import java.util.ArrayList; import java.net.URI; import java.net.URISyntaxException; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; + @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase { public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @@ -101,7 +105,7 @@ public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase { @Test public void testNext() { - int[] typeId = new int[4]; + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_INT.ordinal()}; long[] vecNativeId = new long[4]; long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java index 196d1ff350efa81be59282b090a9375edfd01a10..8246e5606e626fee237ea246d1d5454ee37232c0 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java @@ -40,6 +40,10 @@ import java.util.ArrayList; import java.net.URI; import java.net.URISyntaxException; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; + @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase { public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @@ -147,7 +151,7 @@ public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase { @Test public void testNext() { - int[] typeId = new int[4]; + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_INT.ordinal()}; long[] vecNativeId = new long[4]; long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java index c8581f35ebc605845896b5f35731b0908779f326..eab15fef660250780e0beb311b489e3ceeb8ff5b 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java @@ -21,20 +21,15 @@ package com.huawei.boostkit.spark.jni; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.io.Input; import junit.framework.TestCase; -import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.IntVec; import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.VarcharVec; import nova.hetu.omniruntime.vector.Vec; import org.apache.commons.codec.binary.Base64; import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl; -import org.apache.orc.OrcConf; import org.apache.orc.OrcFile; -import org.apache.orc.Reader; import org.apache.orc.TypeDescription; import org.apache.orc.mapred.OrcInputFormat; -import org.json.JSONObject; import org.junit.After; import org.junit.Before; import org.junit.FixMethodOrder; @@ -51,6 +46,9 @@ import org.apache.orc.Reader.Options; import static org.junit.Assert.*; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; + @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderTest extends TestCase { public Configuration conf = new Configuration(); @@ -152,7 +150,8 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { @Test public void testNext() { Vec[] vecs = new Vec[2]; - long rtn = orcColumnarBatchScanReader.next(vecs); + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()}; + long rtn = orcColumnarBatchScanReader.next(vecs, typeId); assertTrue(rtn == 4096); assertTrue(((LongVec) vecs[0]).get(0) == 1); String str = new String(((VarcharVec) vecs[1]).get(0));