From 98ea09f502889d726fb3cac1318f3bea01c8839f Mon Sep 17 00:00:00 2001 From: zhousipei Date: Thu, 29 Aug 2024 19:42:47 +0800 Subject: [PATCH 01/12] refactor orc push down filter --- .../spark/jni/OrcColumnarBatchScanReader.java | 442 +++++++++++------- .../orc/OmniOrcColumnarBatchReader.java | 160 +++---- .../datasources/orc/OmniOrcFileFormat.scala | 94 +--- ...OrcColumnarBatchJniReaderDataTypeTest.java | 2 +- ...ColumnarBatchJniReaderNotPushDownTest.java | 2 +- ...OrcColumnarBatchJniReaderPushDownTest.java | 2 +- ...BatchJniReaderSparkORCNotPushDownTest.java | 2 +- ...narBatchJniReaderSparkORCPushDownTest.java | 2 +- .../jni/OrcColumnarBatchJniReaderTest.java | 134 +++--- 9 files changed, 418 insertions(+), 422 deletions(-) rename omnioperator/omniop-spark-extension/java/src/main/{scala => java}/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java (57%) diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java index df8c564b5..7147fb45d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java @@ -25,65 +25,78 @@ import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.vector.*; import org.apache.orc.impl.writer.TimestampTreeWriter; +import org.apache.spark.sql.catalyst.util.CharVarcharUtils; import org.apache.spark.sql.catalyst.util.RebaseDateTime; -import org.apache.hadoop.hive.ql.io.sarg.ExpressionTree; -import org.apache.hadoop.hive.ql.io.sarg.PredicateLeaf; -import org.apache.orc.OrcFile.ReaderOptions; -import org.apache.orc.Reader.Options; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.json.JSONObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.math.BigDecimal; import java.net.URI; -import java.sql.Date; -import java.sql.Timestamp; +import java.time.LocalDate; import java.text.DateFormat; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.ArrayList; -import java.util.List; +import java.util.Arrays; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.TimeZone; public class OrcColumnarBatchScanReader { private static final Logger LOGGER = LoggerFactory.getLogger(OrcColumnarBatchScanReader.class); private boolean nativeSupportTimestampRebase; + private static final Pattern CHAR_TYPE = Pattern.compile("char\\(\\s*(\\d+)\\s*\\)"); + + private static final int MAX_LEAF_THRESHOLD = 256; + public long reader; public long recordReader; public long batchReader; - public int[] colsToGet; - public int realColsCnt; - public ArrayList fildsNames; + // All ORC fieldNames + public ArrayList allFieldsNames; - public ArrayList colToInclu; + // Indicate columns to read + public int[] colsToGet; - public String[] requiredfieldNames; + // Actual columns to read + public ArrayList includedColumns; - public int[] precisionArray; + // max threshold for leaf node + private int leafIndex; - public int[] scaleArray; + // spark required schema + private StructType requiredSchema; public OrcColumnarBatchJniReader jniReader; public OrcColumnarBatchScanReader() { jniReader = new OrcColumnarBatchJniReader(); - fildsNames = new ArrayList(); - } - - public JSONObject getSubJson(ExpressionTree node) { - JSONObject jsonObject = new JSONObject(); - jsonObject.put("op", node.getOperator().ordinal()); - if (node.getOperator().toString().equals("LEAF")) { - jsonObject.put("leaf", node.toString()); - return jsonObject; - } - ArrayList child = new ArrayList(); - for (ExpressionTree childNode : node.getChildren()) { - JSONObject rtnJson = getSubJson(childNode); - child.add(rtnJson); - } - jsonObject.put("child", child); - return jsonObject; + allFieldsNames = new ArrayList(); } public String padZeroForDecimals(String [] decimalStrArray, int decimalScale) { @@ -95,91 +108,6 @@ public class OrcColumnarBatchScanReader { return String.format("%1$-" + decimalScale + "s", decimalVal).replace(' ', '0'); } - public int getPrecision(String colname) { - for (int i = 0; i < requiredfieldNames.length; i++) { - if (colname.equals(requiredfieldNames[i])) { - return precisionArray[i]; - } - } - - return -1; - } - - public int getScale(String colname) { - for (int i = 0; i < requiredfieldNames.length; i++) { - if (colname.equals(requiredfieldNames[i])) { - return scaleArray[i]; - } - } - - return -1; - } - - public JSONObject getLeavesJson(List leaves) { - JSONObject jsonObjectList = new JSONObject(); - for (int i = 0; i < leaves.size(); i++) { - PredicateLeaf pl = leaves.get(i); - JSONObject jsonObject = new JSONObject(); - jsonObject.put("op", pl.getOperator().ordinal()); - jsonObject.put("name", pl.getColumnName()); - jsonObject.put("type", pl.getType().ordinal()); - if (pl.getLiteral() != null) { - if (pl.getType() == PredicateLeaf.Type.DATE) { - jsonObject.put("literal", ((int)Math.ceil(((Date)pl.getLiteral()).getTime()* 1.0/3600/24/1000)) + ""); - } else if (pl.getType() == PredicateLeaf.Type.DECIMAL) { - int decimalP = getPrecision(pl.getColumnName()); - int decimalS = getScale(pl.getColumnName()); - String[] spiltValues = pl.getLiteral().toString().split("\\."); - if (decimalS == 0) { - jsonObject.put("literal", spiltValues[0] + " " + decimalP + " " + decimalS); - } else { - String scalePadZeroStr = padZeroForDecimals(spiltValues, decimalS); - jsonObject.put("literal", spiltValues[0] + "." + scalePadZeroStr + " " + decimalP + " " + decimalS); - } - } else if (pl.getType() == PredicateLeaf.Type.TIMESTAMP) { - Timestamp t = (Timestamp)pl.getLiteral(); - jsonObject.put("literal", formatSecs(t.getTime() / TimestampTreeWriter.MILLIS_PER_SECOND) + " " + formatNanos(t.getNanos())); - } else { - jsonObject.put("literal", pl.getLiteral().toString()); - } - } else { - jsonObject.put("literal", ""); - } - if ((pl.getLiteralList() != null) && (pl.getLiteralList().size() != 0)){ - List lst = new ArrayList<>(); - for (Object ob : pl.getLiteralList()) { - if (ob == null) { - lst.add(null); - continue; - } - if (pl.getType() == PredicateLeaf.Type.DECIMAL) { - int decimalP = getPrecision(pl.getColumnName()); - int decimalS = getScale(pl.getColumnName()); - String[] spiltValues = ob.toString().split("\\."); - if (decimalS == 0) { - lst.add(spiltValues[0] + " " + decimalP + " " + decimalS); - } else { - String scalePadZeroStr = padZeroForDecimals(spiltValues, decimalS); - lst.add(spiltValues[0] + "." + scalePadZeroStr + " " + decimalP + " " + decimalS); - } - } else if (pl.getType() == PredicateLeaf.Type.DATE) { - lst.add(((int)Math.ceil(((Date)ob).getTime()* 1.0/3600/24/1000)) + ""); - } else if (pl.getType() == PredicateLeaf.Type.TIMESTAMP) { - Timestamp t = (Timestamp)pl.getLiteral(); - lst.add(formatSecs(t.getTime() / TimestampTreeWriter.MILLIS_PER_SECOND) + " " + formatNanos(t.getNanos())); - } else { - lst.add(ob.toString()); - } - } - jsonObject.put("literalList", lst); - } else { - jsonObject.put("literalList", new ArrayList()); - } - jsonObjectList.put("leaf-" + i, jsonObject); - } - return jsonObjectList; - } - private long formatSecs(long secs) { DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); long epoch; @@ -224,15 +152,11 @@ public class OrcColumnarBatchScanReader { * Init Orc reader. * * @param uri split file path - * @param options split file options */ - public long initializeReaderJava(URI uri, ReaderOptions options) { + public long initializeReaderJava(URI uri) { JSONObject job = new JSONObject(); - if (options.getOrcTail() == null) { - job.put("serializedTail", ""); - } else { - job.put("serializedTail", options.getOrcTail().getSerializedTail().toString()); - } + + job.put("serializedTail", ""); job.put("tailLocation", 9223372036854775807L); job.put("scheme", uri.getScheme() == null ? "" : uri.getScheme()); @@ -240,38 +164,37 @@ public class OrcColumnarBatchScanReader { job.put("port", uri.getPort()); job.put("path", uri.getPath() == null ? "" : uri.getPath()); - reader = jniReader.initializeReader(job, fildsNames); + reader = jniReader.initializeReader(job, allFieldsNames); return reader; } /** * Init Orc RecordReader. * - * @param options split file options + * @param offset split file offset + * @param length split file read length + * @param pushedFilter the filter push down to native + * @param requiredSchema the columns read from native */ - public long initializeRecordReaderJava(Options options) { + public long initializeRecordReaderJava(long offset, long length, Filter pushedFilter, StructType requiredSchema) { + this.requiredSchema = requiredSchema; JSONObject job = new JSONObject(); - if (options.getInclude() == null) { - job.put("include", ""); - } else { - job.put("include", options.getInclude().toString()); - } - job.put("offset", options.getOffset()); - job.put("length", options.getLength()); - // When the number of pushedFilters > hive.CNF_COMBINATIONS_THRESHOLD, the expression is rewritten to - // 'YES_NO_NULL'. Under the circumstances, filter push down will be skipped. - if (options.getSearchArgument() != null - && !options.getSearchArgument().toString().contains("YES_NO_NULL")) { - LOGGER.debug("SearchArgument: {}", options.getSearchArgument().toString()); - JSONObject jsonexpressionTree = getSubJson(options.getSearchArgument().getExpression()); - job.put("expressionTree", jsonexpressionTree); - JSONObject jsonleaves = getLeavesJson(options.getSearchArgument().getLeaves()); - job.put("leaves", jsonleaves); + + job.put("offset", offset); + job.put("length", length); + + if (pushedFilter != null) { + JSONObject jsonExpressionTree = new JSONObject(); + JSONObject jsonLeaves = new JSONObject(); + boolean flag = canPushDown(pushedFilter, jsonExpressionTree, jsonLeaves); + if (flag) { + job.put("expressionTree", jsonExpressionTree); + job.put("leaves", jsonLeaves); + } } - job.put("includedColumns", colToInclu.toArray()); + job.put("includedColumns", includedColumns.toArray()); addJulianGregorianInfo(job); - recordReader = jniReader.initializeRecordReader(reader, job); return recordReader; } @@ -318,13 +241,13 @@ public class OrcColumnarBatchScanReader { } public int next(Vec[] vecList, int[] typeIds) { - long[] vecNativeIds = new long[realColsCnt]; + long[] vecNativeIds = new long[typeIds.length]; long rtn = jniReader.recordReaderNext(recordReader, batchReader, typeIds, vecNativeIds); if (rtn == 0) { return 0; } int nativeGetId = 0; - for (int i = 0; i < realColsCnt; i++) { + for (int i = 0; i < colsToGet.length; i++) { if (colsToGet[i] != 0) { continue; } @@ -372,7 +295,7 @@ public class OrcColumnarBatchScanReader { } default: { throw new RuntimeException("UnSupport type for ColumnarFileScan:" + - DataType.DataTypeId.values()[typeIds[i]]); + DataType.DataTypeId.values()[typeIds[i]]); } } nativeGetId++; @@ -380,18 +303,225 @@ public class OrcColumnarBatchScanReader { return (int)rtn; } - private static String bytesToHexString(byte[] bytes) { - if (bytes == null || bytes.length < 1) { - throw new IllegalArgumentException("this bytes must not be null or empty"); + enum OrcOperator { + OR, + AND, + NOT, + LEAF, + CONSTANT + } + + enum OrcLeafOperator { + EQUALS, + NULL_SAFE_EQUALS, + LESS_THAN, + LESS_THAN_EQUALS, + IN, + BETWEEN, // not use, spark transfers it to gt and lt + IS_NULL + } + + enum OrcPredicateDataType { + LONG, // all of integer types + FLOAT, // float and double + STRING, // string, char, varchar + DATE, + DECIMAL, + TIMESTAMP, + BOOLEAN + } + + private OrcPredicateDataType getOrcPredicateDataType(String attribute) { + StructField field = requiredSchema.apply(attribute); + org.apache.spark.sql.types.DataType dataType = field.dataType(); + if (dataType instanceof ShortType || dataType instanceof IntegerType || + dataType instanceof LongType) { + return OrcPredicateDataType.LONG; + } else if (dataType instanceof DoubleType) { + return OrcPredicateDataType.FLOAT; + } else if (dataType instanceof StringType) { + if (isCharType(field.metadata())) { + throw new UnsupportedOperationException("Unsupported orc push down filter data type: char"); + } + return OrcPredicateDataType.STRING; + } else if (dataType instanceof DateType) { + return OrcPredicateDataType.DATE; + } else if (dataType instanceof DecimalType) { + return OrcPredicateDataType.DECIMAL; + } else if (dataType instanceof BooleanType) { + return OrcPredicateDataType.BOOLEAN; + } else { + throw new UnsupportedOperationException("Unsupported orc push down filter data type: " + + dataType.getClass().getSimpleName()); + } + } + + // Check the type whether is char type, which orc native does not support push down + private boolean isCharType(Metadata metadata) { + if (metadata != null) { + String rawTypeString = CharVarcharUtils.getRawTypeString(metadata).getOrElse(null); + if (rawTypeString != null) { + Matcher matcher = CHAR_TYPE.matcher(rawTypeString); + return matcher.matches(); + } + } + return false; + } + + private boolean canPushDown(Filter pushedFilter, JSONObject jsonExpressionTree, + JSONObject jsonLeaves) { + try { + getExprJson(pushedFilter, jsonExpressionTree, jsonLeaves); + if (leafIndex > MAX_LEAF_THRESHOLD) { + throw new UnsupportedOperationException("leaf node nums is " + leafIndex + + ", which is bigger than max threshold " + MAX_LEAF_THRESHOLD + "."); + } + return true; + } catch (Exception e) { + LOGGER.info("Unable to push down orc filter because " + e.getMessage()); + return false; + } + } + + private void getExprJson(Filter filterPredicate, JSONObject jsonExpressionTree, + JSONObject jsonLeaves) { + if (filterPredicate instanceof And) { + addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.AND, + ((And) filterPredicate).left(), ((And) filterPredicate).right()); + } else if (filterPredicate instanceof Or) { + addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.OR, + ((Or) filterPredicate).left(), ((Or) filterPredicate).right()); + } else if (filterPredicate instanceof Not) { + addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.NOT, + ((Not) filterPredicate).child()); + } else if (filterPredicate instanceof EqualTo) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.EQUALS, jsonLeaves, + ((EqualTo) filterPredicate).attribute(), ((EqualTo) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof GreaterThan) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN_EQUALS, jsonLeaves, + ((GreaterThan) filterPredicate).attribute(), ((GreaterThan) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof GreaterThanOrEqual) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN, jsonLeaves, + ((GreaterThanOrEqual) filterPredicate).attribute(), ((GreaterThanOrEqual) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof LessThan) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN, jsonLeaves, + ((LessThan) filterPredicate).attribute(), ((LessThan) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof LessThanOrEqual) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN_EQUALS, jsonLeaves, + ((LessThanOrEqual) filterPredicate).attribute(), ((LessThanOrEqual) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof IsNotNull) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IS_NULL, jsonLeaves, + ((IsNotNull) filterPredicate).attribute(), null, null); + leafIndex++; + } else if (filterPredicate instanceof IsNull) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IS_NULL, jsonLeaves, + ((IsNull) filterPredicate).attribute(), null, null); + leafIndex++; + } else if (filterPredicate instanceof In) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IN, jsonLeaves, + ((In) filterPredicate).attribute(), null, Arrays.stream(((In) filterPredicate).values()).toArray()); + leafIndex++; + } else { + throw new UnsupportedOperationException("Unsupported orc push down filter operation: " + + filterPredicate.getClass().getSimpleName()); + } + } + + private void addLiteralToJsonLeaves(String leaf, OrcLeafOperator leafOperator, JSONObject jsonLeaves, + String name, Object literal, Object[] literals) { + JSONObject leafJson = new JSONObject(); + leafJson.put("op", leafOperator.ordinal()); + leafJson.put("name", name); + leafJson.put("type", getOrcPredicateDataType(name).ordinal()); + + leafJson.put("literal", getLiteralValue(literal)); + + ArrayList literalList = new ArrayList<>(); + if (literals != null) { + for (Object lit: literalList) { + literalList.add(getLiteralValue(literal)); + } + } + leafJson.put("literalList", literalList); + jsonLeaves.put(leaf, leafJson); + } + + private void addToJsonExpressionTree(String leaf, JSONObject jsonExpressionTree, boolean addNot) { + if (addNot) { + jsonExpressionTree.put("op", OrcOperator.NOT.ordinal()); + ArrayList child = new ArrayList<>(); + JSONObject subJson = new JSONObject(); + subJson.put("op", OrcOperator.LEAF.ordinal()); + subJson.put("leaf", leaf); + child.add(subJson); + jsonExpressionTree.put("child", child); + } else { + jsonExpressionTree.put("op", OrcOperator.LEAF.ordinal()); + jsonExpressionTree.put("leaf", leaf); + } + } + + private void addChildJson(JSONObject jsonExpressionTree, JSONObject jsonLeaves, + OrcOperator orcOperator, Filter ... filters) { + jsonExpressionTree.put("op", orcOperator.ordinal()); + ArrayList child = new ArrayList<>(); + for (Filter filter: filters) { + JSONObject subJson = new JSONObject(); + getExprJson(filter, subJson, jsonLeaves); + child.add(subJson); } + jsonExpressionTree.put("child", child); + } - final StringBuilder hexString = new StringBuilder(); - for (int i = 0; i < bytes.length; i++) { - if ((bytes[i] & 0xff) < 0x10) - hexString.append("0"); - hexString.append(Integer.toHexString(bytes[i] & 0xff)); + private String getLiteralValue(Object literal) { + // For null literal, the predicate will not be pushed down. + if (literal == null) { + throw new UnsupportedOperationException("Unsupported orc push down filter for literal is null"); } - return hexString.toString().toLowerCase(); + // For Decimal Type, we use the special string format to represent, which is "$decimalVal + // $precision $scale". + // e.g., Decimal(9, 3) = 123456.789, it outputs "123456.789 9 3". + // e.g., Decimal(9, 3) = 123456.7, it outputs "123456.700 9 3". + if (literal instanceof BigDecimal) { + BigDecimal value = (BigDecimal) literal; + int precision = value.precision(); + int scale = value.scale(); + String[] split = value.toString().split("\\."); + if (scale == 0) { + return split[0] + " " + precision + " " + scale; + } else { + String padded = padZeroForDecimals(split, scale); + return split[0] + "." + padded + " " + precision + " " + scale; + } + } + // For Date Type, spark uses Gregorian in default but orc uses Julian, which should be converted. + if (literal instanceof LocalDate) { + int epochDay = Math.toIntExact(((LocalDate) literal).toEpochDay()); + int rebased = RebaseDateTime.rebaseGregorianToJulianDays(epochDay); + return String.valueOf(rebased); + } + if (literal instanceof String) { + return (String) literal; + } + if (literal instanceof Integer || literal instanceof Long || literal instanceof Boolean || + literal instanceof Short || literal instanceof Double) { + return literal.toString(); + } + throw new UnsupportedOperationException("Unsupported orc push down filter date type: " + + literal.getClass().getSimpleName()); } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java similarity index 57% rename from omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java rename to omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java index 93950e9f0..bd0b42463 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -22,18 +22,15 @@ import com.google.common.annotations.VisibleForTesting; import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; import com.huawei.boostkit.spark.jni.OrcColumnarBatchScanReader; import nova.hetu.omniruntime.vector.Vec; -import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.RecordReader; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.input.FileSplit; -import org.apache.orc.OrcConf; -import org.apache.orc.OrcFile; -import org.apache.orc.Reader; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.OmniColumnVectorUtils; import org.apache.spark.sql.execution.vectorized.OmniColumnVector; +import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -49,26 +46,11 @@ import java.util.ArrayList; public class OmniOrcColumnarBatchReader extends RecordReader { // The capacity of vectorized batch. - private int capacity; - /** - * The column IDs of the physical ORC file schema which are required by this reader. - * -1 means this required column is partition column, or it doesn't exist in the ORC file. - * Ideally partition column should never appear in the physical file, and should only appear - * in the directory name. However, Spark allows partition columns inside physical file, - * but Spark will discard the values from the file, and use the partition value got from - * directory name. The column order will be reserved though. - */ - @VisibleForTesting - public int[] requestedDataColIds; - // Native Record reader from ORC row batch. private OrcColumnarBatchScanReader recordReader; - private StructField[] requiredFields; - private StructField[] resultFields; - // The result columnar batch for vectorized execution by whole-stage codegen. @VisibleForTesting public ColumnarBatch columnarBatch; @@ -82,10 +64,13 @@ public class OmniOrcColumnarBatchReader extends RecordReader orcfieldNames = recordReader.fildsNames; // save valid cols and numbers of valid cols recordReader.colsToGet = new int[requiredfieldNames.length]; - recordReader.realColsCnt = 0; - // save valid cols fieldsNames - recordReader.colToInclu = new ArrayList(); + recordReader.includedColumns = new ArrayList<>(); // collect read cols types ArrayList typeBuilder = new ArrayList<>(); + for (int i = 0; i < requiredfieldNames.length; i++) { String target = requiredfieldNames[i]; - boolean is_find = false; - for (int j = 0; j < orcfieldNames.size(); j++) { - String temp = orcfieldNames.get(j); - if (target.equals(temp)) { - requestedDataColIds[i] = i; - recordReader.colsToGet[i] = 0; - recordReader.colToInclu.add(requiredfieldNames[i]); - recordReader.realColsCnt++; - typeBuilder.add(OmniExpressionAdaptor.sparkTypeToOmniType(requiredSchema.fields()[i].dataType())); - is_find = true; - } - } - - // if invalid, set colsToGet value -1, else set colsToGet 0 - if (!is_find) { + // if not find, set colsToGet value -1, else set colsToGet 0 + if (recordReader.allFieldsNames.contains(target)) { + recordReader.colsToGet[i] = 0; + recordReader.includedColumns.add(requiredfieldNames[i]); + typeBuilder.add(OmniExpressionAdaptor.sparkTypeToOmniType(requiredSchema.fields()[i].dataType())); + } else { recordReader.colsToGet[i] = -1; } } vecTypeIds = typeBuilder.stream().mapToInt(Integer::intValue).toArray(); - - for (int i = 0; i < resultFields.length; i++) { - if (requestedPartitionColIds[i] != -1) { - requestedDataColIds[i] = -1; - } - } - - // set data members resultFields and requestedDataColIdS - this.resultFields = resultFields; - this.requestedDataColIds = requestedDataColIds; - - recordReader.requiredfieldNames = requiredfieldNames; - recordReader.precisionArray = precisionArray; - recordReader.scaleArray = scaleArray; - recordReader.initializeRecordReaderJava(options); } /** * Initialize columnar batch by setting required schema and partition information. * With this information, this creates ColumnarBatch with the full schema. * - * @param requiredFields The fields that are required to return,. - * @param resultFields All the fields that are required to return, including partition fields. - * @param requestedDataColIds Requested column ids from orcSchema. -1 if not existed. - * @param requestedPartitionColIds Requested column ids from partition schema. -1 if not existed. + * @param partitionColumns partition columns * @param partitionValues Values of partition columns. */ - public void initBatch( - StructField[] requiredFields, - StructField[] resultFields, - int[] requestedDataColIds, - int[] requestedPartitionColIds, - InternalRow partitionValues) { - if (resultFields.length != requestedDataColIds.length || resultFields.length != requestedPartitionColIds.length){ - throw new UnsupportedOperationException("This operator doesn't support orc initBatch."); - } + public void initBatch(StructType partitionColumns, InternalRow partitionValues) { + StructType resultSchema = new StructType(); - this.requiredFields = requiredFields; + for (StructField f: requiredSchema.fields()) { + resultSchema = resultSchema.add(f); + } - StructType resultSchema = new StructType(resultFields); + if (partitionColumns != null) { + for (StructField f: partitionColumns.fields()) { + resultSchema = resultSchema.add(f); + } + } // Just wrap the ORC column vector instead of copying it to Spark column vector. orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; templateWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; - for (int i = 0; i < resultFields.length; i++) { - DataType dt = resultFields[i].dataType(); - if (requestedPartitionColIds[i] != -1) { - OmniColumnVector partitionCol = new OmniColumnVector(capacity, dt, true); - OmniColumnVectorUtils.populate(partitionCol, partitionValues, requestedPartitionColIds[i]); + if (partitionColumns != null) { + int partitionIdx = requiredSchema.fields().length; + for (int i = 0; i < partitionColumns.fields().length; i++) { + OmniColumnVector partitionCol = new OmniColumnVector(capacity, partitionColumns.fields()[i].dataType(), true); + OmniColumnVectorUtils.populate(partitionCol, partitionValues, i); partitionCol.setIsConstant(); - templateWrappers[i] = partitionCol; - orcVectorWrappers[i] = new OmniColumnVector(capacity, dt, false);; + templateWrappers[i + partitionIdx] = partitionCol; + orcVectorWrappers[i + partitionIdx] = new OmniColumnVector(capacity, partitionColumns.fields()[i].dataType(), false); + } + } + + for (int i = 0; i < requiredSchema.fields().length; i++) { + DataType dt = requiredSchema.fields()[i].dataType(); + if (recordReader.colsToGet[i] == -1) { + // missing cols + OmniColumnVector missingCol = new OmniColumnVector(capacity, dt, true); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + templateWrappers[i] = missingCol; } else { - int colId = requestedDataColIds[i]; - // Initialize the missing columns once. - if (colId == -1) { - OmniColumnVector missingCol = new OmniColumnVector(capacity, dt, true); - missingCol.putNulls(0, capacity); - missingCol.setIsConstant(); - templateWrappers[i] = missingCol; - } else { - templateWrappers[i] = new OmniColumnVector(capacity, dt, false); - } - orcVectorWrappers[i] = new OmniColumnVector(capacity, dt, false); + templateWrappers[i] = new OmniColumnVector(capacity, dt, false); } + orcVectorWrappers[i] = new OmniColumnVector(capacity, dt, false); } + // init batch recordReader.initBatchJava(capacity); vecs = new Vec[orcVectorWrappers.length]; @@ -260,7 +211,7 @@ public class OmniOrcColumnarBatchReader extends RecordReader - convertibleFiltersHelper(left, dataSchema) && convertibleFiltersHelper(right, dataSchema) - case Or(left, right) => - convertibleFiltersHelper(left, dataSchema) && convertibleFiltersHelper(right, dataSchema) - case Not(pred) => - convertibleFiltersHelper(pred, dataSchema) - case other => - other match { - case EqualTo(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case EqualNullSafe(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case LessThan(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case LessThanOrEqual(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case GreaterThan(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case GreaterThanOrEqual(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case IsNull(name) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case IsNotNull(name) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case In(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case _ => false - } - } - - filters.map { filter => - convertibleFiltersHelper(filter, dataSchema) - } - } - override def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, @@ -101,7 +62,6 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) val sqlConf = sparkSession.sessionState.conf val capacity = sqlConf.orcVectorizedReaderBatchSize @@ -111,21 +71,17 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown - val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles (file: PartitionedFile) => { val conf = broadcastedConf.value.value val filePath = new Path(new URI(file.filePath)) - val isPPDSafeValue = isPPDSafe(filters, dataSchema).reduceOption(_ && _) // ORC predicate pushdown - if (orcFilterPushDown && filters.nonEmpty && isPPDSafeValue.getOrElse(false)) { - OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { - fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f => - OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) - } - } + val pushed = if (orcFilterPushDown) { + filters.reduceOption(And(_, _)) + } else { + None } val taskConf = new Configuration(conf) @@ -134,42 +90,16 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) // read data from vectorized reader - val batchReader = new OmniOrcColumnarBatchReader(capacity) + val batchReader = new OmniOrcColumnarBatchReader(capacity, requiredSchema, pushed.orNull) // SPARK-23399 Register a task completion listener first to call `close()` in all cases. // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) // after opening a file. val iter = new RecordReaderIterator(batchReader) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) - // fill requestedDataColIds with -1, fil real values int initDataColIds function - val requestedDataColIds = Array.fill(requiredSchema.length)(-1) ++ Array.fill(partitionSchema.length)(-1) - val requestedPartitionColIds = - Array.fill(requiredSchema.length)(-1) ++ Range(0, partitionSchema.length) - - // 初始化precision数组和scale数组,透传至java侧使用 - val requiredFields = requiredSchema.fields - val fieldslength = requiredFields.length - val precisionArray : Array[Int] = Array.ofDim[Int](fieldslength) - val scaleArray : Array[Int] = Array.ofDim[Int](fieldslength) - for ((reqField, index) <- requiredFields.zipWithIndex) { - val reqdatatype = reqField.dataType - if (reqdatatype.isInstanceOf[DecimalType]) { - val precision = reqdatatype.asInstanceOf[DecimalType].precision - val scale = reqdatatype.asInstanceOf[DecimalType].scale - precisionArray(index) = precision - scaleArray(index) = scale - } - } SparkMemoryUtils.init() batchReader.initialize(fileSplit, taskAttemptContext) - batchReader.initDataColIds(requiredSchema, requestedPartitionColIds, requestedDataColIds, resultSchema.fields, - precisionArray, scaleArray) - batchReader.initBatch( - requiredSchema.fields, - resultSchema.fields, - requestedDataColIds, - requestedPartitionColIds, - file.partitionValues) + batchReader.initBatch(partitionSchema, file.partitionValues) iter.asInstanceOf[Iterator[InternalRow]] } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java index fe1c55ffb..b236da644 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java @@ -68,7 +68,7 @@ public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java index 995c434f6..7f6c1acb9 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java @@ -66,7 +66,7 @@ public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java index c9ad9fada..2c912d919 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java @@ -66,7 +66,7 @@ public class OrcColumnarBatchJniReaderPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java index 8f4535338..cf86c0a5a 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java @@ -68,7 +68,7 @@ public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java index 27bcf5d7b..ef8d037bf 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java @@ -68,7 +68,7 @@ public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java index eab15fef6..19f23db00 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java @@ -18,103 +18,90 @@ package com.huawei.boostkit.spark.jni; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.io.Input; +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; import junit.framework.TestCase; import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.VarcharVec; import nova.hetu.omniruntime.vector.Vec; -import org.apache.commons.codec.binary.Base64; -import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; -import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl; -import org.apache.orc.OrcFile; -import org.apache.orc.TypeDescription; -import org.apache.orc.mapred.OrcInputFormat; +import org.apache.hadoop.conf.Configuration; +import org.apache.orc.Reader; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.junit.After; import org.junit.Before; import org.junit.FixMethodOrder; import org.junit.Test; import org.junit.runners.MethodSorters; -import org.apache.hadoop.conf.Configuration; + import java.io.File; import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; -import java.util.List; import java.util.Arrays; -import org.apache.orc.Reader.Options; - -import static org.junit.Assert.*; +import java.util.List; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import static org.apache.spark.sql.types.DataTypes.LongType; +import static org.apache.spark.sql.types.DataTypes.StringType; @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderTest extends TestCase { public Configuration conf = new Configuration(); public OrcColumnarBatchScanReader orcColumnarBatchScanReader; - public int batchSize = 4096; + private int batchSize = 4096; + + private StructType requiredSchema; + private int[] vecTypeIds; + + private long offset = 0; + + private long length = Integer.MAX_VALUE; @Before public void setUp() throws Exception { - Configuration conf = new Configuration(); - TypeDescription schema = - TypeDescription.fromString("struct<`i_item_sk`:bigint,`i_item_id`:string>"); - Options options = new Options(conf) - .range(0, Integer.MAX_VALUE) - .useZeroCopy(false) - .skipCorruptRecords(false) - .tolerateMissingSchema(true); - - options.schema(schema); - options.include(OrcInputFormat.parseInclude(schema, - null)); - String kryoSarg = "AQEAb3JnLmFwYWNoZS5oYWRvb3AuaGl2ZS5xbC5pby5zYXJnLkV4cHJlc3Npb25UcmXlAQEBamF2YS51dGlsLkFycmF5TGlz9AECAQABAQEBAQEAAQAAAAEEAAEBAwEAAQEBAQEBAAEAAAIIAAEJAAEBAgEBAQIBAscBb3JnLmFwYWNoZS5oYWRvb3AuaGl2ZS5xbC5pby5zYXJnLlNlYXJjaEFyZ3VtZW50SW1wbCRQcmVkaWNhdGVMZWFmSW1wbAEBaV9pdGVtX3PrAAABBwEBAQIBEAkAAAEEEg=="; - String sargColumns = "i_item_sk,i_item_id,i_rec_start_date,i_rec_end_date,i_item_desc,i_current_price,i_wholesale_cost,i_brand_id,i_brand,i_class_id,i_class,i_category_id,i_category,i_manufact_id,i_manufact,i_size,i_formulation,i_color,i_units,i_container,i_manager_id,i_product_name"; - if (kryoSarg != null && sargColumns != null) { - byte[] sargBytes = Base64.decodeBase64(kryoSarg); - SearchArgument sarg = - new Kryo().readObject(new Input(sargBytes), SearchArgumentImpl.class); - options.searchArgument(sarg, sargColumns.split(",")); - sarg.getExpression().toString(); - } - orcColumnarBatchScanReader = new OrcColumnarBatchScanReader(); + constructSchema(); initReaderJava(); - initDataColIds(options, orcColumnarBatchScanReader); - initRecordReaderJava(options); - initBatch(options); + initDataColIds(); + initRecordReaderJava(); + initBatch(); + } + + private void constructSchema() { + requiredSchema = new StructType() + .add("i_item_sk", LongType) + .add("i_item_id", StringType); } - public void initDataColIds( - Options options, OrcColumnarBatchScanReader orcColumnarBatchScanReader) { - List allCols; - allCols = Arrays.asList(options.getColumnNames()); - orcColumnarBatchScanReader.colToInclu = new ArrayList(); - List optionField = options.getSchema().getFieldNames(); - orcColumnarBatchScanReader.colsToGet = new int[optionField.size()]; - orcColumnarBatchScanReader.realColsCnt = 0; - for (int i = 0; i < optionField.size(); i++) { - if (allCols.contains(optionField.get(i))) { - orcColumnarBatchScanReader.colToInclu.add(optionField.get(i)); - orcColumnarBatchScanReader.colsToGet[i] = 0; - orcColumnarBatchScanReader.realColsCnt++; - } else { + private void initDataColIds() { + // find requiredS fieldNames + String[] requiredfieldNames = requiredSchema.fieldNames(); + // save valid cols and numbers of valid cols + orcColumnarBatchScanReader.colsToGet = new int[requiredfieldNames.length]; + orcColumnarBatchScanReader.includedColumns = new ArrayList<>(); + // collect read cols types + ArrayList typeBuilder = new ArrayList<>(); + + for (int i = 0; i < requiredfieldNames.length; i++) { + String target = requiredfieldNames[i]; + + // if not find, set colsToGet value -1, else set colsToGet 0 + boolean is_find = false; + for (int j = 0; j < orcColumnarBatchScanReader.allFieldsNames.size(); j++) { + if (target.equals(orcColumnarBatchScanReader.allFieldsNames.get(j))) { + orcColumnarBatchScanReader.colsToGet[i] = 0; + orcColumnarBatchScanReader.includedColumns.add(requiredfieldNames[i]); + typeBuilder.add(OmniExpressionAdaptor.sparkTypeToOmniType(requiredSchema.fields()[i].dataType())); + is_find = true; + break; + } + } + + if (!is_find) { orcColumnarBatchScanReader.colsToGet[i] = -1; } } - orcColumnarBatchScanReader.requiredfieldNames = new String[optionField.size()]; - TypeDescription schema = options.getSchema(); - int[] precisionArray = new int[optionField.size()]; - int[] scaleArray = new int[optionField.size()]; - for (int i = 0; i < optionField.size(); i++) { - precisionArray[i] = schema.findSubtype(optionField.get(i)).getPrecision(); - scaleArray[i] = schema.findSubtype(optionField.get(i)).getScale(); - orcColumnarBatchScanReader.requiredfieldNames[i] = optionField.get(i); - } - orcColumnarBatchScanReader.precisionArray = precisionArray; - orcColumnarBatchScanReader.scaleArray = scaleArray; + vecTypeIds = typeBuilder.stream().mapToInt(Integer::intValue).toArray(); } @After @@ -122,8 +109,7 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { System.out.println("orcColumnarBatchJniReader test finished"); } - public void initReaderJava() throws URISyntaxException { - OrcFile.ReaderOptions readerOptions = OrcFile.readerOptions(conf); + private void initReaderJava() { File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); String path = directory.getAbsolutePath(); URI uri = null; @@ -133,16 +119,17 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, readerOptions); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } - public void initRecordReaderJava(Options options) { - orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader.initializeRecordReaderJava(options); + private void initRecordReaderJava() { + orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader. + initializeRecordReaderJava(offset, length, null, requiredSchema); assertTrue(orcColumnarBatchScanReader.recordReader != 0); } - public void initBatch(Options options) { + private void initBatch() { orcColumnarBatchScanReader.initBatchJava(batchSize); assertTrue(orcColumnarBatchScanReader.batchReader != 0); } @@ -150,8 +137,7 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { @Test public void testNext() { Vec[] vecs = new Vec[2]; - int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()}; - long rtn = orcColumnarBatchScanReader.next(vecs, typeId); + long rtn = orcColumnarBatchScanReader.next(vecs, vecTypeIds); assertTrue(rtn == 4096); assertTrue(((LongVec) vecs[0]).get(0) == 1); String str = new String(((VarcharVec) vecs[1]).get(0)); -- Gitee From 0704341b62c6e0557ffd009515754ef0a25c8f0f Mon Sep 17 00:00:00 2001 From: liujingxiang-cs Date: Mon, 21 Oct 2024 11:52:10 +0000 Subject: [PATCH 02/12] !894 [spark extension] opt jni: row shuffle deserialized * [spark extension] opt jni: row shuffle deserialized --- .../boostkit/spark/serialize/ShuffleDataSerializer.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java index 55d56ae20..0859173ca 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java @@ -80,12 +80,12 @@ public class ShuffleDataSerializer { long[] omniVecs = new long[vecCount]; int[] omniTypes = new int[vecCount]; createEmptyVec(rowBatch, omniTypes, omniVecs, columnarVecs, vecCount, rowCount); - OmniRowDeserializer deserializer = new OmniRowDeserializer(omniTypes); + OmniRowDeserializer deserializer = new OmniRowDeserializer(omniTypes, omniVecs); for (int rowIdx = 0; rowIdx < rowCount; rowIdx++) { VecData.ProtoRow protoRow = rowBatch.getRows(rowIdx); byte[] array = protoRow.getData().toByteArray(); - deserializer.parse(array, omniVecs, rowIdx); + deserializer.parse(array, rowIdx); } // update initial varchar vector because it's capacity might have been expanded. -- Gitee From 8cd612abfeaffd4db0a8d2cd46f486313fe9519d Mon Sep 17 00:00:00 2001 From: hyy_cyan Date: Tue, 22 Oct 2024 01:44:47 +0000 Subject: [PATCH 03/12] !928 [Spark Extension] add empty2null expr * add test * add empty2null expr --- .../expression/OmniExpressionAdaptor.scala | 8 ++- .../OmniFileFormatDataWriter.scala | 61 ++++++++++++++----- .../spark/TableWriteBasicFunctionSuite.scala | 33 ++++++++++ 3 files changed, 87 insertions(+), 15 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 3152d6c7c..f15672927 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti import org.apache.spark.sql.catalyst.util.CharVarcharUtils.getRawTypeString import org.apache.spark.sql.execution import org.apache.spark.sql.execution.ColumnarBloomFilterSubquery +import org.apache.spark.sql.execution.datasources.OmniFileFormatWriter.Empty2Null import org.apache.spark.sql.expression.ColumnarExpressionConverter import org.apache.spark.sql.hive.HiveUdfAdaptorUtil import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, NullType, ShortType, StringType, TimestampType} @@ -319,7 +320,12 @@ object OmniExpressionAdaptor extends Logging { .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.str, exprsIndexMap)). put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.pos, exprsIndexMap)) .put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.len, exprsIndexMap))) - + case empty2Null: Empty2Null => + new JSONObject().put("exprType", "FUNCTION") + .put("function_name", "empty2null") + .addOmniExpJsonType("returnType", empty2Null.dataType) + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject + (empty2Null.child, exprsIndexMap))) // Cast case cast: CastBase => unsupportedCastCheck(expr, cast) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala index 8983a6f18..f3fe865e0 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala @@ -157,6 +157,35 @@ abstract class OmniBaseDynamicPartitionDataWriter( protected val getOutputRow = UnsafeProjection.create(description.dataColumns, description.allColumns) + protected def getPartitionPath(partitionValues: Option[InternalRow], + bucketId: Option[Int]): String = { + val partDir = partitionValues.map(getPartitionPath(_)) + partDir.foreach(updatedPartitions.add) + + val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + + // The prefix and suffix must be in a form that matches our bucketing format. See BucketingUtils + // for details. The prefix is required to represent bucket id when writing Hive-compatible + // bucketed table. + val prefix = bucketId match { + case Some(id) => description.bucketSpec.get.bucketFileNamePrefix(id) + case _ => "" + } + val suffix = f"$bucketIdStr.c$fileCounter%03d" + + description.outputWriterFactory.getFileExtension(taskAttemptContext) + val fileNameSpec = FileNameSpec(prefix, suffix) + + val customPath = partDir.flatMap { dir => + description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + } + val currentPath = if (customPath.isDefined) { + customPath.get + fileNameSpec.toString + } else { + partDir.toString + fileNameSpec.toString + } + currentPath + } + /** * Opens a new OutputWriter given a partition key and/or a bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the @@ -276,21 +305,25 @@ class OmniDynamicPartitionDataSingleWriter( val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None if (currentPartitionValues != nextPartitionValues || currentBucketId != nextBucketId) { - // See a new partition or bucket - write to a new partition dir (or a new bucket file). - if (isPartitioned && currentPartitionValues != nextPartitionValues) { - currentPartitionValues = Some(nextPartitionValues.get.copy()) - statsTrackers.foreach(_.newPartition(currentPartitionValues.get)) - } - if (isBucketed) { - currentBucketId = nextBucketId - } - - fileCounter = 0 - if (i != 0) { - writeRecord(omniInternalRow, lastIndex, i) - lastIndex = i + val isFilePathSame = getPartitionPath(currentPartitionValues, + currentBucketId) == getPartitionPath(nextPartitionValues, nextBucketId) + if (!isFilePathSame) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + if (isPartitioned && currentPartitionValues != nextPartitionValues) { + currentPartitionValues = Some(nextPartitionValues.get.copy()) + statsTrackers.foreach(_.newPartition(currentPartitionValues.get)) + } + if (isBucketed) { + currentBucketId = nextBucketId + } + + fileCounter = 0 + if (i != 0) { + writeRecord(omniInternalRow, lastIndex, i) + lastIndex = i + } + renewCurrentWriter(currentPartitionValues, currentBucketId, closeCurrentWriter = true) } - renewCurrentWriter(currentPartitionValues, currentBucketId, closeCurrentWriter = true) } else if ( description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala index 1566fb383..b32c3983d 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala @@ -168,4 +168,37 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { assert(runRows(0).getDate(0).toString == "1001-01-04", "the run value is error") } + + test("empty string partition") { + val drop = spark.sql("drop table if exists table_insert_varchar") + drop.collect() + val createTable = spark.sql("create table table_insert_varchar" + + "(id int, c_varchar varchar(40)) using orc partitioned by (p_varchar varchar(40))") + createTable.collect() + val insert = spark.sql("insert into table table_insert_varchar values" + + "(5,'',''), (13,'6884578', null), (6,'72135', '666')") + insert.collect() + + val select = spark.sql("select * from table_insert_varchar order by id, c_varchar, p_varchar") + val runRows = select.collect() + val expectedRows = Seq(Row(5, "", null), Row(6, "72135", "666"), Row(13, "6884578", null)) + assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + + val dropNP = spark.sql("drop table if exists table_insert_varchar_np") + dropNP.collect() + val createTableNP = spark.sql("create table table_insert_varchar_np" + + "(id int, c_varchar varchar(40)) using orc partitioned by " + + "(p_varchar1 int, p_varchar2 varchar(40), p_varchar3 varchar(40))") + createTableNP.collect() + val insertNP = spark.sql("insert into table table_insert_varchar_np values" + + "(5,'',1,'',''), (13,'6884578',6, null, null), (1,'abc',1,'',''), " + + "(3,'abcde',6,null,null), (4,'qqqqq', 8, 'a', 'b'), (6,'ooooo', 8, 'a', 'b')") + val selectNP = spark.sql("select * from table_insert_varchar_np " + + "order by id, c_varchar, p_varchar1") + val runRowsNP = selectNP.collect() + val expectedRowsNP = Seq(Row(1, "abc", 1, null, null), Row(3, "abcde", 6, null, null), + Row(4, "qqqqq", 8, "a", "b"), Row(5, "", 1, null, null), Row(6, "ooooo", 8, "a", "b"), + Row(13, "6884578", 6, null, null)) + assert(QueryTest.sameRows(runRowsNP, expectedRowsNP).isEmpty, "the run value is error") + } } -- Gitee From 2f987ef172fc312f14597d745725c7a539839ce2 Mon Sep 17 00:00:00 2001 From: zhousipei Date: Tue, 3 Sep 2024 14:31:59 +0800 Subject: [PATCH 04/12] fix parquet reader schema change bug --- .../src/jni/ParquetColumnarBatchJniReader.cpp | 45 ++++++++++-- .../src/jni/ParquetColumnarBatchJniReader.h | 16 +++++ .../cpp/src/parquet/ParquetReader.cpp | 11 +-- .../cpp/src/parquet/ParquetReader.h | 8 ++- .../cpp/test/tablescan/parquet_scan_test.cpp | 8 ++- .../jni/ParquetColumnarBatchJniReader.java | 6 ++ .../jni/ParquetColumnarBatchScanReader.java | 71 ++++++++++++------- .../OmniParquetColumnarBatchReader.java | 52 +++++++++++--- .../ParquetColumnarBatchJniReaderTest.java | 11 +-- 9 files changed, 177 insertions(+), 51 deletions(-) diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp index f871f2c3d..13dba0154 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp @@ -78,6 +78,23 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJn // Get capacity for each record batch int64_t capacity = (int64_t)env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("capacity")); + std::unique_ptr rebaseInfoPtr = common::BuildTimeRebaseInfo(env, jsonObj); + + ParquetReader *pReader = new ParquetReader(rebaseInfoPtr); + auto state = pReader->InitReader(uriInfo, capacity, ugiString); + if (state != Status::OK()) { + env->ThrowNew(runtimeExceptionClass, state.ToString().c_str()); + return 0; + } + return (jlong)(pReader); + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_initializeRecordReader + (JNIEnv *env, jobject jObj, jlong reader, jobject jsonObj) +{ + JNI_FUNC_START + ParquetReader *pReader = (ParquetReader *)reader; int64_t start = (int64_t)env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("start")); int64_t end = (int64_t)env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("end")); @@ -96,11 +113,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJn auto fieldNames = GetFieldNames(env, jsonObj); - std::unique_ptr rebaseInfoPtr = common::BuildTimeRebaseInfo(env, jsonObj); - - ParquetReader *pReader = new ParquetReader(rebaseInfoPtr); - auto state = pReader->InitRecordReader(uriInfo, start, end, capacity, hasExpressionTree, pushedFilterArray, - fieldNames, ugiString); + auto state = pReader->InitRecordReader(start, end, hasExpressionTree, pushedFilterArray, fieldNames); if (state != Status::OK()) { env->ThrowNew(runtimeExceptionClass, state.ToString().c_str()); return 0; @@ -109,6 +122,28 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJn JNI_FUNC_END(runtimeExceptionClass) } +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_getAllFieldNames + (JNIEnv *env, jobject jObj, jlong reader, jobject allFieldNames) +{ + JNI_FUNC_START + ParquetReader *pReader = (ParquetReader *)reader; + std::shared_ptr schema; + auto state = pReader->arrow_reader->GetSchema(&schema); + if (state != Status::OK()) { + env->ThrowNew(runtimeExceptionClass, state.ToString().c_str()); + return 0; + } + std::vector columnNames = schema->field_names(); + auto num = columnNames.size(); + for (uint32_t i = 0; i < num; i++) { + jstring fieldName = env->NewStringUTF(columnNames[i].c_str()); + env->CallBooleanMethod(allFieldNames, arrayListAdd, fieldName); + env->DeleteLocalRef(fieldName); + } + return (jlong)(num); + JNI_FUNC_END(runtimeExceptionClass) +} + JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_recordReaderNext(JNIEnv *env, jobject jObj, jlong reader, jlongArray vecNativeId) { diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.h b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.h index a37456747..b5b382760 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.h @@ -44,6 +44,22 @@ extern "C" { JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_initializeReader (JNIEnv* env, jobject jObj, jobject job); +/* + * Class: com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader + * Method: initializeRecordReader + * Signature: (JLorg/json/JSONObject;)J + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_initializeRecordReader + (JNIEnv *, jobject, jlong, jobject); + +/* + * Class: com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader + * Method: getAllFieldNames + * Signature: (JLjava/util/ArrayList;)J + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_getAllFieldNames + (JNIEnv *, jobject, jlong, jobject); + /* * Class: com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader * Method: recordReaderNext diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp index 83b0f3265..9eab7507f 100644 --- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp @@ -84,8 +84,7 @@ Filesystem* omniruntime::reader::GetFileSystemPtr(UriInfo &uri, std::string& ugi return restore_filesysptr[key]; } -Status ParquetReader::InitRecordReader(UriInfo &uri, int64_t start, int64_t end, int64_t capacity, bool hasExpressionTree, - Expression pushedFilterArray, const std::vector& fieldNames, std::string& ugi) +Status ParquetReader::InitReader(UriInfo &uri, int64_t capacity, std::string& ugi) { // Configure reader settings auto reader_properties = parquet::ReaderProperties(pool); @@ -94,8 +93,6 @@ Status ParquetReader::InitRecordReader(UriInfo &uri, int64_t start, int64_t end, auto arrow_reader_properties = parquet::ArrowReaderProperties(); arrow_reader_properties.set_batch_size(capacity); - std::shared_ptr file; - // Get the file from filesystem Status result; mutex_.lock(); @@ -113,6 +110,12 @@ Status ParquetReader::InitRecordReader(UriInfo &uri, int64_t start, int64_t end, reader_builder.properties(arrow_reader_properties); ARROW_ASSIGN_OR_RAISE(arrow_reader, reader_builder.Build()); + return arrow::Status::OK(); +} + +Status ParquetReader::InitRecordReader(int64_t start, int64_t end, bool hasExpressionTree, + Expression pushedFilterArray, const std::vector& fieldNames) +{ std::vector row_group_indices; auto filesource = std::make_shared(file); ARROW_RETURN_NOT_OK(GetRowGroupIndices(*filesource, start, end, hasExpressionTree, pushedFilterArray, row_group_indices)); diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.h b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.h index 5bbd3a503..3d1645054 100644 --- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.h @@ -53,8 +53,10 @@ namespace omniruntime::reader { ParquetReader(std::unique_ptr &rebaseInfoPtr) : rebaseInfoPtr(std::move(rebaseInfoPtr)) {} - arrow::Status InitRecordReader(UriInfo &uri, int64_t start, int64_t end, int64_t capacity, bool hasExpressionTree, - Expression pushedFilterArray, const std::vector& fieldNames, std::string& ugi); + arrow::Status InitReader(UriInfo &uri, int64_t capacity, std::string& ugi); + + arrow::Status InitRecordReader(int64_t start, int64_t end, bool hasExpressionTree, + Expression pushedFilterArray, const std::vector& fieldNames); arrow::Status ReadNextBatch(std::vector &batch, long *batchRowSize); @@ -84,6 +86,8 @@ namespace omniruntime::reader { const std::shared_ptr &ctx, std::unique_ptr* out); std::unique_ptr rebaseInfoPtr; + + std::shared_ptr file; }; class Filesystem { diff --git a/omnioperator/omniop-native-reader/cpp/test/tablescan/parquet_scan_test.cpp b/omnioperator/omniop-native-reader/cpp/test/tablescan/parquet_scan_test.cpp index 26a752f5d..acc83d51d 100644 --- a/omnioperator/omniop-native-reader/cpp/test/tablescan/parquet_scan_test.cpp +++ b/omnioperator/omniop-native-reader/cpp/test/tablescan/parquet_scan_test.cpp @@ -44,7 +44,9 @@ TEST(read, test_parquet_reader) ParquetReader *reader = new ParquetReader(rebaseInfoPtr); std::string ugi = "root@sample"; Expression pushedFilterArray; - auto state1 = reader->InitRecordReader(uriInfo, 0, 1000000, 1024, false, pushedFilterArray, column_indices, ugi); + auto state0 = reader->InitReader(uriInfo, 1024, ugi); + ASSERT_EQ(state0, arrow::Status::OK()); + auto state1 = reader->InitRecordReader(0, 1000000, false, pushedFilterArray, column_indices); ASSERT_EQ(state1, arrow::Status::OK()); std::vector recordBatch(column_indices.size()); @@ -113,7 +115,9 @@ TEST(read, test_varchar) ParquetReader *reader = new ParquetReader(rebaseInfoPtr); std::string ugi = "root@sample"; Expression pushedFilterArray; - auto state1 = reader->InitRecordReader(uriInfo, 0, 1000000, 4096, false, pushedFilterArray, column_indices, ugi); + auto state0 = reader->InitReader(uriInfo, 4096, ugi); + ASSERT_EQ(state0, arrow::Status::OK()); + auto state1 = reader->InitRecordReader(0, 1000000, false, pushedFilterArray, column_indices); ASSERT_EQ(state1, arrow::Status::OK()); int total_nums = 0; int iter = 0; diff --git a/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/ParquetColumnarBatchJniReader.java b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/ParquetColumnarBatchJniReader.java index b740b726c..a02be6b9e 100644 --- a/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/ParquetColumnarBatchJniReader.java +++ b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/ParquetColumnarBatchJniReader.java @@ -19,6 +19,8 @@ package com.huawei.boostkit.scan.jni; import org.json.JSONObject; +import java.util.ArrayList; + public class ParquetColumnarBatchJniReader { public ParquetColumnarBatchJniReader() { @@ -27,6 +29,10 @@ public class ParquetColumnarBatchJniReader { public native long initializeReader(JSONObject job); + public native long initializeRecordReader(long parquetReader, JSONObject job); + + public native long getAllFieldNames(long parquetReader, ArrayList allFieldNames); + public native long recordReaderNext(long parquetReader, long[] vecNativeId); public native void recordReaderClose(long parquetReader); diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java index 3c1e7dba1..5fce3f089 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java @@ -55,6 +55,7 @@ import java.math.BigDecimal; import java.net.URI; import java.time.Instant; import java.time.LocalDate; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -81,6 +82,8 @@ public class ParquetColumnarBatchScanReader { private final List parquetTypes; + private ArrayList allFieldsNames; + public ParquetColumnarBatchScanReader(StructType requiredSchema, RebaseSpec datetimeRebaseSpec, RebaseSpec int96RebaseSpec, List parquetTypes) { this.requiredSchema = requiredSchema; @@ -124,8 +127,7 @@ public class ParquetColumnarBatchScanReader { } } - public long initializeReaderJava(Path path, long start, long end, int capacity, - Filter pushedFilter) throws UnsupportedEncodingException { + public long initializeReaderJava(Path path, int capacity) throws UnsupportedEncodingException { JSONObject job = new JSONObject(); URI uri = path.toUri(); @@ -135,11 +137,7 @@ public class ParquetColumnarBatchScanReader { job.put("port", uri.getPort()); job.put("path", uri.getPath() == null ? "" : uri.getPath()); - job.put("start", start); - job.put("end", end); - job.put("capacity", capacity); - job.put("fieldNames", requiredSchema.fieldNames()); String ugi = null; try { @@ -149,15 +147,35 @@ public class ParquetColumnarBatchScanReader { } job.put("ugi", ugi); + addJulianGregorianInfo(job); + + parquetReader = jniReader.initializeReader(job); + return parquetReader; + } + + public long initializeRecordReaderJava(long start, long end, String[] fieldNames, Filter pushedFilter) + throws UnsupportedEncodingException { + JSONObject job = new JSONObject(); + job.put("start", start); + job.put("end", end); + job.put("fieldNames", fieldNames); + if (pushedFilter != null) { pushDownFilter(pushedFilter, job); } - addJulianGregorianInfo(job); - parquetReader = jniReader.initializeReader(job); + parquetReader = jniReader.initializeRecordReader(parquetReader, job); return parquetReader; } + public ArrayList getAllFieldsNames() { + if (allFieldsNames == null) { + allFieldsNames = new ArrayList<>(); + jniReader.getAllFieldNames(parquetReader, allFieldsNames); + } + return allFieldsNames; + } + private void pushDownFilter(Filter pushedFilter, JSONObject job) { try { JSONObject jsonExpressionTree = getSubJson(pushedFilter); @@ -361,43 +379,48 @@ public class ParquetColumnarBatchScanReader { } } - public int next(Vec[] vecList, List types) { - int vectorCnt = vecList.length; - long[] vecNativeIds = new long[vectorCnt]; + public int next(Vec[] vecList, boolean[] missingColumns, List types) { + int colsCount = missingColumns.length; + long[] vecNativeIds = new long[types.size()]; long rtn = jniReader.recordReaderNext(parquetReader, vecNativeIds); if (rtn == 0) { return 0; } - for (int i = 0; i < vectorCnt; i++) { - DataType type = types.get(i); + int nativeGetId = 0; + for (int i = 0; i < colsCount; i++) { + if (missingColumns[i]) { + continue; + } + DataType type = types.get(nativeGetId); if (type instanceof LongType) { - vecList[i] = new LongVec(vecNativeIds[i]); + vecList[i] = new LongVec(vecNativeIds[nativeGetId]); } else if (type instanceof BooleanType) { - vecList[i] = new BooleanVec(vecNativeIds[i]); + vecList[i] = new BooleanVec(vecNativeIds[nativeGetId]); } else if (type instanceof ShortType) { - vecList[i] = new ShortVec(vecNativeIds[i]); + vecList[i] = new ShortVec(vecNativeIds[nativeGetId]); } else if (type instanceof IntegerType) { - vecList[i] = new IntVec(vecNativeIds[i]); + vecList[i] = new IntVec(vecNativeIds[nativeGetId]); } else if (type instanceof DecimalType) { if (DecimalType.is64BitDecimalType(type)) { - vecList[i] = new LongVec(vecNativeIds[i]); + vecList[i] = new LongVec(vecNativeIds[nativeGetId]); } else { - vecList[i] = new Decimal128Vec(vecNativeIds[i]); + vecList[i] = new Decimal128Vec(vecNativeIds[nativeGetId]); } } else if (type instanceof DoubleType) { - vecList[i] = new DoubleVec(vecNativeIds[i]); + vecList[i] = new DoubleVec(vecNativeIds[nativeGetId]); } else if (type instanceof StringType) { - vecList[i] = new VarcharVec(vecNativeIds[i]); + vecList[i] = new VarcharVec(vecNativeIds[nativeGetId]); } else if (type instanceof DateType) { - vecList[i] = new IntVec(vecNativeIds[i]); + vecList[i] = new IntVec(vecNativeIds[nativeGetId]); } else if (type instanceof ByteType) { - vecList[i] = new VarcharVec(vecNativeIds[i]); + vecList[i] = new VarcharVec(vecNativeIds[nativeGetId]); } else if (type instanceof TimestampType) { - vecList[i] = new LongVec(vecNativeIds[i]); + vecList[i] = new LongVec(vecNativeIds[nativeGetId]); tryToAdjustTimestampVec((LongVec) vecList[i], rtn, i); } else { throw new RuntimeException("Unsupport type for ColumnarFileScan: " + type.typeName()); } + nativeGetId++; } return (int)rtn; } diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/parquet/OmniParquetColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/parquet/OmniParquetColumnarBatchReader.java index 5e718b086..86d76be34 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/parquet/OmniParquetColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/parquet/OmniParquetColumnarBatchReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; import com.huawei.boostkit.spark.jni.ParquetColumnarBatchScanReader; import nova.hetu.omniruntime.vector.Vec; import org.apache.hadoop.mapreduce.InputSplit; @@ -46,6 +47,7 @@ public class OmniParquetColumnarBatchReader extends RecordReader allFieldsNames = reader.getAllFieldsNames(); + + ArrayList includeFieldNames = new ArrayList<>(); + for (int i = 0; i < requiredFieldNames.length; i++) { + String target = requiredFieldNames[i]; + if (allFieldsNames.contains(target)) { + missingColumns[i] = false; + includeFieldNames.add(target); + types.add(structFields[i].dataType()); + } else { + missingColumns[i] = true; + } + } + return includeFieldNames.toArray(new String[includeFieldNames.size()]); + } // Creates a columnar batch that includes the schema from the data files and the additional // partition columns appended to the end of the batch. @@ -138,7 +161,6 @@ public class OmniParquetColumnarBatchReader extends RecordReader types; @@ -59,9 +60,11 @@ public class ParquetColumnarBatchJniReaderTest extends TestCase { File file = new File("src/test/java/com/huawei/boostkit/spark/jni/parquetsrc/parquet_data_all_type"); String path = file.getAbsolutePath(); - parquetColumnarBatchScanReader.initializeReaderJava(new Path(path), 0, 100000, - 4096, null); - vecs = new Vec[9]; + parquetColumnarBatchScanReader.initializeReaderJava(new Path(path), 4096); + parquetColumnarBatchScanReader.initializeRecordReaderJava(0, 100000, schema.fieldNames(), null); + missingColumns = new boolean[schema.fieldNames().length]; + Arrays.fill(missingColumns, false); + vecs = new Vec[schema.fieldNames().length]; } private void constructSchema() { @@ -93,7 +96,7 @@ public class ParquetColumnarBatchJniReaderTest extends TestCase { @Test public void testRead() { - long num = parquetColumnarBatchScanReader.next(vecs, types); + long num = parquetColumnarBatchScanReader.next(vecs, missingColumns, types); assertTrue(num == 1); } } -- Gitee From 48eedffb7ac08869d02950ae89aadc549039d5b0 Mon Sep 17 00:00:00 2001 From: panmingyi Date: Tue, 22 Oct 2024 16:21:33 +0800 Subject: [PATCH 05/12] fix unix time function --- .../boostkit/spark/ColumnarPluginConfig.scala | 10 +++++ .../expression/OmniExpressionAdaptor.scala | 30 ++++++++------ .../expressions/ColumnarFuncSuite.scala | 40 +++++++++++++++++++ 3 files changed, 67 insertions(+), 13 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index eb8bf478d..26b6f8f1d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -183,6 +183,10 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { def adaptivePartialAggregationRatio: Double = conf.getConf(ADAPTIVE_PARTIAL_AGGREGATION_RATIO) + def timeParserPolicy: String = conf.getConfString("spark.sql.legacy.timeParserPolicy") + + def enableOmniUnixTimeFunc: Boolean = conf.getConf(ENABLE_OMNI_UNIXTIME_FUNCTION) + } @@ -636,4 +640,10 @@ object ColumnarPluginConfig { .doubleConf .createWithDefault(0.8) + val ENABLE_OMNI_UNIXTIME_FUNCTION = buildConf("spark.omni.sql.columnar.unixTimeFunc.enabled") + .internal() + .doc("enable omni unix_timestamp and from_unixtime") + .booleanConf + .createWithDefault(true) + } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 3152d6c7c..c5c48af0f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -101,21 +101,24 @@ object OmniExpressionAdaptor extends Logging { } private val timeFormatSet: Set[String] = Set("yyyy-MM-dd HH:mm:ss", "yyyy-MM-dd") - - private def unsupportedTimeFormatCheck(timeFormat: String): Unit = { - if (!timeFormatSet.contains(timeFormat)) { - throw new UnsupportedOperationException(s"Unsupported Time Format: $timeFormat") - } - } - private val timeZoneSet: Set[String] = Set("GMT+08:00", "Asia/Beijing", "Asia/Shanghai") - private def unsupportedTimeZoneCheck(timeZone: String): Unit = { + private def unsupportedUnixTimeFunction(timeFormat: String, timeZone: String): Unit = { + if (!ColumnarPluginConfig.getSessionConf.enableOmniUnixTimeFunc) { + throw new UnsupportedOperationException(s"Not Enabled Omni UnixTime Function") + } + if (ColumnarPluginConfig.getSessionConf.timeParserPolicy == "LEGACY") { + throw new UnsupportedOperationException(s"Unsupported Time Parser Policy: LEGACY") + } if (!timeZoneSet.contains(timeZone)) { throw new UnsupportedOperationException(s"Unsupported Time Zone: $timeZone") } + if (!timeFormatSet.contains(timeFormat)) { + throw new UnsupportedOperationException(s"Unsupported Time Format: $timeFormat") + } } + def toOmniTimeFormat(format: String): String = { format.replace("yyyy", "%Y") .replace("MM", "%m") @@ -480,20 +483,21 @@ object OmniExpressionAdaptor extends Logging { // for date time functions case unixTimestamp: UnixTimestamp => val timeZone = unixTimestamp.timeZoneId.getOrElse("") - unsupportedTimeZoneCheck(timeZone) - unsupportedTimeFormatCheck(unixTimestamp.format.toString) + unsupportedUnixTimeFunction(unixTimestamp.format.toString, timeZone) + val policy = ColumnarPluginConfig.getSessionConf.timeParserPolicy new JSONObject().put("exprType", "FUNCTION") .addOmniExpJsonType("returnType", unixTimestamp.dataType) .put("function_name", "unix_timestamp") .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(unixTimestamp.timeExp, exprsIndexMap)) .put(new JSONObject(toOmniTimeFormat(rewriteToOmniJsonExpressionLiteral(unixTimestamp.format, exprsIndexMap)))) .put(new JSONObject().put("exprType", "LITERAL").put("dataType", 15).put("isNull", timeZone.isEmpty()) - .put("value", timeZone).put("width", timeZone.length))) + .put("value", timeZone).put("width", timeZone.length)) + .put(new JSONObject().put("exprType", "LITERAL").put("dataType", 15).put("isNull", policy.isEmpty()) + .put("value", policy).put("width", policy.length))) case fromUnixTime: FromUnixTime => val timeZone = fromUnixTime.timeZoneId.getOrElse("") - unsupportedTimeZoneCheck(timeZone) - unsupportedTimeFormatCheck(fromUnixTime.format.toString) + unsupportedUnixTimeFunction(fromUnixTime.format.toString, timeZone) new JSONObject().put("exprType", "FUNCTION") .addOmniExpJsonType("returnType", fromUnixTime.dataType) .put("function_name", "from_unixtime") diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarFuncSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarFuncSuite.scala index 467ad35ce..20c861eea 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarFuncSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarFuncSuite.scala @@ -110,6 +110,46 @@ class ColumnarFuncSuite extends ColumnarSparkPlanTest { assertOmniProjectNotHappened(rollbackRes) } + test("Test Unix_timestamp Function") { + spark.conf.set("spark.sql.optimizer.excludedRules", "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") + spark.conf.set("spark.sql.session.timeZone", "Asia/Shanghai") + spark.conf.set("spark.sql.legacy.timeParserPolicy", "CORRECTED") + val res1 = spark.sql("select unix_timestamp('','yyyy-MM-dd'), unix_timestamp('123-abc', " + + "'yyyy-MM-dd HH:mm:ss'), unix_timestamp(NULL, 'yyyy-MM-dd')") + assertOmniProjectHappened(res1) + checkAnswer(res1, Seq(Row(null, null, null))) + + val res2 = spark.sql("select unix_timestamp('2024-10-21', 'yyyy-MM-dd'), " + + "unix_timestamp('2024-10-21 11:22:33', 'yyyy-MM-dd HH:mm:ss')") + assertOmniProjectHappened(res2) + checkAnswer(res2, Seq(Row(1729440000L, 1729480953L))) + + val res3 = spark.sql("select unix_timestamp('1986-08-10 05:05:05','yyyy-MM-dd HH:mm:ss')") + assertOmniProjectHappened(res3) + checkAnswer(res3, Seq(Row(524001905L))) + + val res4 = spark.sql("select unix_timestamp('2086-08-10 05:05:05','yyyy-MM-dd HH:mm:ss')") + assertOmniProjectHappened(res4) + checkAnswer(res4, Seq(Row(3679765505L))) + } + + test("Test from_unixtime Function") { + spark.conf.set("spark.sql.optimizer.excludedRules", "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") + spark.conf.set("spark.sql.session.timeZone", "Asia/Shanghai") + spark.conf.set("spark.sql.legacy.timeParserPolicy", "CORRECTED") + val res1 = spark.sql("select from_unixtime(1, 'yyyy-MM-dd HH:mm:ss'), from_unixtime(1, 'yyyy-MM-dd')") + assertOmniProjectHappened(res1) + checkAnswer(res1, Seq(Row("1970-01-01 08:00:01", "1970-01-01"))) + + val res2 = spark.sql("select from_unixtime(524001905, 'yyyy-MM-dd HH:mm:ss'), from_unixtime(524001905, 'yyyy-MM-dd')") + assertOmniProjectHappened(res2) + checkAnswer(res2, Seq(Row("1986-08-10 05:05:05", "1986-08-10"))) + + val res3 = spark.sql("select from_unixtime(3679765505, 'yyyy-MM-dd HH:mm:ss'), from_unixtime(3679765505, 'yyyy-MM-dd')") + assertOmniProjectHappened(res3) + checkAnswer(res3, Seq(Row("2086-08-10 05:05:05", "2086-08-10"))) + } + private def assertOmniProjectHappened(res: DataFrame) = { val executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") -- Gitee From 50759918f39382d04cf987ffe2d16d7cbffaa8a3 Mon Sep 17 00:00:00 2001 From: liujingxiang Date: Wed, 23 Oct 2024 09:59:47 +0800 Subject: [PATCH 06/12] [spark extension] fix data inconsistency when the pushOrderedLimitThroughAgg rule and adaptivePartialAggregation rule are used at the same time. --- .../execution/aggregate/PushOrderedLimitThroughAgg.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala index 422435694..327c3426c 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala @@ -27,10 +27,14 @@ import org.apache.spark.sql.SparkSession case class PushOrderedLimitThroughAgg(session: SparkSession) extends Rule[SparkPlan] with PredicateHelper { override def apply(plan: SparkPlan): SparkPlan = { - if (!ColumnarPluginConfig.getSessionConf.pushOrderedLimitThroughAggEnable) { + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + // The two optimization principles are contrary and cannot be used at the same time. + // reason: the pushOrderedLimitThroughAgg rule depends on the actual aggregation result in the partial phase. + // However, if the partial phase is skipped, aggregation is not performed. + if (!columnarConf.pushOrderedLimitThroughAggEnable || columnarConf.enableAdaptivePartialAggregation) { return plan } - val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableColumnarTopNSort: Boolean = columnarConf.enableColumnarTopNSort plan.transform { -- Gitee From 03b959df06b757a3eb3b7bb8f352cb84f01f980f Mon Sep 17 00:00:00 2001 From: huanglong Date: Wed, 23 Oct 2024 09:30:17 +0800 Subject: [PATCH 07/12] reduce simpleProject Signed-off-by: huanglong --- .../com/huawei/boostkit/spark/ColumnarPlugin.scala | 12 ++++++++++++ .../spark/expression/OmniExpressionAdaptor.scala | 9 +++++++++ 2 files changed, 21 insertions(+) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index d19d1a467..6beadbd19 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -145,6 +145,18 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) } else { ColumnarProjectExec(plan.projectList, child) } + case scan: ColumnarFileSourceScanExec if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleAttribute(project))) => + ColumnarFileSourceScanExec( + scan.relation, + plan.output, + scan.requiredSchema, + scan.partitionFilters, + scan.optionalBucketSet, + scan.optionalNumCoalescedBuckets, + scan.dataFilters, + scan.tableIdentifier, + scan.disableBucketedScan + ) case _ => ColumnarProjectExec(plan.projectList, child) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index f15672927..7479ee08c 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -1035,4 +1035,13 @@ object OmniExpressionAdaptor extends Logging { false } } + + def isSimpleAttribute(project: NamedExpression): Boolean = { + project match { + case attribute: AttributeReference => + true + case _ => + false + } + } } -- Gitee From 2ca3c7ab4589055ae12fd5fb7ff94177630fa68a Mon Sep 17 00:00:00 2001 From: zhousipei Date: Thu, 24 Oct 2024 19:33:28 +0800 Subject: [PATCH 08/12] avoid throw exception for IsNotNull/IsNull/In predicate and fix bug --- .../spark/jni/OrcColumnarBatchScanReader.java | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java index 7147fb45d..1aab7425b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java @@ -419,20 +419,21 @@ public class OrcColumnarBatchScanReader { addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN_EQUALS, jsonLeaves, ((LessThanOrEqual) filterPredicate).attribute(), ((LessThanOrEqual) filterPredicate).value(), null); leafIndex++; + // For IsNotNull/IsNull/In, pass literal = "" to native to avoid throwing exception. } else if (filterPredicate instanceof IsNotNull) { addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true); addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IS_NULL, jsonLeaves, - ((IsNotNull) filterPredicate).attribute(), null, null); + ((IsNotNull) filterPredicate).attribute(), "", null); leafIndex++; } else if (filterPredicate instanceof IsNull) { addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IS_NULL, jsonLeaves, - ((IsNull) filterPredicate).attribute(), null, null); + ((IsNull) filterPredicate).attribute(), "", null); leafIndex++; } else if (filterPredicate instanceof In) { addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IN, jsonLeaves, - ((In) filterPredicate).attribute(), null, Arrays.stream(((In) filterPredicate).values()).toArray()); + ((In) filterPredicate).attribute(), "", Arrays.stream(((In) filterPredicate).values()).toArray()); leafIndex++; } else { throw new UnsupportedOperationException("Unsupported orc push down filter operation: " + @@ -451,8 +452,8 @@ public class OrcColumnarBatchScanReader { ArrayList literalList = new ArrayList<>(); if (literals != null) { - for (Object lit: literalList) { - literalList.add(getLiteralValue(literal)); + for (Object lit: literals) { + literalList.add(getLiteralValue(lit)); } } leafJson.put("literalList", literalList); -- Gitee From d46655d5474131a11ec3f0556d3956fd6188729c Mon Sep 17 00:00:00 2001 From: huanglong Date: Fri, 25 Oct 2024 14:53:51 +0800 Subject: [PATCH 09/12] Revert "reduce simpleProject" This reverts commit 03b959df06b757a3eb3b7bb8f352cb84f01f980f. --- .../com/huawei/boostkit/spark/ColumnarPlugin.scala | 12 ------------ .../spark/expression/OmniExpressionAdaptor.scala | 9 --------- 2 files changed, 21 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index 6beadbd19..d19d1a467 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -145,18 +145,6 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) } else { ColumnarProjectExec(plan.projectList, child) } - case scan: ColumnarFileSourceScanExec if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleAttribute(project))) => - ColumnarFileSourceScanExec( - scan.relation, - plan.output, - scan.requiredSchema, - scan.partitionFilters, - scan.optionalBucketSet, - scan.optionalNumCoalescedBuckets, - scan.dataFilters, - scan.tableIdentifier, - scan.disableBucketedScan - ) case _ => ColumnarProjectExec(plan.projectList, child) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index dbd9635b5..128d0790b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -1039,13 +1039,4 @@ object OmniExpressionAdaptor extends Logging { false } } - - def isSimpleAttribute(project: NamedExpression): Boolean = { - project match { - case attribute: AttributeReference => - true - case _ => - false - } - } } -- Gitee From f97436c9dd1e1caa93ef78fb82431365089adb93 Mon Sep 17 00:00:00 2001 From: zhousipei Date: Wed, 30 Oct 2024 11:21:07 +0800 Subject: [PATCH 10/12] fix parquet push down missing columns error --- .../boostkit/spark/jni/ParquetColumnarBatchScanReader.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java index 5fce3f089..40fcb06d9 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java @@ -276,7 +276,12 @@ public class ParquetColumnarBatchScanReader { private void putCompareOp(JSONObject json, ParquetPredicateOperator op, String field, Object value) { json.put("op", op.ordinal()); - json.put("field", field); + if (allFieldsNames.contains(field)) { + json.put("field", field); + } else { + throw new ParquetDecodingException("Unsupported parquet push down missing columns: " + field); + } + if (value == null) { json.put("type", 0); } else { -- Gitee From 3c1573521672114cd5277f2e10b6383c5397419a Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Sat, 26 Oct 2024 20:04:19 +0800 Subject: [PATCH 11/12] compatible to spark 3.4.3 --- .../omniop-native-reader/java/pom.xml | 8 +- .../omniop-spark-extension/java/pom.xml | 4 +- .../boostkit/spark/ColumnarPlugin.scala | 2 +- .../boostkit/spark/TransformHintRule.scala | 5 + .../expression/OmniExpressionAdaptor.scala | 37 ++- .../sql/catalyst/Tree/TreePatterns.scala | 12 + .../catalyst/expressions/runtimefilter.scala | 4 +- .../optimizer/InjectRuntimeFilter.scala | 16 +- .../optimizer/MergeSubqueryFilters.scala | 2 +- .../RewriteSelfJoinInInPredicate.scala | 2 +- .../ColumnarBasicPhysicalOperators.scala | 8 +- .../ColumnarFileSourceScanExec.scala | 4 +- .../spark/sql/execution/QueryExecution.scala | 30 ++- .../adaptive/AdaptiveSparkPlanExec.scala | 8 +- .../adaptive/InsertAdaptiveSparkPlan.scala | 2 + .../adaptive/PlanAdaptiveSubqueries.scala | 6 +- .../PushOrderedLimitThroughAgg.scala | 2 +- .../datasources/OmniFileFormatWriter.scala | 238 +++++++++++++----- ...mniInsertIntoHadoopFsRelationCommand.scala | 12 +- .../datasources/orc/OmniOrcFileFormat.scala | 2 +- .../parquet/OmniParquetFileFormat.scala | 2 +- .../sql/execution/util/BroadcastUtils.scala | 1 + .../test/resources/HiveResource.properties | 6 +- .../sql/catalyst/expressions/CastSuite.scala | 37 +-- .../ColumnarDecimalCastSuite.scala | 5 +- .../expressions/DecimalOperationSuite.scala | 26 +- .../optimizer/CombiningLimitsSuite.scala | 2 +- .../optimizer/MergeSubqueryFiltersSuite.scala | 2 +- .../ColumnarAdaptiveQueryExecSuite.scala | 47 ++-- omnioperator/omniop-spark-extension/pom.xml | 4 +- 30 files changed, 369 insertions(+), 167 deletions(-) diff --git a/omnioperator/omniop-native-reader/java/pom.xml b/omnioperator/omniop-native-reader/java/pom.xml index e7ddfe6c3..3cd67b1fb 100644 --- a/omnioperator/omniop-native-reader/java/pom.xml +++ b/omnioperator/omniop-native-reader/java/pom.xml @@ -8,13 +8,13 @@ com.huawei.boostkit boostkit-omniop-native-reader jar - 3.3.1-1.6.0 + 3.4.3-1.6.0 BoostKit Spark Native Sql Engine Extension With OmniOperator 2.12 - 3.3.1 + 3.4.3 FALSE ../cpp/ ../cpp/build/releases/ @@ -35,8 +35,8 @@ org.slf4j - slf4j-api - 1.7.32 + slf4j-simple + 1.7.36 junit diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml index 9cc1b9d25..138665893 100644 --- a/omnioperator/omniop-spark-extension/java/pom.xml +++ b/omnioperator/omniop-spark-extension/java/pom.xml @@ -7,7 +7,7 @@ com.huawei.kunpeng boostkit-omniop-spark-parent - 3.3.1-1.6.0 + 3.4.3-1.6.0 ../pom.xml @@ -52,7 +52,7 @@ com.huawei.boostkit boostkit-omniop-native-reader - 3.3.1-1.6.0 + 3.4.3-1.6.0 junit diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index d19d1a467..e03798e8b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -563,7 +563,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) var unSupportedColumnarCommand = false var unSupportedFileFormat = false val omniCmd = plan.cmd match { - case cmd: InsertIntoHadoopFsRelationCommand => + case cmd: OmniInsertIntoHadoopFsRelationCommand => logInfo(s"Columnar Processing for ${cmd.getClass} is currently supported.") val fileFormat: FileFormat = cmd.fileFormat match { case _: OrcFileFormat => new OmniOrcFileFormat() diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala index 553463d56..96c88cecc 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.datasources.WriteFilesExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ColumnarBroadcastHashJoinExec, ColumnarShuffledHashJoinExec, ColumnarSortMergeJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec @@ -391,6 +392,10 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { } ColumnarDataWritingCommandExec(plan.cmd, plan.child).buildCheck() TransformHints.tagTransformable(plan) + case plan: WriteFilesExec => + TransformHints.tagNotTransformable( + plan, "data writing is not support" + ) case _ => TransformHints.tagTransformable(plan) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 128d0790b..29a09782b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -26,6 +26,7 @@ import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_AVG, import nova.hetu.omniruntime.constants.JoinType._ import nova.hetu.omniruntime.operator.OmniExprVerify import com.huawei.boostkit.spark.ColumnarPluginConfig +import nova.hetu.omniruntime.`type`.DataType.DataTypeId import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -38,6 +39,7 @@ import org.apache.spark.sql.execution.ColumnarBloomFilterSubquery import org.apache.spark.sql.execution.datasources.OmniFileFormatWriter.Empty2Null import org.apache.spark.sql.expression.ColumnarExpressionConverter import org.apache.spark.sql.hive.HiveUdfAdaptorUtil +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, NullType, ShortType, StringType, TimestampType} import org.json.{JSONArray, JSONObject} @@ -77,18 +79,22 @@ object OmniExpressionAdaptor extends Logging { } } - private def unsupportedCastCheck(expr: Expression, cast: CastBase): Unit = { + private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = { def doSupportCastToString(dataType: DataType): Boolean = { - dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[IntegerType] || + (dataType.isInstanceOf[DecimalType] && !SQLConf.get.ansiEnabled) || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[IntegerType] || dataType.isInstanceOf[LongType] || dataType.isInstanceOf[DateType] || dataType.isInstanceOf[DoubleType] || dataType.isInstanceOf[NullType] } def doSupportCastFromString(dataType: DataType): Boolean = { - dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[DateType] || + !unSupportCastFromStringToDecimal(dataType) || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[DateType] || dataType.isInstanceOf[IntegerType] || dataType.isInstanceOf[LongType] || dataType.isInstanceOf[DoubleType] } + def unSupportCastFromStringToDecimal(dataType: DataType): Boolean = { + dataType.isInstanceOf[DecimalType] && dataType.asInstanceOf[DecimalType].precision == DecimalType.MAX_PRECISION + } + // support cast(decimal/string/int/long as string) if (cast.dataType.isInstanceOf[StringType] && !doSupportCastToString(cast.child.dataType)) { throw new UnsupportedOperationException(s"Unsupported expression: $expr") @@ -101,6 +107,20 @@ object OmniExpressionAdaptor extends Logging { } + private def unsupportedModCheck(mod: Remainder): Unit = { + val leftDataType = mod.left.dataType + val rightDataType = mod.right.dataType + leftDataType match { + case decimalType: DecimalType if rightDataType.isInstanceOf[DecimalType] => + val leftOmniType = sparkTypeToOmniType(decimalType) + val rightOmniType = sparkTypeToOmniType(rightDataType.asInstanceOf[DecimalType]) + if (leftOmniType == DataTypeId.OMNI_DECIMAL64.toValue && rightOmniType == DataTypeId.OMNI_DECIMAL128.toValue && !SQLConf.get.ansiEnabled) { + throw new UnsupportedOperationException(s"Unsupported mod Type: $leftDataType % $rightDataType") + } + case _ => + } + } + private val timeFormatSet: Set[String] = Set("yyyy-MM-dd HH:mm:ss", "yyyy-MM-dd") private val timeZoneSet: Set[String] = Set("GMT+08:00", "Asia/Beijing", "Asia/Shanghai") @@ -189,8 +209,8 @@ object OmniExpressionAdaptor extends Logging { throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}") } - case promotePrecision: PromotePrecision => - rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap) +// case promotePrecision: PromotePrecision => +// rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap) case sub: Subtract => new JSONObject().put("exprType", "BINARY") @@ -221,6 +241,7 @@ object OmniExpressionAdaptor extends Logging { .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(divide.right, exprsIndexMap)) case mod: Remainder => + unsupportedModCheck(mod) new JSONObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) .put("operator", "MODULUS") @@ -330,7 +351,7 @@ object OmniExpressionAdaptor extends Logging { .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject (empty2Null.child, exprsIndexMap))) // Cast - case cast: CastBase => + case cast: Cast => unsupportedCastCheck(expr, cast) cast.child.dataType match { case NullType => @@ -650,10 +671,10 @@ object OmniExpressionAdaptor extends Logging { rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap) } else { children.head match { - case base: CastBase if base.child.dataType.isInstanceOf[NullType] => + case base: Cast if base.child.dataType.isInstanceOf[NullType] => rewriteToOmniJsonExpressionLiteralJsonObject(children(1), exprsIndexMap) case _ => children(1) match { - case base: CastBase if base.child.dataType.isInstanceOf[NullType] => + case base: Cast if base.child.dataType.isInstanceOf[NullType] => rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap) case _ => new JSONObject().put("exprType", "FUNCTION") diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala index ea2712447..4d8c246ab 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala @@ -26,6 +26,7 @@ object TreePattern extends Enumeration { val AGGREGATE_EXPRESSION = Value(0) val ALIAS: Value = Value val AND_OR: Value = Value + val AND: Value = Value val ARRAYS_ZIP: Value = Value val ATTRIBUTE_REFERENCE: Value = Value val APPEND_COLUMNS: Value = Value @@ -58,6 +59,7 @@ object TreePattern extends Enumeration { val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value + val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value val LATERAL_SUBQUERY: Value = Value val LIKE_FAMLIY: Value = Value val LIST_SUBQUERY: Value = Value @@ -69,7 +71,10 @@ object TreePattern extends Enumeration { val NULL_CHECK: Value = Value val NULL_LITERAL: Value = Value val SERIALIZE_FROM_OBJECT: Value = Value + val OR: Value = Value val OUTER_REFERENCE: Value = Value + val PARAMETER: Value = Value + val PARAMETERIZED_QUERY: Value = Value val PIVOT: Value = Value val PLAN_EXPRESSION: Value = Value val PYTHON_UDF: Value = Value @@ -81,6 +86,7 @@ object TreePattern extends Enumeration { val SCALAR_SUBQUERY: Value = Value val SCALAR_SUBQUERY_REFERENCE: Value = Value val SCALA_UDF: Value = Value + val SESSION_WINDOW: Value = Value val SORT: Value = Value val SUBQUERY_ALIAS: Value = Value val SUBQUERY_WRAPPER: Value = Value @@ -89,7 +95,9 @@ object TreePattern extends Enumeration { val TIME_ZONE_AWARE_EXPRESSION: Value = Value val TRUE_OR_FALSE_LITERAL: Value = Value val WINDOW_EXPRESSION: Value = Value + val WINDOW_TIME: Value = Value val UNARY_POSITIVE: Value = Value + val UNPIVOT: Value = Value val UPDATE_FIELDS: Value = Value val UPPER_OR_LOWER: Value = Value val UP_CAST: Value = Value @@ -119,6 +127,7 @@ object TreePattern extends Enumeration { val UNION: Value = Value val UNRESOLVED_RELATION: Value = Value val UNRESOLVED_WITH: Value = Value + val TEMP_RESOLVED_COLUMN: Value = Value val TYPED_FILTER: Value = Value val WINDOW: Value = Value val WITH_WINDOW_DEFINITION: Value = Value @@ -127,6 +136,7 @@ object TreePattern extends Enumeration { val UNRESOLVED_ALIAS: Value = Value val UNRESOLVED_ATTRIBUTE: Value = Value val UNRESOLVED_DESERIALIZER: Value = Value + val UNRESOLVED_HAVING: Value = Value val UNRESOLVED_ORDINAL: Value = Value val UNRESOLVED_FUNCTION: Value = Value val UNRESOLVED_HINT: Value = Value @@ -135,6 +145,8 @@ object TreePattern extends Enumeration { // Unresolved Plan patterns (Alphabetically ordered) val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value val UNRESOLVED_FUNC: Value = Value + val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value + val UNRESOLVED_TVF_ALIASES: Value = Value // Execution expression patterns (alphabetically ordered) val IN_SUBQUERY_EXEC: Value = Value diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala index 0a5d509b0..85192fc36 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala @@ -39,7 +39,7 @@ case class RuntimeFilterSubquery( exprId: ExprId = NamedExpression.newExprId, hint: Option[HintInfo] = None) extends SubqueryExpression( - filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty) + filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty, hint) with Unevaluable with UnaryLike[Expression] { @@ -74,6 +74,8 @@ case class RuntimeFilterSubquery( override protected def withNewChildInternal(newChild: Expression): RuntimeFilterSubquery = copy(filterApplicationSideExp = newChild) + + override def withNewHint(hint: Option[HintInfo]): RuntimeFilterSubquery = copy(hint = hint) } /** diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 812c387bc..ac38c5399 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -218,9 +218,9 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J leftKey: Expression, rightKey: Expression): Boolean = { (left, right) match { - case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) => + case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan), _) => pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey) - case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) => + case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan)) => pruningKey.fastEquals(rightKey) || hasDynamicPruningSubquery(left, plan, leftKey, rightKey) case _ => false @@ -251,10 +251,10 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J rightKey: Expression): Boolean = { (left, right) match { case (Filter(InSubquery(Seq(key), - ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) => + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _), _) => key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey))) case (_, Filter(InSubquery(Seq(key), - ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) => + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _)) => key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey))) case _ => false } @@ -299,7 +299,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J case s: Subquery if s.correlated => plan case _ if !conf.runtimeFilterSemiJoinReductionEnabled && !conf.runtimeFilterBloomFilterEnabled => plan - case _ => tryInjectRuntimeFilter(plan) + case _ => + val newPlan = tryInjectRuntimeFilter(plan) + if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) { + RewritePredicateSubquery(newPlan) + } else { + newPlan + } } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala index 1b5baa230..c4435379f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala @@ -643,7 +643,7 @@ object MergeSubqueryFilters extends Rule[LogicalPlan] { val subqueryCTE = header.plan.asInstanceOf[CTERelationDef] GetStructField( ScalarSubquery( - CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output), + CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output, subqueryCTE.isStreaming), exprId = ssr.exprId), ssr.headerIndex) } else { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala index 9e4029025..f6ebd716d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala @@ -61,7 +61,7 @@ object RewriteSelfJoinInInPredicate extends Rule[LogicalPlan] with PredicateHelp case f: Filter => f transformExpressions { case in @ InSubquery(_, listQuery @ ListQuery(Project(projectList, - Join(left, right, Inner, Some(joinCond), _)), _, _, _, _)) + Join(left, right, Inner, Some(joinCond), _)), _, _, _, _, _)) if left.canonicalized ne right.canonicalized => val attrMapping = AttributeMap(right.output.zip(left.output)) val subCondExprs = splitConjunctivePredicates(joinCond transform { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index 486369843..52d784182 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -44,8 +44,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode - with AliasAwareOutputPartitioning - with AliasAwareOutputOrdering { + with PartitioningPreservingUnaryExecNode + with OrderPreservingUnaryExecNode { override def supportsColumnar: Boolean = true @@ -267,8 +267,8 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], condition: Expression, child: SparkPlan) extends UnaryExecNode - with AliasAwareOutputPartitioning - with AliasAwareOutputOrdering { + with PartitioningPreservingUnaryExecNode + with OrderPreservingUnaryExecNode { override def supportsColumnar: Boolean = true diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index 800dcf1a0..80d796ce1 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -479,8 +479,8 @@ abstract class BaseColumnarFileSourceScanExec( } }.groupBy { f => BucketingUtils - .getBucketId(new Path(f.filePath).getName) - .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.filePath)) + .getBucketId(f.toPath.getName) + .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath)) } val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index ef33a84de..3630b0f2e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -105,12 +105,29 @@ class QueryExecution( case other => other } + // The plan that has been normalized by custom rules, so that it's more likely to hit cache. + lazy val normalized: LogicalPlan = { + val normalizationRules = sparkSession.sessionState.planNormalizationRules + if (normalizationRules.isEmpty) { + commandExecuted + } else { + val planChangeLogger = new PlanChangeLogger[LogicalPlan]() + val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) => + val result = rule.apply(p) + planChangeLogger.logRule(rule.ruleName, p, result) + result + } + planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized) + normalized + } + } + lazy val withCachedData: LogicalPlan = sparkSession.withActive { assertAnalyzed() assertSupported() // clone the plan to avoid sharing the plan instance between different stages like analyzing, // optimizing and planning. - sparkSession.sharedState.cacheManager.useCachedData(commandExecuted.clone()) + sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) } def assertCommandExecuted(): Unit = commandExecuted @@ -227,7 +244,7 @@ class QueryExecution( // output mode does not matter since there is no `Sink`. new IncrementalExecution( sparkSession, logical, OutputMode.Append(), "", - UUID.randomUUID, UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) + UUID.randomUUID, UUID.randomUUID, 0, None ,OffsetSeqMetadata(0, 0)) } else { this } @@ -494,11 +511,10 @@ object QueryExecution { */ private[sql] def toInternalError(msg: String, e: Throwable): Throwable = e match { case e @ (_: java.lang.NullPointerException | _: java.lang.AssertionError) => - new SparkException( - errorClass = "INTERNAL_ERROR", - messageParameters = Array(msg + - " Please, fill a bug report in, and provide the full stack trace."), - cause = e) + SparkException.internalError( + msg + " You hit a bug in Spark or the Spark plugins you use. Please, report this bug " + + "to the corresponding communities or vendors, and provide the full stack trace.", + e) case e: Throwable => e } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 6234751dd..6f397001e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -84,7 +84,9 @@ case class AdaptiveSparkPlanExec( @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]() // The logical plan optimizer for re-optimizing the current logical plan. - @transient private val optimizer = new AQEOptimizer(conf) + @transient private val optimizer = new AQEOptimizer(conf, + session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules + ) // `EnsureRequirements` may remove user-specified repartition and assume the query plan won't // change its output partitioning. This assumption is not true in AQE. Here we check the @@ -123,7 +125,7 @@ case class AdaptiveSparkPlanExec( RemoveRedundantSorts, DisableUnnecessaryBucketedScan, OptimizeSkewedJoin(ensureRequirements) - ) ++ context.session.sessionState.queryStagePrepRules + ) ++ context.session.sessionState.adaptiveRulesHolder.queryStagePrepRules } // A list of physical optimizer rules to be applied to a new stage before its execution. These @@ -223,6 +225,8 @@ case class AdaptiveSparkPlanExec( .map(_.toLong).filter(SQLExecution.getQueryExecution(_) eq context.qe) } + def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity) + private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized { if (isFinalPlan) return currentPhysicalPlan diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index a70ba852e..3cd491364 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} +import org.apache.spark.sql.execution.datasources.WriteFilesExec import org.apache.spark.sql.execution.datasources.v2.V2CommandExec import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.internal.SQLConf @@ -44,6 +45,7 @@ case class InsertAdaptiveSparkPlan( private def applyInternal(plan: SparkPlan, isSubquery: Boolean): SparkPlan = plan match { case _ if !conf.adaptiveExecutionEnabled => plan + case _: WriteFilesExec => plan case _: ExecutedCommandExec => plan case _: CommandResultExec => plan case c: DataWritingCommandExec => c.copy(child = apply(c.child)) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index b5a1ad375..dfdbe2c70 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -30,11 +30,11 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning(_.containsAnyPattern( SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY, RUNTIME_FILTER_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _) => + case expressions.ScalarSubquery(_, _, exprId, _, _, _) => val subquery = SubqueryExec.createForScalarSubquery( s"subquery#${exprId.id}", subqueryMap(exprId.id)) execution.ScalarSubquery(subquery, exprId) - case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _)) => + case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) => val expr = if (values.length == 1) { values.head } else { @@ -47,7 +47,7 @@ case class PlanAdaptiveSubqueries( val subquery = SubqueryExec(s"subquery#${exprId.id}", subqueryMap(exprId.id)) InSubqueryExec(expr, subquery, exprId, shouldBroadcast = true) case expressions.DynamicPruningSubquery(value, buildPlan, - buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) => + buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) => val name = s"dynamicpruning#${exprId.id}" val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndex, onlyInBroadcast, buildPlan, buildKeys, subqueryMap(exprId.id)) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala index 327c3426c..171825da2 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala @@ -38,7 +38,7 @@ case class PushOrderedLimitThroughAgg(session: SparkSession) extends Rule[SparkP val enableColumnarTopNSort: Boolean = columnarConf.enableColumnarTopNSort plan.transform { - case orderAndProject @ TakeOrderedAndProjectExec(limit, sortOrder, projectList, orderAndProjectChild) => { + case orderAndProject @ TakeOrderedAndProjectExec(limit, sortOrder, projectList, orderAndProjectChild, _) => { orderAndProjectChild match { case finalAgg @ HashAggregateExec(_, _, _, _, _, _, _, _, finalAggChild) => finalAggChild match { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala index 465123d87..5ccfa1513 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala @@ -37,8 +37,10 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.datasources.FileFormatWriter.{ConcurrentOutputWriterSpec, OutputSpec, executedPlan, outputOrderingMatched} import org.apache.spark.sql.execution.{ColumnarProjectExec, ColumnarSortExec, OmniColumnarToRowExec, ProjectExec, SQLExecution, SortExec, SparkPlan, UnsafeExternalRowSorter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType @@ -69,6 +71,19 @@ object OmniFileFormatWriter extends Logging { copy(child = newChild) } + /** + * A variable used in tests to check whether the output ordering of the query matches the + * required ordering of the write command. + */ + private[sql] var outputOrderingMatched: Boolean = false + + /** + * A variable used in tests to check the final executed plan. + */ + private[sql] var executedPlan: Option[SparkPlan] = None + + // scalastyle:off argcount + /** * Basic work flow of this command is: * 1. Driver side setup, including output committer initialization and data source specific @@ -94,8 +109,10 @@ object OmniFileFormatWriter extends Logging { partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], statsTrackers: Seq[WriteJobStatsTracker], - options: Map[String, String]) + options: Map[String, String], + numStaticPartitionCols: Int = 0) : Set[String] = { + require(partitionColumns.size >= numStaticPartitionCols) val job = Job.getInstance(hadoopConf) job.setOutputKeyClass(classOf[Void]) @@ -118,43 +135,14 @@ object OmniFileFormatWriter extends Logging { } val empty2NullPlan = if (needConvert) ColumnarProjectExec(projectList, plan) else plan - val writerBucketSpec = bucketSpec.map { spec => - val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) - - if (options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") == - "true") { - // Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression. - // Without the extra bitwise-and operation, we can get wrong bucket id when hash value of - // columns is negative. See Hive implementation in - // `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`. - val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue)) - val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets)) - - // The bucket file name prefix is following Hive, Presto and Trino conversion, so this - // makes sure Hive bucketed table written by Spark, can be read by other SQL engines. - // - // Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`. - // Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`. - val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_" - WriterBucketSpec(bucketIdExpression, fileNamePrefix) - } else { - // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id - // expression, so that we can guarantee the data distribution is same between shuffle and - // bucketed data source, which enables us to only shuffle one side when join a bucketed - // table and a normal one. - val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets) - .partitionIdExpression - WriterBucketSpec(bucketIdExpression, (_: Int) => "") - } - } - val sortColumns = bucketSpec.toSeq.flatMap { - spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) - } + val writerBucketSpec = V1WritesUtils.getWriterBucketSpec(bucketSpec, dataColumns, options) + val sortColumns = V1WritesUtils.getBucketSortColumns(bucketSpec, dataColumns) val caseInsensitiveOptions = CaseInsensitiveMap(options) val dataSchema = dataColumns.toStructType DataSourceUtils.verifySchema(fileFormat, dataSchema) + DataSourceUtils.checkFieldNames(fileFormat, dataSchema) // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema) @@ -176,53 +164,89 @@ object OmniFileFormatWriter extends Logging { statsTrackers = statsTrackers ) - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val requiredOrdering = - partitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns - // the sort order doesn't matter - val actualOrdering = empty2NullPlan.outputOrdering.map(_.child) - val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { - false - } else { - requiredOrdering.zip(actualOrdering).forall { - case (requiredOrder, childOutputOrder) => - requiredOrder.semanticEquals(childOutputOrder) - } + // We should first sort by dynamic partition columns, then bucket id, and finally sorting + // columns. + val requiredOrdering = partitionColumns.drop(numStaticPartitionCols) ++ + writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns + val writeFilesOpt = V1WritesUtils.getWriteFilesOpt(plan) + + // SPARK-40588: when planned writing is disabled and AQE is enabled, + // plan contains an AdaptiveSparkPlanExec, which does not know + // its final plan's ordering, so we have to materialize that plan first + // it is fine to use plan further down as the final plan is cached in that plan + def materializeAdaptiveSparkPlan(plan: SparkPlan): SparkPlan = plan match { + case a: AdaptiveSparkPlanExec => a.finalPhysicalPlan + case p: SparkPlan => p.withNewChildren(p.children.map(materializeAdaptiveSparkPlan)) } + // the sort order doesn't matter + val actualOrdering = writeFilesOpt.map(_.child) + .getOrElse(materializeAdaptiveSparkPlan(plan)) + .outputOrdering + val orderingMatched = V1WritesUtils.isOrderingMatched(requiredOrdering, actualOrdering) + SQLExecution.checkSQLExecutionId(sparkSession) // propagate the description UUID into the jobs, so that committers // get an ID guaranteed to be unique. job.getConfiguration.set("spark.sql.sources.writeJobUUID", description.uuid) - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - committer.setupJob(job) - - try { + // When `PLANNED_WRITE_ENABLED` is true, the optimizer rule V1Writes will add logical sort + // operator based on the required ordering of the V1 write command. So the output + // ordering of the physical plan should always match the required ordering. Here + // we set the variable to verify this behavior in tests. + // There are two cases where FileFormatWriter still needs to add physical sort: + // 1) When the planned write config is disabled. + // 2) When the concurrent writers are enabled (in this case the required ordering of a + // V1 write command will be empty). + if (Utils.isTesting) outputOrderingMatched = orderingMatched + + if (writeFilesOpt.isDefined) { + // build `WriteFilesSpec` for `WriteFiles` + val concurrentOutputWriterSpecFunc = (plan: SparkPlan) => { + val sortPlan = createSortPlan(plan, requiredOrdering, outputSpec) + createConcurrentOutputWriterSpec(sparkSession, sortPlan, sortColumns) + } + val writeSpec = WriteFilesSpec( + description = description, + committer = committer, + concurrentOutputWriterSpecFunc = concurrentOutputWriterSpecFunc + ) + executeWrite(sparkSession, plan, writeSpec, job) + } else { + executeWrite(sparkSession, plan, job, description, committer, outputSpec, + requiredOrdering, partitionColumns, sortColumns, orderingMatched) + } + } + // scalastyle:on argcount + + private def executeWrite( + sparkSession: SparkSession, + plan: SparkPlan, + job: Job, + description: WriteJobDescription, + committer: FileCommitProtocol, + outputSpec: OutputSpec, + requiredOrdering: Seq[Expression], + partitionColumns: Seq[Attribute], + sortColumns: Seq[Attribute], + orderingMatched: Boolean): Set[String] = { + val projectList = V1WritesUtils.convertEmptyToNull(plan.output, partitionColumns) + val empty2NullPlan = if (projectList.nonEmpty) ProjectExec(projectList, plan) else plan + + writeAndCommit(job, description, committer) { val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) { (empty2NullPlan.executeColumnar(), None) } else { - // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and - // the physical plan may have different attribute ids due to optimizer removing some - // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. - val orderingExpr = bindReferences( - requiredOrdering.map(SortOrder(_, Ascending)), finalOutputSpec.outputColumns) - // val orderingExpr = requiredOrdering.map(SortOrder(_, Ascending)) - val sortPlan = ColumnarSortExec( - orderingExpr, - global = false, - child = empty2NullPlan) - - val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters - val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty - if (concurrentWritersEnabled) { + val sortPlan = createSortPlan(empty2NullPlan, requiredOrdering, outputSpec) + val concurrentOutputWriterSpec = createConcurrentOutputWriterSpec( + sparkSession, sortPlan, sortColumns) + if (concurrentOutputWriterSpec.isDefined) { // TODO Concurrent output write logInfo("Columnar concurrent write is not support now, use un concurrent write") - (sortPlan.executeColumnar(), None) + (empty2NullPlan.executeColumnar(), concurrentOutputWriterSpec) } else { - (sortPlan.executeColumnar(), None) + (sortPlan.executeColumnar(), concurrentOutputWriterSpec) } } @@ -254,7 +278,19 @@ object OmniFileFormatWriter extends Logging { committer.onTaskCommit(res.commitMsg) ret(index) = res }) + ret + } + } + private def writeAndCommit( + job: Job, + description: WriteJobDescription, + committer: FileCommitProtocol)(f: => Array[WriteTaskResult]): Set[String] = { + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + committer.setupJob(job) + try { + val ret = f val commitMsgs = ret.map(_.commitMsg) logInfo(s"Start to commit write Job ${description.uuid}.") @@ -272,7 +308,71 @@ object OmniFileFormatWriter extends Logging { case cause: Throwable => logError(s"Aborting job ${description.uuid}.", cause) committer.abortJob(job) - throw QueryExecutionErrors.jobAbortedError(cause) + throw cause + } + } + + /** + * Write files using [[SparkPlan.executeWrite]] + */ + private def executeWrite( + session: SparkSession, + planForWrites: SparkPlan, + writeFilesSpec: WriteFilesSpec, + job: Job): Set[String] = { + val committer = writeFilesSpec.committer + val description = writeFilesSpec.description + + // In testing, this is the only way to get hold of the actually executed plan written to file + if (Utils.isTesting) executedPlan = Some(planForWrites) + + writeAndCommit(job, description, committer) { + val rdd = planForWrites.executeWrite(writeFilesSpec) + val ret = new Array[WriteTaskResult](rdd.partitions.length) + session.sparkContext.runJob( + rdd, + (context: TaskContext, iter: Iterator[WriterCommitMessage]) => { + assert(iter.hasNext) + val commitMessage = iter.next() + assert(!iter.hasNext) + commitMessage + }, + rdd.partitions.indices, + (index, res: WriterCommitMessage) => { + assert(res.isInstanceOf[WriteTaskResult]) + val writeTaskResult = res.asInstanceOf[WriteTaskResult] + committer.onTaskCommit(writeTaskResult.commitMsg) + ret(index) = writeTaskResult + }) + ret + } + } + + private def createSortPlan( + plan: SparkPlan, + requiredOrdering: Seq[Expression], + outputSpec: OutputSpec): SortExec = { + // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and + // the physical plan may have different attribute ids due to optimizer removing some + // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. + val orderingExpr = bindReferences( + requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns) + SortExec( + orderingExpr, + global = false, + child = plan) + } + + private def createConcurrentOutputWriterSpec( + sparkSession: SparkSession, + sortPlan: SortExec, + sortColumns: Seq[Attribute]): Option[ConcurrentOutputWriterSpec] = { + val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters + val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty + if (concurrentWritersEnabled) { + Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter())) + } else { + None } } @@ -343,7 +443,7 @@ object OmniFileFormatWriter extends Logging { // We throw the exception and let Executor throw ExceptionFailure to abort the job. throw new TaskOutputFileAlreadyExistException(f) case t: Throwable => - throw QueryExecutionErrors.taskFailedWhileWritingRowsError(t) + throw QueryExecutionErrors.taskFailedWhileWritingRowsError(description.path, t) } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala index 9d0008e0b..7415e2065 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -59,7 +59,7 @@ case class OmniInsertIntoHadoopFsRelationCommand( catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex], outputColumnNames: Seq[String]) - extends DataWritingCommand { + extends V1WriteCommand { private lazy val parameters = CaseInsensitiveMap(options) @@ -76,6 +76,10 @@ case class OmniInsertIntoHadoopFsRelationCommand( staticPartitions.size < partitionColumns.length } + override def requiredOrdering: Seq[SortOrder] = + V1WritesUtils.getSortOrder(outputColumns, partitionColumns, bucketSpec, options, + staticPartitions.size) + // Return Seq[Row] but Seq[ColumBatch] since // 1. reuse the origin interface of spark to avoid add duplicate code // 2. this func return a Seq.empty[Row] and this data doesn't do anything else @@ -83,7 +87,6 @@ case class OmniInsertIntoHadoopFsRelationCommand( // Most formats don't do well with duplicate columns, so lets not allow that SchemaUtils.checkColumnNameDuplication( outputColumnNames, - s"when inserting into $outputPath", sparkSession.sessionState.conf.caseSensitiveAnalysis) val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options) @@ -187,7 +190,8 @@ case class OmniInsertIntoHadoopFsRelationCommand( partitionColumns = partitionColumns, bucketSpec = bucketSpec, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), - options = options) + options = options, + numStaticPartitionCols = staticPartitions.size) // update metastore partition metadata diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala index 8f37ca70b..4e8d14424 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala @@ -75,7 +75,7 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ (file: PartitionedFile) => { val conf = broadcastedConf.value.value - val filePath = new Path(new URI(file.filePath)) + val filePath = file.toPath // ORC predicate pushdown val pushed = if (orcFilterPushDown) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala index 78acf3058..4586b8ec0 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala @@ -88,7 +88,7 @@ class OmniParquetFileFormat extends FileFormat with DataSourceRegister with Logg (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) - val filePath = new Path(new URI(file.filePath)) + val filePath = file.toPath val split = new org.apache.parquet.hadoop.ParquetInputSplit( filePath, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala index 2204e0fe4..ba5db895e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala @@ -264,6 +264,7 @@ object BroadcastUtils { -1, -1L, -1, + -1, new TaskMemoryManager(memoryManager, -1L), new Properties, MetricsSystem.createMetricsSystem("OMNI_UNSAFE", sparkConf), diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties index 89eabe8e6..1991c9c05 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties @@ -2,11 +2,13 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. # -hive.metastore.uris=thrift://server1:9083 +#hive.metastore.uris=thrift://server1:9083 +hive.metastore.uris=thrift://OmniOperator:9083 spark.sql.warehouse.dir=/user/hive/warehouse spark.memory.offHeap.size=8G spark.sql.codegen.wholeStage=false spark.sql.extensions=com.huawei.boostkit.spark.ColumnarPlugin spark.shuffle.manager=org.apache.spark.shuffle.sort.OmniColumnarShuffleManager spark.sql.orc.impl=native -hive.db=tpcds_bin_partitioned_orc_2 \ No newline at end of file +#hive.db=tpcds_bin_partitioned_orc_2 +hive.db=tpcds_bin_partitioned_varchar_orc_2 \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e6b786c2a..36d0a3e49 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkException import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.execution.ColumnarSparkPlanTest import org.apache.spark.sql.types.{DataType, Decimal} @@ -64,7 +65,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getByte(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as short") { @@ -72,7 +73,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getShort(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as int") { @@ -80,7 +81,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getInt(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as long") { @@ -88,7 +89,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getLong(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as float") { @@ -96,7 +97,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getFloat(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as double") { @@ -104,7 +105,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getDouble(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as date") { @@ -154,13 +155,13 @@ class CastSuite extends ColumnarSparkPlanTest { val exception4 = intercept[Exception]( result4.collect().toSeq.head.getBoolean(0) ) - assert(exception4.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception4.isInstanceOf[SparkException], s"sql: ${sql}") val result5 = spark.sql("select cast('test' as boolean);") val exception5 = intercept[Exception]( result5.collect().toSeq.head.getBoolean(0) ) - assert(exception5.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception5.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast boolean to string") { @@ -182,13 +183,13 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getByte(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as byte);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getByte(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast byte to string") { @@ -210,13 +211,13 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getShort(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as short);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getShort(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast short to string") { @@ -238,13 +239,13 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getInt(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as int);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getInt(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast int to string") { @@ -266,13 +267,13 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getLong(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as long);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getLong(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast long to string") { @@ -298,7 +299,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception3 = intercept[Exception]( result3.collect().toSeq.head.getFloat(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast float to string") { @@ -324,7 +325,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception3 = intercept[Exception]( result3.collect().toSeq.head.getDouble(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast double to string") { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarDecimalCastSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarDecimalCastSuite.scala index c7bef78bd..1f27d7dec 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarDecimalCastSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarDecimalCastSuite.scala @@ -552,14 +552,15 @@ class ColumnarDecimalCastSuite extends ColumnarSparkPlanTest{ ) } + // cast string to decimal when precision is max_precision of decimal not supported test("Test ColumnarProjectExec happen and result is same as native " + "when cast decimal to string") { val res = spark.sql("select cast(cast(c_deci17_2_null as string) as decimal(38, 2))," + "cast(cast(c_deci38_2_null as string) as decimal(38, 2)) from deci_string") val executedPlan = res.queryExecution.executedPlan println(executedPlan) - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isDefined, s"ProjectExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"ColumnarProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalOperationSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalOperationSuite.scala index b29e062fe..9049ff17c 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalOperationSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalOperationSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.execution.{ColumnarConditionProjectExec, ColumnarSparkPlanTest} +import org.apache.spark.sql.execution.{ColumnarConditionProjectExec, ColumnarSparkPlanTest, ProjectExec} import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.{Column, DataFrame} @@ -59,6 +59,14 @@ class DecimalOperationSuite extends ColumnarSparkPlanTest{ assertResult(expect, s"sql: ${sql}")(output.toString) } + private def checkFallBackResult(sql: String, expect: String): Unit = { + val result = spark.sql(sql) + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ProjectExec]).isDefined) + val output = result.collect().toSeq.head.getDecimal(0) + assertResult(expect, s"sql: ${sql}")(output.toString) + } + private def checkResultNull(sql: String): Unit = { val result = spark.sql(sql) val plan = result.queryExecution.executedPlan @@ -67,6 +75,14 @@ class DecimalOperationSuite extends ColumnarSparkPlanTest{ assertResult(null, s"sql: ${sql}")(output) } + private def checkFallBackResultNull(sql: String): Unit = { + val result = spark.sql(sql) + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ProjectExec]).isDefined) + val output = result.collect().toSeq.head.getDecimal(0) + assertResult(null, s"sql: ${sql}")(output) + } + private def checkAnsiResult(sql: String, expect: String): Unit = { spark.conf.set("spark.sql.ansi.enabled", true) checkResult(sql, expect) @@ -576,7 +592,7 @@ class DecimalOperationSuite extends ColumnarSparkPlanTest{ } test("decimal64%decimal128(0) when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci18_6%c_deci21_6 from deci_overflow where id = 4") + checkFallBackResultNull("select c_deci18_6%c_deci21_6 from deci_overflow where id = 4") } test("decimal128%decimal128(0) when spark.sql.ansi.enabled=false") { @@ -747,7 +763,7 @@ class DecimalOperationSuite extends ColumnarSparkPlanTest{ } test("decimal64(0)%decimal128 when spark.sql.ansi.enabled=false") { - checkResult("select c_deci7_2%c_deci22_6 from deci_overflow where id = 4;", "0.000000") + checkFallBackResult("select c_deci7_2%c_deci22_6 from deci_overflow where id = 4;", "0.000000") } test("decimal128(0)%decimal128 when spark.sql.ansi.enabled=false") { @@ -759,7 +775,7 @@ class DecimalOperationSuite extends ColumnarSparkPlanTest{ } test("literal(0)%decimal128 when when spark.sql.ansi.enabled=false") { - checkResult("select 0%c_deci22_6 from deci_overflow where id = 4;", "0.000000") + checkFallBackResult("select 0%c_deci22_6 from deci_overflow where id = 4;", "0.000000") } // spark.sql.ansi.enabled=true @@ -990,7 +1006,7 @@ class DecimalOperationSuite extends ColumnarSparkPlanTest{ } test("decimal64(NULL)%decimal128 when spark.sql.ansi.enabled=false") { - checkResultNull("select c_deci7_2%c_deci22_6 from deci_overflow where id = 5;") + checkFallBackResultNull("select c_deci7_2%c_deci22_6 from deci_overflow where id = 5;") } test("decimal128%decimal128(NULL) when spark.sql.ansi.enabled=false") { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index f83edb9ca..ea52aca62 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -61,7 +61,7 @@ class CombiningLimitsSuite extends PlanTest { comparePlans(optimized1, expected1) // test child max row > limit. - val query2 = testRelation.select().groupBy()(count(1)).limit(0).analyze + val query2 = testRelation2.select($"x").groupBy($"x")(count(1)).limit(1).analyze val optimized2 = Optimize.execute(query2) comparePlans(optimized2, query2) diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala index aaa244cdf..e1c620e1c 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala @@ -43,7 +43,7 @@ class MergeSubqueryFiltersSuite extends PlanTest { } private def extractorExpression(cteIndex: Int, output: Seq[Attribute], fieldIndex: Int) = { - GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output)), fieldIndex) + GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output, isStreaming = false)), fieldIndex) .as("scalarsubquery()") } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala index c34ff5bb1..17a08139b 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala @@ -19,17 +19,15 @@ package org.apache.spark.sql.execution.adaptive import java.io.File import java.net.URI - import org.apache.logging.log4j.Level import org.scalatest.PrivateMethodTester import org.scalatest.time.SpanSugar._ - import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec} +import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnaryExecNode, UnionExec} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource @@ -37,7 +35,7 @@ import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter -import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate +import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLExecutionStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -1122,13 +1120,21 @@ class AdaptiveQueryExecSuite test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + var plan: SparkPlan = null + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + plan = qe.executedPlan + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } + spark.listenerManager.register(listener) withTable("t1") { - val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan - assert(plan.isInstanceOf[CommandResultExec]) - val commandResultExec = plan.asInstanceOf[CommandResultExec] - assert(commandResultExec.commandPhysicalPlan.isInstanceOf[DataWritingCommandExec]) - assert(commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec] - .child.isInstanceOf[AdaptiveSparkPlanExec]) + val format = classOf[NoopDataSource].getName + Seq((0, 1)).toDF("x", "y").write.format(format).mode("overwrite").save() + sparkContext.listenerBus.waitUntilEmpty() + assert(plan.isInstanceOf[V2TableWriteExec]) + assert(plan.asInstanceOf[V2TableWriteExec].child.isInstanceOf[AdaptiveSparkPlanExec]) + spark.listenerManager.unregister(listener) } } } @@ -1172,15 +1178,14 @@ class AdaptiveQueryExecSuite test("SPARK-31658: SQL UI should show write commands") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "false") { withTable("t1") { - var checkDone = false + var commands: Seq[SparkPlanInfo] = Seq.empty val listener = new SparkListener { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { - case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) => - assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand") - checkDone = true + case start: SparkListenerSQLExecutionStart => + commands = commands ++ Seq(start.sparkPlanInfo) case _ => // ignore other events } } @@ -1189,7 +1194,12 @@ class AdaptiveQueryExecSuite try { sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() spark.sparkContext.listenerBus.waitUntilEmpty() - assert(checkDone) + assert(commands.size == 3) + assert(commands.head.nodeName == "Execute CreateDataSourceTableAsSelectCommand") + assert(commands(1).nodeName == "Execute InsertIntoHadoopFsRelationCommand") + assert(commands(1).children.size == 1) + assert(commands(1).children.head.nodeName == "WriteFiles") + assert(commands(2).nodeName == "CommandResult") } finally { spark.sparkContext.removeSparkListener(listener) } @@ -1574,7 +1584,7 @@ class AdaptiveQueryExecSuite test("SPARK-32932: Do not use local shuffle read at final stage on write command") { withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString, SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { val data = for ( i <- 1L to 10L; j <- 1L to 3L @@ -1584,9 +1594,8 @@ class AdaptiveQueryExecSuite var noLocalread: Boolean = false val listener = new QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - qe.executedPlan match { + stripAQEPlan(qe.executedPlan) match { case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) => - assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec]) noLocalread = collect(plan) { case exec: AQEShuffleReadExec if exec.isLocalRead => exec }.isEmpty diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml index 4376a89be..e949cea2d 100644 --- a/omnioperator/omniop-spark-extension/pom.xml +++ b/omnioperator/omniop-spark-extension/pom.xml @@ -8,13 +8,13 @@ com.huawei.kunpeng boostkit-omniop-spark-parent pom - 3.3.1-1.6.0 + 3.4.3-1.6.0 BoostKit Spark Native Sql Engine Extension Parent Pom 2.12.10 2.12 - 3.3.1 + 3.4.3 3.2.2 UTF-8 UTF-8 -- Gitee From 121753bf5f98863073b79956bfd1a6456d788830 Mon Sep 17 00:00:00 2001 From: wangwei <14424757+wiwimao@user.noreply.gitee.com> Date: Fri, 8 Nov 2024 18:55:11 +0800 Subject: [PATCH 12/12] fix decimal calculation and testsuite --- .../boostkit/spark/ColumnarPlugin.scala | 23 +- .../boostkit/spark/TransformHintRule.scala | 5 - .../expression/OmniExpressionAdaptor.scala | 91 ++++--- .../adaptive/AdaptiveSparkPlanExec.scala | 6 +- .../datasources/OmniFileFormatWriter.scala | 234 +++++------------- ...mniInsertIntoHadoopFsRelationCommand.scala | 11 +- .../test/resources/HiveResource.properties | 6 +- .../ColumnarDecimalCastSuite.scala | 5 +- .../expressions/DecimalOperationSuite.scala | 26 +- .../ColumnarAdaptiveQueryExecSuite.scala | 4 +- 10 files changed, 148 insertions(+), 263 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index e03798e8b..22fc8339e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -563,25 +563,10 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) var unSupportedColumnarCommand = false var unSupportedFileFormat = false val omniCmd = plan.cmd match { - case cmd: OmniInsertIntoHadoopFsRelationCommand => - logInfo(s"Columnar Processing for ${cmd.getClass} is currently supported.") - val fileFormat: FileFormat = cmd.fileFormat match { - case _: OrcFileFormat => new OmniOrcFileFormat() - case format => - logInfo(s"Unsupported ${format.getClass} file " + - s"format for columnar data write command.") - unSupportedFileFormat = true - null - } - if (unSupportedFileFormat) { - cmd - } else { - OmniInsertIntoHadoopFsRelationCommand(cmd.outputPath, cmd.staticPartitions, - cmd.ifPartitionNotExists, cmd.partitionColumns, cmd.bucketSpec, fileFormat, - cmd.options, cmd.query, cmd.mode, cmd.catalogTable, - cmd.fileIndex, cmd.outputColumnNames - ) - } + case cmd: InsertIntoHadoopFsRelationCommand => + logInfo(s"Columnar Processing for ${cmd.getClass} is currently not supported.") + unSupportedColumnarCommand = true + cmd case cmd: DataWritingCommand => logInfo(s"Columnar Processing for ${cmd.getClass} is currently not supported.") unSupportedColumnarCommand = true diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala index 96c88cecc..553463d56 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec -import org.apache.spark.sql.execution.datasources.WriteFilesExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ColumnarBroadcastHashJoinExec, ColumnarShuffledHashJoinExec, ColumnarSortMergeJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec @@ -392,10 +391,6 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { } ColumnarDataWritingCommandExec(plan.cmd, plan.child).buildCheck() TransformHints.tagTransformable(plan) - case plan: WriteFilesExec => - TransformHints.tagNotTransformable( - plan, "data writing is not support" - ) case _ => TransformHints.tagTransformable(plan) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 29a09782b..c06c8e0a7 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -26,7 +26,6 @@ import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_AVG, import nova.hetu.omniruntime.constants.JoinType._ import nova.hetu.omniruntime.operator.OmniExprVerify import com.huawei.boostkit.spark.ColumnarPluginConfig -import nova.hetu.omniruntime.`type`.DataType.DataTypeId import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -40,6 +39,7 @@ import org.apache.spark.sql.execution.datasources.OmniFileFormatWriter.Empty2Nul import org.apache.spark.sql.expression.ColumnarExpressionConverter import org.apache.spark.sql.hive.HiveUdfAdaptorUtil import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DecimalType.{MAX_PRECISION, MAX_SCALE} import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, NullType, ShortType, StringType, TimestampType} import org.json.{JSONArray, JSONObject} @@ -87,14 +87,10 @@ object OmniExpressionAdaptor extends Logging { } def doSupportCastFromString(dataType: DataType): Boolean = { - !unSupportCastFromStringToDecimal(dataType) || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[DateType] || + dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[DateType] || dataType.isInstanceOf[IntegerType] || dataType.isInstanceOf[LongType] || dataType.isInstanceOf[DoubleType] } - def unSupportCastFromStringToDecimal(dataType: DataType): Boolean = { - dataType.isInstanceOf[DecimalType] && dataType.asInstanceOf[DecimalType].precision == DecimalType.MAX_PRECISION - } - // support cast(decimal/string/int/long as string) if (cast.dataType.isInstanceOf[StringType] && !doSupportCastToString(cast.child.dataType)) { throw new UnsupportedOperationException(s"Unsupported expression: $expr") @@ -107,18 +103,48 @@ object OmniExpressionAdaptor extends Logging { } - private def unsupportedModCheck(mod: Remainder): Unit = { - val leftDataType = mod.left.dataType - val rightDataType = mod.right.dataType - leftDataType match { - case decimalType: DecimalType if rightDataType.isInstanceOf[DecimalType] => - val leftOmniType = sparkTypeToOmniType(decimalType) - val rightOmniType = sparkTypeToOmniType(rightDataType.asInstanceOf[DecimalType]) - if (leftOmniType == DataTypeId.OMNI_DECIMAL64.toValue && rightOmniType == DataTypeId.OMNI_DECIMAL128.toValue && !SQLConf.get.ansiEnabled) { - throw new UnsupportedOperationException(s"Unsupported mod Type: $leftDataType % $rightDataType") + private def binaryOperatorAdjust(expr: BinaryOperator, returnDataType: DataType): Tuple2[Expression, Expression] = { + import scala.math.{max, min} + def bounded(precision: Int, scale: Int): DecimalType = { + DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) + } + + def widerDecimalType(d1: DecimalType, d2: DecimalType): Tuple2[DecimalType, Boolean] = { + getWiderDecimalType(d1.precision, d1.scale, d2.precision, d2.scale) + } + + def getWiderDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): Tuple2[DecimalType, Boolean] = { + val scale = max(s1, s2) + val range = max(p1 - s1, p2 - s2) + (bounded(range + scale, scale), range + scale > MAX_PRECISION) + } + + def decimalTypeCast(expr: Expression, d: DecimalType, widerType: DecimalType, returnType: DecimalType, isOverPrecision: Boolean): Expression = { + if (isOverPrecision && d.scale <= returnType.scale) { + if (returnType.precision - returnType.scale < d.precision - d.scale) { + return expr } - case _ => + Cast (expr, returnDataType) + } else { + Cast (expr, widerType) + } } + + if (DecimalType.unapply(expr.left) && DecimalType.unapply(expr.right)) { + val leftDataType = expr.left.dataType.asInstanceOf[DecimalType] + val rightDataType = expr.right.dataType.asInstanceOf[DecimalType] + val (widerType, isOverPrecision) = widerDecimalType(leftDataType, rightDataType) + val result = expr match { + case _: Add | _: Subtract => (Cast(expr.left, returnDataType), Cast(expr.right, returnDataType)) + case _: Multiply | _: Divide | _: Remainder => + val newLeft = decimalTypeCast(expr.left, leftDataType, widerType, returnDataType.asInstanceOf[DecimalType], isOverPrecision) + val newRight = decimalTypeCast(expr.right, rightDataType, widerType, returnDataType.asInstanceOf[DecimalType], isOverPrecision) + (newLeft, newRight) + case _ => (expr.left, expr.right) + } + return result + } + (expr.left, expr.right) } private val timeFormatSet: Set[String] = Set("yyyy-MM-dd HH:mm:ss", "yyyy-MM-dd") @@ -209,44 +235,45 @@ object OmniExpressionAdaptor extends Logging { throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}") } -// case promotePrecision: PromotePrecision => -// rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap) - case sub: Subtract => + val (left, right) = binaryOperatorAdjust(sub, returnDatatype) new JSONObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) .put("operator", "SUBTRACT") - .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(sub.left, exprsIndexMap)) - .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(sub.right, exprsIndexMap)) + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case add: Add => + val (left, right) = binaryOperatorAdjust(add, returnDatatype) new JSONObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) .put("operator", "ADD") - .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(add.left, exprsIndexMap)) - .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(add.right, exprsIndexMap)) + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case mult: Multiply => + val (left, right) = binaryOperatorAdjust(mult, returnDatatype) new JSONObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) .put("operator", "MULTIPLY") - .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(mult.left, exprsIndexMap)) - .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(mult.right, exprsIndexMap)) + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case divide: Divide => + val (left, right) = binaryOperatorAdjust(divide, returnDatatype) new JSONObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) .put("operator", "DIVIDE") - .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(divide.left, exprsIndexMap)) - .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(divide.right, exprsIndexMap)) + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case mod: Remainder => - unsupportedModCheck(mod) + val (left, right) = binaryOperatorAdjust(mod, returnDatatype) new JSONObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) .put("operator", "MODULUS") - .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(mod.left, exprsIndexMap)) - .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(mod.right, exprsIndexMap)) + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case greaterThan: GreaterThan => new JSONObject().put("exprType", "BINARY") @@ -466,8 +493,8 @@ object OmniExpressionAdaptor extends Logging { .put("function_name", "might_contain") .put("arguments", new JSONArray() .put(rewriteToOmniJsonExpressionLiteralJsonObject( - ColumnarExpressionConverter.replaceWithColumnarExpression(bloomFilterMightContain.bloomFilterExpression), - exprsIndexMap)) + ColumnarExpressionConverter.replaceWithColumnarExpression(bloomFilterMightContain.bloomFilterExpression), + exprsIndexMap)) .put(rewriteToOmniJsonExpressionLiteralJsonObject(bloomFilterMightContain.valueExpression, exprsIndexMap, returnDatatype))) case columnarBloomFilterSubquery: ColumnarBloomFilterSubquery => diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 6f397001e..6db6fc3c1 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -190,7 +190,7 @@ case class AdaptiveSparkPlanExec( @volatile private var currentPhysicalPlan = initialPlan - private var isFinalPlan = false + @volatile private var _isFinalPlan = false private var currentStageId = 0 @@ -207,6 +207,8 @@ case class AdaptiveSparkPlanExec( def executedPlan: SparkPlan = currentPhysicalPlan + def isFinalPlan: Boolean = _isFinalPlan + override def conf: SQLConf = context.session.sessionState.conf override def output: Seq[Attribute] = inputPlan.output @@ -330,7 +332,7 @@ case class AdaptiveSparkPlanExec( optimizeQueryStage(result.newPlan, isFinalStage = true), postStageCreationRules(supportsColumnar), Some((planChangeLogger, "AQE Post Stage Creation"))) - isFinalPlan = true + _isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala index 5ccfa1513..a65e66c4e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala @@ -37,10 +37,8 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec -import org.apache.spark.sql.execution.datasources.FileFormatWriter.{ConcurrentOutputWriterSpec, OutputSpec, executedPlan, outputOrderingMatched} +import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec import org.apache.spark.sql.execution.{ColumnarProjectExec, ColumnarSortExec, OmniColumnarToRowExec, ProjectExec, SQLExecution, SortExec, SparkPlan, UnsafeExternalRowSorter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType @@ -71,19 +69,6 @@ object OmniFileFormatWriter extends Logging { copy(child = newChild) } - /** - * A variable used in tests to check whether the output ordering of the query matches the - * required ordering of the write command. - */ - private[sql] var outputOrderingMatched: Boolean = false - - /** - * A variable used in tests to check the final executed plan. - */ - private[sql] var executedPlan: Option[SparkPlan] = None - - // scalastyle:off argcount - /** * Basic work flow of this command is: * 1. Driver side setup, including output committer initialization and data source specific @@ -109,10 +94,8 @@ object OmniFileFormatWriter extends Logging { partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], statsTrackers: Seq[WriteJobStatsTracker], - options: Map[String, String], - numStaticPartitionCols: Int = 0) + options: Map[String, String]) : Set[String] = { - require(partitionColumns.size >= numStaticPartitionCols) val job = Job.getInstance(hadoopConf) job.setOutputKeyClass(classOf[Void]) @@ -135,14 +118,43 @@ object OmniFileFormatWriter extends Logging { } val empty2NullPlan = if (needConvert) ColumnarProjectExec(projectList, plan) else plan - val writerBucketSpec = V1WritesUtils.getWriterBucketSpec(bucketSpec, dataColumns, options) - val sortColumns = V1WritesUtils.getBucketSortColumns(bucketSpec, dataColumns) + val writerBucketSpec = bucketSpec.map { spec => + val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + + if (options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") == + "true") { + // Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression. + // Without the extra bitwise-and operation, we can get wrong bucket id when hash value of + // columns is negative. See Hive implementation in + // `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`. + val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue)) + val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets)) + + // The bucket file name prefix is following Hive, Presto and Trino conversion, so this + // makes sure Hive bucketed table written by Spark, can be read by other SQL engines. + // + // Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`. + // Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`. + val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_" + WriterBucketSpec(bucketIdExpression, fileNamePrefix) + } else { + // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id + // expression, so that we can guarantee the data distribution is same between shuffle and + // bucketed data source, which enables us to only shuffle one side when join a bucketed + // table and a normal one. + val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets) + .partitionIdExpression + WriterBucketSpec(bucketIdExpression, (_: Int) => "") + } + } + val sortColumns = bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) + } val caseInsensitiveOptions = CaseInsensitiveMap(options) val dataSchema = dataColumns.toStructType DataSourceUtils.verifySchema(fileFormat, dataSchema) - DataSourceUtils.checkFieldNames(fileFormat, dataSchema) // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema) @@ -164,26 +176,19 @@ object OmniFileFormatWriter extends Logging { statsTrackers = statsTrackers ) - // We should first sort by dynamic partition columns, then bucket id, and finally sorting - // columns. - val requiredOrdering = partitionColumns.drop(numStaticPartitionCols) ++ - writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns - val writeFilesOpt = V1WritesUtils.getWriteFilesOpt(plan) - - // SPARK-40588: when planned writing is disabled and AQE is enabled, - // plan contains an AdaptiveSparkPlanExec, which does not know - // its final plan's ordering, so we have to materialize that plan first - // it is fine to use plan further down as the final plan is cached in that plan - def materializeAdaptiveSparkPlan(plan: SparkPlan): SparkPlan = plan match { - case a: AdaptiveSparkPlanExec => a.finalPhysicalPlan - case p: SparkPlan => p.withNewChildren(p.children.map(materializeAdaptiveSparkPlan)) - } - + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val requiredOrdering = + partitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns // the sort order doesn't matter - val actualOrdering = writeFilesOpt.map(_.child) - .getOrElse(materializeAdaptiveSparkPlan(plan)) - .outputOrdering - val orderingMatched = V1WritesUtils.isOrderingMatched(requiredOrdering, actualOrdering) + val actualOrdering = empty2NullPlan.outputOrdering.map(_.child) + val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { + false + } else { + requiredOrdering.zip(actualOrdering).forall { + case (requiredOrder, childOutputOrder) => + requiredOrder.semanticEquals(childOutputOrder) + } + } SQLExecution.checkSQLExecutionId(sparkSession) @@ -191,62 +196,33 @@ object OmniFileFormatWriter extends Logging { // get an ID guaranteed to be unique. job.getConfiguration.set("spark.sql.sources.writeJobUUID", description.uuid) - // When `PLANNED_WRITE_ENABLED` is true, the optimizer rule V1Writes will add logical sort - // operator based on the required ordering of the V1 write command. So the output - // ordering of the physical plan should always match the required ordering. Here - // we set the variable to verify this behavior in tests. - // There are two cases where FileFormatWriter still needs to add physical sort: - // 1) When the planned write config is disabled. - // 2) When the concurrent writers are enabled (in this case the required ordering of a - // V1 write command will be empty). - if (Utils.isTesting) outputOrderingMatched = orderingMatched - - if (writeFilesOpt.isDefined) { - // build `WriteFilesSpec` for `WriteFiles` - val concurrentOutputWriterSpecFunc = (plan: SparkPlan) => { - val sortPlan = createSortPlan(plan, requiredOrdering, outputSpec) - createConcurrentOutputWriterSpec(sparkSession, sortPlan, sortColumns) - } - val writeSpec = WriteFilesSpec( - description = description, - committer = committer, - concurrentOutputWriterSpecFunc = concurrentOutputWriterSpecFunc - ) - executeWrite(sparkSession, plan, writeSpec, job) - } else { - executeWrite(sparkSession, plan, job, description, committer, outputSpec, - requiredOrdering, partitionColumns, sortColumns, orderingMatched) - } - } - // scalastyle:on argcount - - private def executeWrite( - sparkSession: SparkSession, - plan: SparkPlan, - job: Job, - description: WriteJobDescription, - committer: FileCommitProtocol, - outputSpec: OutputSpec, - requiredOrdering: Seq[Expression], - partitionColumns: Seq[Attribute], - sortColumns: Seq[Attribute], - orderingMatched: Boolean): Set[String] = { - val projectList = V1WritesUtils.convertEmptyToNull(plan.output, partitionColumns) - val empty2NullPlan = if (projectList.nonEmpty) ProjectExec(projectList, plan) else plan - - writeAndCommit(job, description, committer) { + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + committer.setupJob(job) + + try { val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) { (empty2NullPlan.executeColumnar(), None) } else { - val sortPlan = createSortPlan(empty2NullPlan, requiredOrdering, outputSpec) - val concurrentOutputWriterSpec = createConcurrentOutputWriterSpec( - sparkSession, sortPlan, sortColumns) - if (concurrentOutputWriterSpec.isDefined) { + // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and + // the physical plan may have different attribute ids due to optimizer removing some + // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. + val orderingExpr = bindReferences( + requiredOrdering.map(SortOrder(_, Ascending)), finalOutputSpec.outputColumns) + // val orderingExpr = requiredOrdering.map(SortOrder(_, Ascending)) + val sortPlan = ColumnarSortExec( + orderingExpr, + global = false, + child = empty2NullPlan) + + val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters + val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty + if (concurrentWritersEnabled) { // TODO Concurrent output write logInfo("Columnar concurrent write is not support now, use un concurrent write") - (empty2NullPlan.executeColumnar(), concurrentOutputWriterSpec) + (sortPlan.executeColumnar(), None) } else { - (sortPlan.executeColumnar(), concurrentOutputWriterSpec) + (sortPlan.executeColumnar(), None) } } @@ -278,19 +254,7 @@ object OmniFileFormatWriter extends Logging { committer.onTaskCommit(res.commitMsg) ret(index) = res }) - ret - } - } - private def writeAndCommit( - job: Job, - description: WriteJobDescription, - committer: FileCommitProtocol)(f: => Array[WriteTaskResult]): Set[String] = { - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - committer.setupJob(job) - try { - val ret = f val commitMsgs = ret.map(_.commitMsg) logInfo(s"Start to commit write Job ${description.uuid}.") @@ -312,70 +276,6 @@ object OmniFileFormatWriter extends Logging { } } - /** - * Write files using [[SparkPlan.executeWrite]] - */ - private def executeWrite( - session: SparkSession, - planForWrites: SparkPlan, - writeFilesSpec: WriteFilesSpec, - job: Job): Set[String] = { - val committer = writeFilesSpec.committer - val description = writeFilesSpec.description - - // In testing, this is the only way to get hold of the actually executed plan written to file - if (Utils.isTesting) executedPlan = Some(planForWrites) - - writeAndCommit(job, description, committer) { - val rdd = planForWrites.executeWrite(writeFilesSpec) - val ret = new Array[WriteTaskResult](rdd.partitions.length) - session.sparkContext.runJob( - rdd, - (context: TaskContext, iter: Iterator[WriterCommitMessage]) => { - assert(iter.hasNext) - val commitMessage = iter.next() - assert(!iter.hasNext) - commitMessage - }, - rdd.partitions.indices, - (index, res: WriterCommitMessage) => { - assert(res.isInstanceOf[WriteTaskResult]) - val writeTaskResult = res.asInstanceOf[WriteTaskResult] - committer.onTaskCommit(writeTaskResult.commitMsg) - ret(index) = writeTaskResult - }) - ret - } - } - - private def createSortPlan( - plan: SparkPlan, - requiredOrdering: Seq[Expression], - outputSpec: OutputSpec): SortExec = { - // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and - // the physical plan may have different attribute ids due to optimizer removing some - // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. - val orderingExpr = bindReferences( - requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns) - SortExec( - orderingExpr, - global = false, - child = plan) - } - - private def createConcurrentOutputWriterSpec( - sparkSession: SparkSession, - sortPlan: SortExec, - sortColumns: Seq[Attribute]): Option[ConcurrentOutputWriterSpec] = { - val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters - val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty - if (concurrentWritersEnabled) { - Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter())) - } else { - None - } - } - /** Writes data out in a single Spark task. */ private def executeTask( description: WriteJobDescription, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala index 7415e2065..7c6011091 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -59,7 +59,7 @@ case class OmniInsertIntoHadoopFsRelationCommand( catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex], outputColumnNames: Seq[String]) - extends V1WriteCommand { + extends DataWritingCommand { private lazy val parameters = CaseInsensitiveMap(options) @@ -76,10 +76,6 @@ case class OmniInsertIntoHadoopFsRelationCommand( staticPartitions.size < partitionColumns.length } - override def requiredOrdering: Seq[SortOrder] = - V1WritesUtils.getSortOrder(outputColumns, partitionColumns, bucketSpec, options, - staticPartitions.size) - // Return Seq[Row] but Seq[ColumBatch] since // 1. reuse the origin interface of spark to avoid add duplicate code // 2. this func return a Seq.empty[Row] and this data doesn't do anything else @@ -190,8 +186,7 @@ case class OmniInsertIntoHadoopFsRelationCommand( partitionColumns = partitionColumns, bucketSpec = bucketSpec, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), - options = options, - numStaticPartitionCols = staticPartitions.size) + options = options) // update metastore partition metadata diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties index 1991c9c05..89eabe8e6 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties @@ -2,13 +2,11 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. # -#hive.metastore.uris=thrift://server1:9083 -hive.metastore.uris=thrift://OmniOperator:9083 +hive.metastore.uris=thrift://server1:9083 spark.sql.warehouse.dir=/user/hive/warehouse spark.memory.offHeap.size=8G spark.sql.codegen.wholeStage=false spark.sql.extensions=com.huawei.boostkit.spark.ColumnarPlugin spark.shuffle.manager=org.apache.spark.shuffle.sort.OmniColumnarShuffleManager spark.sql.orc.impl=native -#hive.db=tpcds_bin_partitioned_orc_2 -hive.db=tpcds_bin_partitioned_varchar_orc_2 \ No newline at end of file +hive.db=tpcds_bin_partitioned_orc_2 \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarDecimalCastSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarDecimalCastSuite.scala index 1f27d7dec..c7bef78bd 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarDecimalCastSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarDecimalCastSuite.scala @@ -552,15 +552,14 @@ class ColumnarDecimalCastSuite extends ColumnarSparkPlanTest{ ) } - // cast string to decimal when precision is max_precision of decimal not supported test("Test ColumnarProjectExec happen and result is same as native " + "when cast decimal to string") { val res = spark.sql("select cast(cast(c_deci17_2_null as string) as decimal(38, 2))," + "cast(cast(c_deci38_2_null as string) as decimal(38, 2)) from deci_string") val executedPlan = res.queryExecution.executedPlan println(executedPlan) - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isDefined, s"ProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"ColumnarProjectExec happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalOperationSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalOperationSuite.scala index 9049ff17c..b29e062fe 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalOperationSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalOperationSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.execution.{ColumnarConditionProjectExec, ColumnarSparkPlanTest, ProjectExec} +import org.apache.spark.sql.execution.{ColumnarConditionProjectExec, ColumnarSparkPlanTest} import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.{Column, DataFrame} @@ -59,14 +59,6 @@ class DecimalOperationSuite extends ColumnarSparkPlanTest{ assertResult(expect, s"sql: ${sql}")(output.toString) } - private def checkFallBackResult(sql: String, expect: String): Unit = { - val result = spark.sql(sql) - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ProjectExec]).isDefined) - val output = result.collect().toSeq.head.getDecimal(0) - assertResult(expect, s"sql: ${sql}")(output.toString) - } - private def checkResultNull(sql: String): Unit = { val result = spark.sql(sql) val plan = result.queryExecution.executedPlan @@ -75,14 +67,6 @@ class DecimalOperationSuite extends ColumnarSparkPlanTest{ assertResult(null, s"sql: ${sql}")(output) } - private def checkFallBackResultNull(sql: String): Unit = { - val result = spark.sql(sql) - val plan = result.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[ProjectExec]).isDefined) - val output = result.collect().toSeq.head.getDecimal(0) - assertResult(null, s"sql: ${sql}")(output) - } - private def checkAnsiResult(sql: String, expect: String): Unit = { spark.conf.set("spark.sql.ansi.enabled", true) checkResult(sql, expect) @@ -592,7 +576,7 @@ class DecimalOperationSuite extends ColumnarSparkPlanTest{ } test("decimal64%decimal128(0) when spark.sql.ansi.enabled=false") { - checkFallBackResultNull("select c_deci18_6%c_deci21_6 from deci_overflow where id = 4") + checkResultNull("select c_deci18_6%c_deci21_6 from deci_overflow where id = 4") } test("decimal128%decimal128(0) when spark.sql.ansi.enabled=false") { @@ -763,7 +747,7 @@ class DecimalOperationSuite extends ColumnarSparkPlanTest{ } test("decimal64(0)%decimal128 when spark.sql.ansi.enabled=false") { - checkFallBackResult("select c_deci7_2%c_deci22_6 from deci_overflow where id = 4;", "0.000000") + checkResult("select c_deci7_2%c_deci22_6 from deci_overflow where id = 4;", "0.000000") } test("decimal128(0)%decimal128 when spark.sql.ansi.enabled=false") { @@ -775,7 +759,7 @@ class DecimalOperationSuite extends ColumnarSparkPlanTest{ } test("literal(0)%decimal128 when when spark.sql.ansi.enabled=false") { - checkFallBackResult("select 0%c_deci22_6 from deci_overflow where id = 4;", "0.000000") + checkResult("select 0%c_deci22_6 from deci_overflow where id = 4;", "0.000000") } // spark.sql.ansi.enabled=true @@ -1006,7 +990,7 @@ class DecimalOperationSuite extends ColumnarSparkPlanTest{ } test("decimal64(NULL)%decimal128 when spark.sql.ansi.enabled=false") { - checkFallBackResultNull("select c_deci7_2%c_deci22_6 from deci_overflow where id = 5;") + checkResultNull("select c_deci7_2%c_deci22_6 from deci_overflow where id = 5;") } test("decimal128%decimal128(NULL) when spark.sql.ansi.enabled=false") { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala index 17a08139b..d1b295d5c 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala @@ -1178,7 +1178,7 @@ class AdaptiveQueryExecSuite test("SPARK-31658: SQL UI should show write commands") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "false") { + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { withTable("t1") { var commands: Seq[SparkPlanInfo] = Seq.empty val listener = new SparkListener { @@ -1584,7 +1584,7 @@ class AdaptiveQueryExecSuite test("SPARK-32932: Do not use local shuffle read at final stage on write command") { withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString, SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { val data = for ( i <- 1L to 10L; j <- 1L to 3L -- Gitee