diff --git a/omnioperator/omniop-native-reader/java/pom.xml b/omnioperator/omniop-native-reader/java/pom.xml index e7ddfe6c3bc7764df4b1642e4a137afcef64f6cd..49b621e55fc6baa82974617caf7be4a6f4d292df 100644 --- a/omnioperator/omniop-native-reader/java/pom.xml +++ b/omnioperator/omniop-native-reader/java/pom.xml @@ -8,13 +8,13 @@ com.huawei.boostkit boostkit-omniop-native-reader jar - 3.3.1-1.6.0 + 3.4.3-1.6.0 BoostKit Spark Native Sql Engine Extension With OmniOperator 2.12 - 3.3.1 + 3.4.3 FALSE ../cpp/ ../cpp/build/releases/ @@ -34,9 +34,10 @@ 1.6.0 + org.slf4j - slf4j-api - 1.7.32 + slf4j-simple + 1.7.36 junit diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml index 9cc1b9d25848f62caf13359f920db280820b04f0..13866589316a4e5add47eb2d38ea74815e7eb6e9 100644 --- a/omnioperator/omniop-spark-extension/java/pom.xml +++ b/omnioperator/omniop-spark-extension/java/pom.xml @@ -7,7 +7,7 @@ com.huawei.kunpeng boostkit-omniop-spark-parent - 3.3.1-1.6.0 + 3.4.3-1.6.0 ../pom.xml @@ -52,7 +52,7 @@ com.huawei.boostkit boostkit-omniop-native-reader - 3.3.1-1.6.0 + 3.4.3-1.6.0 junit diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java index df8c564b5a9168f3589c7293e6c6595e4dcea5a0..80889789b98f7f317bc34ab752c5b72faed195b2 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java @@ -25,65 +25,78 @@ import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.vector.*; import org.apache.orc.impl.writer.TimestampTreeWriter; +import org.apache.spark.sql.catalyst.util.CharVarcharUtils; import org.apache.spark.sql.catalyst.util.RebaseDateTime; -import org.apache.hadoop.hive.ql.io.sarg.ExpressionTree; -import org.apache.hadoop.hive.ql.io.sarg.PredicateLeaf; -import org.apache.orc.OrcFile.ReaderOptions; -import org.apache.orc.Reader.Options; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.json.JSONObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.math.BigDecimal; import java.net.URI; -import java.sql.Date; -import java.sql.Timestamp; +import java.time.LocalDate; import java.text.DateFormat; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.ArrayList; -import java.util.List; +import java.util.Arrays; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.TimeZone; public class OrcColumnarBatchScanReader { private static final Logger LOGGER = LoggerFactory.getLogger(OrcColumnarBatchScanReader.class); private boolean nativeSupportTimestampRebase; + private static final Pattern CHAR_TYPE = Pattern.compile("char\\(\\s*(\\d+)\\s*\\)"); + + private static final int MAX_LEAF_THRESHOLD = 256; + public long reader; public long recordReader; public long batchReader; - public int[] colsToGet; - public int realColsCnt; - public ArrayList fildsNames; + // All ORC fieldNames + public ArrayList allFieldsNames; - public ArrayList colToInclu; + // Indicate columns to read + public int[] colsToGet; - public String[] requiredfieldNames; + // Actual columns to read + public ArrayList includedColumns; - public int[] precisionArray; + // max threshold for leaf node + private int leafIndex; - public int[] scaleArray; + // spark required schema + private StructType requiredSchema; public OrcColumnarBatchJniReader jniReader; public OrcColumnarBatchScanReader() { jniReader = new OrcColumnarBatchJniReader(); - fildsNames = new ArrayList(); - } - - public JSONObject getSubJson(ExpressionTree node) { - JSONObject jsonObject = new JSONObject(); - jsonObject.put("op", node.getOperator().ordinal()); - if (node.getOperator().toString().equals("LEAF")) { - jsonObject.put("leaf", node.toString()); - return jsonObject; - } - ArrayList child = new ArrayList(); - for (ExpressionTree childNode : node.getChildren()) { - JSONObject rtnJson = getSubJson(childNode); - child.add(rtnJson); - } - jsonObject.put("child", child); - return jsonObject; + allFieldsNames = new ArrayList(); } public String padZeroForDecimals(String [] decimalStrArray, int decimalScale) { @@ -95,144 +108,15 @@ public class OrcColumnarBatchScanReader { return String.format("%1$-" + decimalScale + "s", decimalVal).replace(' ', '0'); } - public int getPrecision(String colname) { - for (int i = 0; i < requiredfieldNames.length; i++) { - if (colname.equals(requiredfieldNames[i])) { - return precisionArray[i]; - } - } - - return -1; - } - - public int getScale(String colname) { - for (int i = 0; i < requiredfieldNames.length; i++) { - if (colname.equals(requiredfieldNames[i])) { - return scaleArray[i]; - } - } - - return -1; - } - - public JSONObject getLeavesJson(List leaves) { - JSONObject jsonObjectList = new JSONObject(); - for (int i = 0; i < leaves.size(); i++) { - PredicateLeaf pl = leaves.get(i); - JSONObject jsonObject = new JSONObject(); - jsonObject.put("op", pl.getOperator().ordinal()); - jsonObject.put("name", pl.getColumnName()); - jsonObject.put("type", pl.getType().ordinal()); - if (pl.getLiteral() != null) { - if (pl.getType() == PredicateLeaf.Type.DATE) { - jsonObject.put("literal", ((int)Math.ceil(((Date)pl.getLiteral()).getTime()* 1.0/3600/24/1000)) + ""); - } else if (pl.getType() == PredicateLeaf.Type.DECIMAL) { - int decimalP = getPrecision(pl.getColumnName()); - int decimalS = getScale(pl.getColumnName()); - String[] spiltValues = pl.getLiteral().toString().split("\\."); - if (decimalS == 0) { - jsonObject.put("literal", spiltValues[0] + " " + decimalP + " " + decimalS); - } else { - String scalePadZeroStr = padZeroForDecimals(spiltValues, decimalS); - jsonObject.put("literal", spiltValues[0] + "." + scalePadZeroStr + " " + decimalP + " " + decimalS); - } - } else if (pl.getType() == PredicateLeaf.Type.TIMESTAMP) { - Timestamp t = (Timestamp)pl.getLiteral(); - jsonObject.put("literal", formatSecs(t.getTime() / TimestampTreeWriter.MILLIS_PER_SECOND) + " " + formatNanos(t.getNanos())); - } else { - jsonObject.put("literal", pl.getLiteral().toString()); - } - } else { - jsonObject.put("literal", ""); - } - if ((pl.getLiteralList() != null) && (pl.getLiteralList().size() != 0)){ - List lst = new ArrayList<>(); - for (Object ob : pl.getLiteralList()) { - if (ob == null) { - lst.add(null); - continue; - } - if (pl.getType() == PredicateLeaf.Type.DECIMAL) { - int decimalP = getPrecision(pl.getColumnName()); - int decimalS = getScale(pl.getColumnName()); - String[] spiltValues = ob.toString().split("\\."); - if (decimalS == 0) { - lst.add(spiltValues[0] + " " + decimalP + " " + decimalS); - } else { - String scalePadZeroStr = padZeroForDecimals(spiltValues, decimalS); - lst.add(spiltValues[0] + "." + scalePadZeroStr + " " + decimalP + " " + decimalS); - } - } else if (pl.getType() == PredicateLeaf.Type.DATE) { - lst.add(((int)Math.ceil(((Date)ob).getTime()* 1.0/3600/24/1000)) + ""); - } else if (pl.getType() == PredicateLeaf.Type.TIMESTAMP) { - Timestamp t = (Timestamp)pl.getLiteral(); - lst.add(formatSecs(t.getTime() / TimestampTreeWriter.MILLIS_PER_SECOND) + " " + formatNanos(t.getNanos())); - } else { - lst.add(ob.toString()); - } - } - jsonObject.put("literalList", lst); - } else { - jsonObject.put("literalList", new ArrayList()); - } - jsonObjectList.put("leaf-" + i, jsonObject); - } - return jsonObjectList; - } - - private long formatSecs(long secs) { - DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); - long epoch; - try { - epoch = dateFormat.parse(TimestampTreeWriter.BASE_TIMESTAMP_STRING).getTime() / - TimestampTreeWriter.MILLIS_PER_SECOND; - } catch (ParseException e) { - throw new RuntimeException(e); - } - return secs - epoch; - } - - private long formatNanos(int nanos) { - if (nanos == 0) { - return 0; - } else if (nanos % 100 != 0) { - return ((long) nanos) << 3; - } else { - nanos /= 100; - int trailingZeros = 1; - while (nanos % 10 == 0 && trailingZeros < 7) { - nanos /= 10; - trailingZeros += 1; - } - return ((long) nanos) << 3 | trailingZeros; - } - } - - private void addJulianGregorianInfo(JSONObject job) { - TimestampUtil instance = TimestampUtil.getInstance(); - JulianGregorianRebase julianObject = instance.getJulianObject(TimeZone.getDefault().getID()); - if (julianObject == null) { - return; - } - job.put("tz", julianObject.getTz()); - job.put("switches", julianObject.getSwitches()); - job.put("diffs", julianObject.getDiffs()); - nativeSupportTimestampRebase = true; - } - /** * Init Orc reader. * * @param uri split file path - * @param options split file options */ - public long initializeReaderJava(URI uri, ReaderOptions options) { + public long initializeReaderJava(URI uri) { JSONObject job = new JSONObject(); - if (options.getOrcTail() == null) { - job.put("serializedTail", ""); - } else { - job.put("serializedTail", options.getOrcTail().getSerializedTail().toString()); - } + + job.put("serializedTail", ""); job.put("tailLocation", 9223372036854775807L); job.put("scheme", uri.getScheme() == null ? "" : uri.getScheme()); @@ -240,38 +124,36 @@ public class OrcColumnarBatchScanReader { job.put("port", uri.getPort()); job.put("path", uri.getPath() == null ? "" : uri.getPath()); - reader = jniReader.initializeReader(job, fildsNames); + reader = jniReader.initializeReader(job, allFieldsNames); return reader; } /** * Init Orc RecordReader. * - * @param options split file options + * @param offset split file offset + * @param length split file read length + * @param pushedFilter the filter push down to native + * @param requiredSchema the columns read from native */ - public long initializeRecordReaderJava(Options options) { + public long initializeRecordReaderJava(long offset, long length, Filter pushedFilter, StructType requiredSchema) { + this.requiredSchema = requiredSchema; JSONObject job = new JSONObject(); - if (options.getInclude() == null) { - job.put("include", ""); - } else { - job.put("include", options.getInclude().toString()); - } - job.put("offset", options.getOffset()); - job.put("length", options.getLength()); - // When the number of pushedFilters > hive.CNF_COMBINATIONS_THRESHOLD, the expression is rewritten to - // 'YES_NO_NULL'. Under the circumstances, filter push down will be skipped. - if (options.getSearchArgument() != null - && !options.getSearchArgument().toString().contains("YES_NO_NULL")) { - LOGGER.debug("SearchArgument: {}", options.getSearchArgument().toString()); - JSONObject jsonexpressionTree = getSubJson(options.getSearchArgument().getExpression()); - job.put("expressionTree", jsonexpressionTree); - JSONObject jsonleaves = getLeavesJson(options.getSearchArgument().getLeaves()); - job.put("leaves", jsonleaves); - } - job.put("includedColumns", colToInclu.toArray()); - addJulianGregorianInfo(job); + job.put("offset", offset); + job.put("length", length); + + if (pushedFilter != null) { + JSONObject jsonExpressionTree = new JSONObject(); + JSONObject jsonLeaves = new JSONObject(); + boolean flag = canPushDown(pushedFilter, jsonExpressionTree, jsonLeaves); + if (flag) { + job.put("expressionTree", jsonExpressionTree); + job.put("leaves", jsonLeaves); + } + } + job.put("includedColumns", includedColumns.toArray()); recordReader = jniReader.initializeRecordReader(reader, job); return recordReader; } @@ -318,13 +200,13 @@ public class OrcColumnarBatchScanReader { } public int next(Vec[] vecList, int[] typeIds) { - long[] vecNativeIds = new long[realColsCnt]; + long[] vecNativeIds = new long[typeIds.length]; long rtn = jniReader.recordReaderNext(recordReader, batchReader, typeIds, vecNativeIds); if (rtn == 0) { return 0; } int nativeGetId = 0; - for (int i = 0; i < realColsCnt; i++) { + for (int i = 0; i < colsToGet.length; i++) { if (colsToGet[i] != 0) { continue; } @@ -351,13 +233,6 @@ public class OrcColumnarBatchScanReader { vecList[i] = new LongVec(vecNativeIds[nativeGetId]); break; } - case OMNI_TIMESTAMP: { - vecList[i] = new LongVec(vecNativeIds[nativeGetId]); - if (!this.nativeSupportTimestampRebase) { - convertJulianToGregorianMicros((LongVec)(vecList[i]), rtn); - } - break; - } case OMNI_DOUBLE: { vecList[i] = new DoubleVec(vecNativeIds[nativeGetId]); break; @@ -372,7 +247,7 @@ public class OrcColumnarBatchScanReader { } default: { throw new RuntimeException("UnSupport type for ColumnarFileScan:" + - DataType.DataTypeId.values()[typeIds[i]]); + DataType.DataTypeId.values()[typeIds[i]]); } } nativeGetId++; @@ -380,6 +255,228 @@ public class OrcColumnarBatchScanReader { return (int)rtn; } + enum OrcOperator { + OR, + AND, + NOT, + LEAF, + CONSTANT + } + + enum OrcLeafOperator { + EQUALS, + NULL_SAFE_EQUALS, + LESS_THAN, + LESS_THAN_EQUALS, + IN, + BETWEEN, // not use, spark transfers it to gt and lt + IS_NULL + } + + enum OrcPredicateDataType { + LONG, // all of integer types + FLOAT, // float and double + STRING, // string, char, varchar + DATE, + DECIMAL, + TIMESTAMP, + BOOLEAN + } + + private OrcPredicateDataType getOrcPredicateDataType(String attribute) { + StructField field = requiredSchema.apply(attribute); + org.apache.spark.sql.types.DataType dataType = field.dataType(); + if (dataType instanceof ShortType || dataType instanceof IntegerType || + dataType instanceof LongType) { + return OrcPredicateDataType.LONG; + } else if (dataType instanceof DoubleType) { + return OrcPredicateDataType.FLOAT; + } else if (dataType instanceof StringType) { + if (isCharType(field.metadata())) { + throw new UnsupportedOperationException("Unsupported orc push down filter data type: char"); + } + return OrcPredicateDataType.STRING; + } else if (dataType instanceof DateType) { + return OrcPredicateDataType.DATE; + } else if (dataType instanceof DecimalType) { + return OrcPredicateDataType.DECIMAL; + } else if (dataType instanceof BooleanType) { + return OrcPredicateDataType.BOOLEAN; + } else { + throw new UnsupportedOperationException("Unsupported orc push down filter data type: " + + dataType.getClass().getSimpleName()); + } + } + + // Check the type whether is char type, which orc native does not support push down + private boolean isCharType(Metadata metadata) { + if (metadata != null) { + String rawTypeString = CharVarcharUtils.getRawTypeString(metadata).getOrElse(null); + if (rawTypeString != null) { + Matcher matcher = CHAR_TYPE.matcher(rawTypeString); + return matcher.matches(); + } + } + return false; + } + + private boolean canPushDown(Filter pushedFilter, JSONObject jsonExpressionTree, + JSONObject jsonLeaves) { + try { + getExprJson(pushedFilter, jsonExpressionTree, jsonLeaves); + if (leafIndex > MAX_LEAF_THRESHOLD) { + throw new UnsupportedOperationException("leaf node nums is " + leafIndex + + ", which is bigger than max threshold " + MAX_LEAF_THRESHOLD + "."); + } + return true; + } catch (Exception e) { + LOGGER.info("Unable to push down orc filter because " + e.getMessage()); + return false; + } + } + + private void getExprJson(Filter filterPredicate, JSONObject jsonExpressionTree, + JSONObject jsonLeaves) { + if (filterPredicate instanceof And) { + addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.AND, + ((And) filterPredicate).left(), ((And) filterPredicate).right()); + } else if (filterPredicate instanceof Or) { + addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.OR, + ((Or) filterPredicate).left(), ((Or) filterPredicate).right()); + } else if (filterPredicate instanceof Not) { + addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.NOT, + ((Not) filterPredicate).child()); + } else if (filterPredicate instanceof EqualTo) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.EQUALS, jsonLeaves, + ((EqualTo) filterPredicate).attribute(), ((EqualTo) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof GreaterThan) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN_EQUALS, jsonLeaves, + ((GreaterThan) filterPredicate).attribute(), ((GreaterThan) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof GreaterThanOrEqual) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN, jsonLeaves, + ((GreaterThanOrEqual) filterPredicate).attribute(), ((GreaterThanOrEqual) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof LessThan) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN, jsonLeaves, + ((LessThan) filterPredicate).attribute(), ((LessThan) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof LessThanOrEqual) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN_EQUALS, jsonLeaves, + ((LessThanOrEqual) filterPredicate).attribute(), ((LessThanOrEqual) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof IsNotNull) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IS_NULL, jsonLeaves, + ((IsNotNull) filterPredicate).attribute(), null, null); + leafIndex++; + } else if (filterPredicate instanceof IsNull) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IS_NULL, jsonLeaves, + ((IsNull) filterPredicate).attribute(), null, null); + leafIndex++; + } else if (filterPredicate instanceof In) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IN, jsonLeaves, + ((In) filterPredicate).attribute(), null, Arrays.stream(((In) filterPredicate).values()).toArray()); + leafIndex++; + } else { + throw new UnsupportedOperationException("Unsupported orc push down filter operation: " + + filterPredicate.getClass().getSimpleName()); + } + } + + private void addLiteralToJsonLeaves(String leaf, OrcLeafOperator leafOperator, JSONObject jsonLeaves, + String name, Object literal, Object[] literals) { + JSONObject leafJson = new JSONObject(); + leafJson.put("op", leafOperator.ordinal()); + leafJson.put("name", name); + leafJson.put("type", getOrcPredicateDataType(name).ordinal()); + + leafJson.put("literal", getLiteralValue(literal)); + + ArrayList literalList = new ArrayList<>(); + if (literals != null) { + for (Object lit: literalList) { + literalList.add(getLiteralValue(literal)); + } + } + leafJson.put("literalList", literalList); + jsonLeaves.put(leaf, leafJson); + } + + private void addToJsonExpressionTree(String leaf, JSONObject jsonExpressionTree, boolean addNot) { + if (addNot) { + jsonExpressionTree.put("op", OrcOperator.NOT.ordinal()); + ArrayList child = new ArrayList<>(); + JSONObject subJson = new JSONObject(); + subJson.put("op", OrcOperator.LEAF.ordinal()); + subJson.put("leaf", leaf); + child.add(subJson); + jsonExpressionTree.put("child", child); + } else { + jsonExpressionTree.put("op", OrcOperator.LEAF.ordinal()); + jsonExpressionTree.put("leaf", leaf); + } + } + + private void addChildJson(JSONObject jsonExpressionTree, JSONObject jsonLeaves, + OrcOperator orcOperator, Filter ... filters) { + jsonExpressionTree.put("op", orcOperator.ordinal()); + ArrayList child = new ArrayList<>(); + for (Filter filter: filters) { + JSONObject subJson = new JSONObject(); + getExprJson(filter, subJson, jsonLeaves); + child.add(subJson); + } + jsonExpressionTree.put("child", child); + } + + private String getLiteralValue(Object literal) { + // For null literal, the predicate will not be pushed down. + if (literal == null) { + throw new UnsupportedOperationException("Unsupported orc push down filter for literal is null"); + } + + // For Decimal Type, we use the special string format to represent, which is "$decimalVal + // $precision $scale". + // e.g., Decimal(9, 3) = 123456.789, it outputs "123456.789 9 3". + // e.g., Decimal(9, 3) = 123456.7, it outputs "123456.700 9 3". + if (literal instanceof BigDecimal) { + BigDecimal value = (BigDecimal) literal; + int precision = value.precision(); + int scale = value.scale(); + String[] split = value.toString().split("\\."); + if (scale == 0) { + return split[0] + " " + precision + " " + scale; + } else { + String padded = padZeroForDecimals(split, scale); + return split[0] + "." + padded + " " + precision + " " + scale; + } + } + // For Date Type, spark uses Gregorian in default but orc uses Julian, which should be converted. + if (literal instanceof LocalDate) { + int epochDay = Math.toIntExact(((LocalDate) literal).toEpochDay()); + int rebased = RebaseDateTime.rebaseGregorianToJulianDays(epochDay); + return String.valueOf(rebased); + } + if (literal instanceof String) { + return (String) literal; + } + if (literal instanceof Integer || literal instanceof Long || literal instanceof Boolean || + literal instanceof Short || literal instanceof Double) { + return literal.toString(); + } + throw new UnsupportedOperationException("Unsupported orc push down filter date type: " + + literal.getClass().getSimpleName()); + } + private static String bytesToHexString(byte[] bytes) { if (bytes == null || bytes.length < 1) { throw new IllegalArgumentException("this bytes must not be null or empty"); diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java index 55d56ae20203d4fdfa551f974a4d92b4210e1a76..0859173ca18d6f8ae4730b57c6be0f474d15f94f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java @@ -80,12 +80,12 @@ public class ShuffleDataSerializer { long[] omniVecs = new long[vecCount]; int[] omniTypes = new int[vecCount]; createEmptyVec(rowBatch, omniTypes, omniVecs, columnarVecs, vecCount, rowCount); - OmniRowDeserializer deserializer = new OmniRowDeserializer(omniTypes); + OmniRowDeserializer deserializer = new OmniRowDeserializer(omniTypes, omniVecs); for (int rowIdx = 0; rowIdx < rowCount; rowIdx++) { VecData.ProtoRow protoRow = rowBatch.getRows(rowIdx); byte[] array = protoRow.getData().toByteArray(); - deserializer.parse(array, omniVecs, rowIdx); + deserializer.parse(array, rowIdx); } // update initial varchar vector because it's capacity might have been expanded. diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 3152d6c7c4f8d928d5df4941cf75826dd5cbe33e..8e25208f7c7a7ac7418ddd0fe46c2ecc603a5518 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -76,7 +76,7 @@ object OmniExpressionAdaptor extends Logging { } } - private def unsupportedCastCheck(expr: Expression, cast: CastBase): Unit = { + private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = { def doSupportCastToString(dataType: DataType): Boolean = { dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[IntegerType] || dataType.isInstanceOf[LongType] || dataType.isInstanceOf[DateType] || dataType.isInstanceOf[DoubleType] || @@ -321,7 +321,7 @@ object OmniExpressionAdaptor extends Logging { .put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.len, exprsIndexMap))) // Cast - case cast: CastBase => + case cast: Cast => unsupportedCastCheck(expr, cast) cast.child.dataType match { case NullType => @@ -640,10 +640,10 @@ object OmniExpressionAdaptor extends Logging { rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap) } else { children.head match { - case base: CastBase if base.child.dataType.isInstanceOf[NullType] => + case base: Cast if base.child.dataType.isInstanceOf[NullType] => rewriteToOmniJsonExpressionLiteralJsonObject(children(1), exprsIndexMap) case _ => children(1) match { - case base: CastBase if base.child.dataType.isInstanceOf[NullType] => + case base: Cast if base.child.dataType.isInstanceOf[NullType] => rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap) case _ => new JSONObject().put("exprType", "FUNCTION") diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala index ea271244784cc5466dbbce00b4c942a8e22777d0..4d8c246ab52ebd7f11ea416c18080c216f1e963e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala @@ -26,6 +26,7 @@ object TreePattern extends Enumeration { val AGGREGATE_EXPRESSION = Value(0) val ALIAS: Value = Value val AND_OR: Value = Value + val AND: Value = Value val ARRAYS_ZIP: Value = Value val ATTRIBUTE_REFERENCE: Value = Value val APPEND_COLUMNS: Value = Value @@ -58,6 +59,7 @@ object TreePattern extends Enumeration { val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value + val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value val LATERAL_SUBQUERY: Value = Value val LIKE_FAMLIY: Value = Value val LIST_SUBQUERY: Value = Value @@ -69,7 +71,10 @@ object TreePattern extends Enumeration { val NULL_CHECK: Value = Value val NULL_LITERAL: Value = Value val SERIALIZE_FROM_OBJECT: Value = Value + val OR: Value = Value val OUTER_REFERENCE: Value = Value + val PARAMETER: Value = Value + val PARAMETERIZED_QUERY: Value = Value val PIVOT: Value = Value val PLAN_EXPRESSION: Value = Value val PYTHON_UDF: Value = Value @@ -81,6 +86,7 @@ object TreePattern extends Enumeration { val SCALAR_SUBQUERY: Value = Value val SCALAR_SUBQUERY_REFERENCE: Value = Value val SCALA_UDF: Value = Value + val SESSION_WINDOW: Value = Value val SORT: Value = Value val SUBQUERY_ALIAS: Value = Value val SUBQUERY_WRAPPER: Value = Value @@ -89,7 +95,9 @@ object TreePattern extends Enumeration { val TIME_ZONE_AWARE_EXPRESSION: Value = Value val TRUE_OR_FALSE_LITERAL: Value = Value val WINDOW_EXPRESSION: Value = Value + val WINDOW_TIME: Value = Value val UNARY_POSITIVE: Value = Value + val UNPIVOT: Value = Value val UPDATE_FIELDS: Value = Value val UPPER_OR_LOWER: Value = Value val UP_CAST: Value = Value @@ -119,6 +127,7 @@ object TreePattern extends Enumeration { val UNION: Value = Value val UNRESOLVED_RELATION: Value = Value val UNRESOLVED_WITH: Value = Value + val TEMP_RESOLVED_COLUMN: Value = Value val TYPED_FILTER: Value = Value val WINDOW: Value = Value val WITH_WINDOW_DEFINITION: Value = Value @@ -127,6 +136,7 @@ object TreePattern extends Enumeration { val UNRESOLVED_ALIAS: Value = Value val UNRESOLVED_ATTRIBUTE: Value = Value val UNRESOLVED_DESERIALIZER: Value = Value + val UNRESOLVED_HAVING: Value = Value val UNRESOLVED_ORDINAL: Value = Value val UNRESOLVED_FUNCTION: Value = Value val UNRESOLVED_HINT: Value = Value @@ -135,6 +145,8 @@ object TreePattern extends Enumeration { // Unresolved Plan patterns (Alphabetically ordered) val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value val UNRESOLVED_FUNC: Value = Value + val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value + val UNRESOLVED_TVF_ALIASES: Value = Value // Execution expression patterns (alphabetically ordered) val IN_SUBQUERY_EXEC: Value = Value diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala new file mode 100644 index 0000000000000000000000000000000000000000..5a04f02cedf51834f86b4a1e48154a5814c2c0e9 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal._ +import org.apache.spark.sql.types._ + + +// scalastyle:off +/** + * Calculates and propagates precision for fixed-precision decimals. Hive has a number of + * rules for this based on the SQL standard and MS SQL: + * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * https://msdn.microsoft.com/en-us/library/ms190476.aspx + * + * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 + * respectively, then the following operations have the following precision / scale: + * + * Operation Result Precision Result Scale + * ------------------------------------------------------------------------ + * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 * e2 p1 + p2 + 1 s1 + s2 + * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) + * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) + * + * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale + * needed are out of the range of available values, the scale is reduced up to 6, in order to + * prevent the truncation of the integer part of the decimals. + * + * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited + * precision, do the math on unlimited-precision numbers, then introduce casts back to the + * required fixed precision. This allows us to do all rounding and overflow handling in the + * cast-to-fixed-precision operator. + * + * In addition, when mixing non-decimal types with decimals, we use the following rules: + * - BYTE gets turned into DECIMAL(3, 0) + * - SHORT gets turned into DECIMAL(5, 0) + * - INT gets turned into DECIMAL(10, 0) + * - LONG gets turned into DECIMAL(20, 0) + * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + * - Literals INT and LONG get turned into DECIMAL with the precision strictly needed by the value + */ +// scalastyle:on +object DecimalPrecision extends TypeCoercionRule { + import scala.math.{max, min} + + private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + + // Returns the wider decimal type that's wider than both of them + def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = { + widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale) + } + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { + val scale = max(s1, s2) + val range = max(p1 - s1, p2 - s2) + DecimalType.bounded(range + scale, scale) + } + + private def promotePrecision(e: Expression, dataType: DataType): Expression = { + PromotePrecision(Cast(e, dataType)) + } + + override def transform: PartialFunction[Expression, Expression] = { + decimalAndDecimal() + .orElse(integralAndDecimalLiteral) + .orElse(nondecimalAndDecimal(conf.literalPickMinimumPrecision)) + } + + private[catalyst] def decimalAndDecimal(): PartialFunction[Expression, Expression] = { + decimalAndDecimal(conf.decimalOperationsAllowPrecisionLoss, !conf.ansiEnabled) + } + + /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */ + private[catalyst] def decimalAndDecimal(allowPrecisionLoss: Boolean, nullOnOverflow: Boolean) + : PartialFunction[Expression, Expression] = { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + // Skip nodes who is already promoted + case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e + + case a @ Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultScale = max(s1, s2) + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow( + a.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)), + resultType, nullOnOverflow) + + case s @ Subtract(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2), _) => + val resultScale = max(s1, s2) + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow( + s.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)), + resultType, nullOnOverflow) + + case m @ Multiply( + e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) + } else { + DecimalType.bounded(p1 + p2 + 1, s1 + s2) + } + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow( + m.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) + + case d @ Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultType = if (allowPrecisionLoss) { + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + val intDig = p1 - s1 + s2 + val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) + val prec = intDig + scale + DecimalType.adjustPrecisionScale(prec, scale) + } else { + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + DecimalType.bounded(intDig + decDig, decDig) + } + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow( + d.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) + + case r @ Remainder( + e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow( + r.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) + + case p @ Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val resultType = if (allowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow( + p.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) + + case expr @ IntegralDivide( + e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => + val widerType = widerDecimalType(p1, s1, p2, s2) + val promotedExpr = expr.copy( + left = promotePrecision(e1, widerType), + right = promotePrecision(e2, widerType)) + if (expr.dataType.isInstanceOf[DecimalType]) { + // This follows division rule + val intDig = p1 - s1 + s2 + // No precision loss can happen as the result scale is 0. + // Overflow can happen only in the promote precision of the operands, but if none of them + // overflows in that phase, no overflow can happen, but CheckOverflow is needed in order + // to return a decimal with the proper scale and precision + CheckOverflow(promotedExpr, DecimalType.bounded(intDig, 0), nullOnOverflow) + } else { + promotedExpr + } + + case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + val resultType = widerDecimalType(p1, s1, p2, s2) + val newE1 = if (e1.dataType == resultType) e1 else Cast(e1, resultType) + val newE2 = if (e2.dataType == resultType) e2 else Cast(e2, resultType) + b.makeCopy(Array(newE1, newE2)) + } + + /** + * Strength reduction for comparing integral expressions with decimal literals. + * 1. int_col > decimal_literal => int_col > floor(decimal_literal) + * 2. int_col >= decimal_literal => int_col >= ceil(decimal_literal) + * 3. int_col < decimal_literal => int_col < ceil(decimal_literal) + * 4. int_col <= decimal_literal => int_col <= floor(decimal_literal) + * 5. decimal_literal > int_col => ceil(decimal_literal) > int_col + * 6. decimal_literal >= int_col => floor(decimal_literal) >= int_col + * 7. decimal_literal < int_col => floor(decimal_literal) < int_col + * 8. decimal_literal <= int_col => ceil(decimal_literal) <= int_col + * + * Note that technically this is an "optimization" and should go into the optimizer. However, + * by the time the optimizer runs, these comparison expressions would be pretty hard to pattern + * match because there are multiple (at least 2) levels of casts involved. + * + * There are a lot more possible rules we can implement, but we don't do them + * because we are not sure how common they are. + */ + private val integralAndDecimalLiteral: PartialFunction[Expression, Expression] = { + + case GreaterThan(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + GreaterThan(i, Literal(value.floor.toLong)) + } + + case GreaterThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + GreaterThanOrEqual(i, Literal(value.ceil.toLong)) + } + + case LessThan(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + LessThan(i, Literal(value.ceil.toLong)) + } + + case LessThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + LessThanOrEqual(i, Literal(value.floor.toLong)) + } + + case GreaterThan(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + GreaterThan(Literal(value.ceil.toLong), i) + } + + case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + GreaterThanOrEqual(Literal(value.floor.toLong), i) + } + + case LessThan(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + LessThan(Literal(value.floor.toLong), i) + } + + case LessThanOrEqual(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + LessThanOrEqual(Literal(value.ceil.toLong), i) + } + } + + /** + * Type coercion for BinaryOperator in which one side is a non-decimal numeric, and the other + * side is a decimal. + */ + private def nondecimalAndDecimal(literalPickMinimumPrecision: Boolean) + : PartialFunction[Expression, Expression] = { + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + (left, right) match { + // Promote literal integers inside a binary expression with fixed-precision decimals to + // decimals. The precision and scale are the ones strictly needed by the integer value. + // Requiring more precision than necessary may lead to a useless loss of precision. + // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2. + // If we use the default precision and scale for the integer type, 2 is considered a + // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18), + // which is out of range and therefore it will become DECIMAL(38, 7), leading to + // potentially loosing 11 digits of the fractional part. Using only the precision needed + // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would + // become DECIMAL(38, 16), safely having a much lower precision loss. + case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] && + l.dataType.isInstanceOf[IntegralType] && + literalPickMinimumPrecision => + b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r)) + case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] && + r.dataType.isInstanceOf[IntegralType] && + literalPickMinimumPrecision => + b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r)))) + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) => + b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r)) + case (l @ DecimalType.Expression(_, _), r @ IntegralType()) => + b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType)))) + case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) => + b.makeCopy(Array(l, Cast(r, DoubleType))) + case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) => + b.makeCopy(Array(Cast(l, DoubleType), r)) + case _ => b + } + } + +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala new file mode 100644 index 0000000000000000000000000000000000000000..09e016c974c796422885fac5a14f428c81de4dcc --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * Return the unscaled Long value of a Decimal, assuming it fits in a Long. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ +case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant { + + override def dataType: DataType = LongType + override def toString: String = s"UnscaledValue($child)" + + protected override def nullSafeEval(input: Any): Any = + input.asInstanceOf[Decimal].toUnscaledLong + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") + } + + override protected def withNewChildInternal(newChild: Expression): UnscaledValue = + copy(child = newChild) +} + +/** + * Create a Decimal from an unscaled Long value. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ +case class MakeDecimal( + child: Expression, + precision: Int, + scale: Int, + nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant { + + def this(child: Expression, precision: Int, scale: Int) = { + this(child, precision, scale, !SQLConf.get.ansiEnabled) + } + + override def dataType: DataType = DecimalType(precision, scale) + override def nullable: Boolean = child.nullable || nullOnOverflow + override def toString: String = s"MakeDecimal($child,$precision,$scale)" + + protected override def nullSafeEval(input: Any): Any = { + val longInput = input.asInstanceOf[Long] + val result = new Decimal() + if (nullOnOverflow) { + result.setOrNull(longInput, precision, scale) + } else { + result.set(longInput, precision, scale) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, eval => { + val setMethod = if (nullOnOverflow) { + "setOrNull" + } else { + "set" + } + val setNull = if (nullable) { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } + s""" + |${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale); + |$setNull + |""".stripMargin + }) + } + + override protected def withNewChildInternal(newChild: Expression): MakeDecimal = + copy(child = newChild) +} + +object MakeDecimal { + def apply(child: Expression, precision: Int, scale: Int): MakeDecimal = { + new MakeDecimal(child, precision, scale) + } +} + +/** + * An expression used to wrap the children when promote the precision of DecimalType to avoid + * promote multiple times. + */ +case class PromotePrecision(child: Expression) extends UnaryExpression { + override def dataType: DataType = child.dataType + override def eval(input: InternalRow): Any = child.eval(input) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + child.genCode(ctx) + override def prettyName: String = "promote_precision" + override def sql: String = child.sql + override lazy val canonicalized: Expression = child.canonicalized + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +/** + * Rounds the decimal to given scale and check whether the decimal can fit in provided precision + * or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an + * `ArithmeticException` is thrown. + */ +case class CheckOverflow( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean) extends UnaryExpression with SupportQueryContext { + + override def nullable: Boolean = true + + override def nullSafeEval(input: Any): Any = + input.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow, + getContextOrNull()) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val errorContextCode = if (nullOnOverflow) { + "\"\"" + } else { + ctx.addReferenceObj("errCtx", queryContext) + } + nullSafeCodeGen(ctx, ev, eval => { + // scalastyle:off line.size.limit + s""" + |${ev.value} = $eval.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode); + |${ev.isNull} = ${ev.value} == null; + """.stripMargin + // scalastyle:on line.size.limit + }) + } + + override def toString: String = s"CheckOverflow($child, $dataType)" + + override def sql: String = child.sql + + override protected def withNewChildInternal(newChild: Expression): CheckOverflow = + copy(child = newChild) + + override def initQueryContext(): Option[SQLQueryContext] = if (nullOnOverflow) { + Some(origin.context) + } else { + None + } +} + +// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`. +case class CheckOverflowInSum( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean, + context: SQLQueryContext) extends UnaryExpression { + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + if (nullOnOverflow) null + else throw QueryExecutionErrors.overflowInSumOfDecimalError(context) + } else { + value.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow, + context) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val errorContextCode = if (nullOnOverflow) { + "\"\"" + } else { + ctx.addReferenceObj("errCtx", context) + } + val nullHandling = if (nullOnOverflow) { + "" + } else { + s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);" + } + // scalastyle:off line.size.limit + val code = code""" + |${childGen.code} + |boolean ${ev.isNull} = ${childGen.isNull}; + |Decimal ${ev.value} = null; + |if (${childGen.isNull}) { + | $nullHandling + |} else { + | ${ev.value} = ${childGen.value}.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode); + | ${ev.isNull} = ${ev.value} == null; + |} + |""".stripMargin + // scalastyle:on line.size.limit + + ev.copy(code = code) + } + + override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)" + + override def sql: String = child.sql + + override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum = + copy(child = newChild) +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala index 0a5d509b05af809fccbedcb28a74cbaea73b40b8..85192fc36038815c38167b64da74e51896d688fa 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala @@ -39,7 +39,7 @@ case class RuntimeFilterSubquery( exprId: ExprId = NamedExpression.newExprId, hint: Option[HintInfo] = None) extends SubqueryExpression( - filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty) + filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty, hint) with Unevaluable with UnaryLike[Expression] { @@ -74,6 +74,8 @@ case class RuntimeFilterSubquery( override protected def withNewChildInternal(newChild: Expression): RuntimeFilterSubquery = copy(filterApplicationSideExp = newChild) + + override def withNewHint(hint: Option[HintInfo]): RuntimeFilterSubquery = copy(hint = hint) } /** diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 812c387bc6b8e5d253f82efb6ac639eebde71c22..ac38c5399325c48475df32f346b2af3a13e38361 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -218,9 +218,9 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J leftKey: Expression, rightKey: Expression): Boolean = { (left, right) match { - case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) => + case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan), _) => pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey) - case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) => + case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan)) => pruningKey.fastEquals(rightKey) || hasDynamicPruningSubquery(left, plan, leftKey, rightKey) case _ => false @@ -251,10 +251,10 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J rightKey: Expression): Boolean = { (left, right) match { case (Filter(InSubquery(Seq(key), - ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) => + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _), _) => key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey))) case (_, Filter(InSubquery(Seq(key), - ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) => + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _)) => key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey))) case _ => false } @@ -299,7 +299,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J case s: Subquery if s.correlated => plan case _ if !conf.runtimeFilterSemiJoinReductionEnabled && !conf.runtimeFilterBloomFilterEnabled => plan - case _ => tryInjectRuntimeFilter(plan) + case _ => + val newPlan = tryInjectRuntimeFilter(plan) + if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) { + RewritePredicateSubquery(newPlan) + } else { + newPlan + } } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala index 1b5baa23080e816868934da02c8cbe87fa21c4d8..c4435379ffb153b5f2d211a1eec044102f2a1f6c 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala @@ -643,7 +643,7 @@ object MergeSubqueryFilters extends Rule[LogicalPlan] { val subqueryCTE = header.plan.asInstanceOf[CTERelationDef] GetStructField( ScalarSubquery( - CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output), + CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output, subqueryCTE.isStreaming), exprId = ssr.exprId), ssr.headerIndex) } else { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala index 9e402902518a90df7f68db4b420d325a53f9bf41..f6ebd716dc1ebb25d3f36bcb5705168f0d76934e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala @@ -61,7 +61,7 @@ object RewriteSelfJoinInInPredicate extends Rule[LogicalPlan] with PredicateHelp case f: Filter => f transformExpressions { case in @ InSubquery(_, listQuery @ ListQuery(Project(projectList, - Join(left, right, Inner, Some(joinCond), _)), _, _, _, _)) + Join(left, right, Inner, Some(joinCond), _)), _, _, _, _, _)) if left.canonicalized ne right.canonicalized => val attrMapping = AttributeMap(right.output.zip(left.output)) val subCondExprs = splitConjunctivePredicates(joinCond transform { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index 4863698430e78896de76218897b30870bad46634..b037930804539c896dd2c67b71828dd7813035f6 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -41,11 +41,12 @@ import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.expression.ColumnarExpressionConverter import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.catalyst.plans.{AliasAwareOutputExpression, AliasAwareQueryOutputOrdering} case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode - with AliasAwareOutputPartitioning - with AliasAwareOutputOrdering { + with AliasAwareOutputExpression + with AliasAwareQueryOutputOrdering[SparkPlan] { override def supportsColumnar: Boolean = true @@ -267,8 +268,8 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], condition: Expression, child: SparkPlan) extends UnaryExecNode - with AliasAwareOutputPartitioning - with AliasAwareOutputOrdering { + with AliasAwareOutputExpression + with AliasAwareQueryOutputOrdering[SparkPlan] { override def supportsColumnar: Boolean = true diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index 800dcf1a0c047b206603b24a2a963ef7f3e48db1..80d796ce126774970594d75195a186be096f5de3 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -479,8 +479,8 @@ abstract class BaseColumnarFileSourceScanExec( } }.groupBy { f => BucketingUtils - .getBucketId(new Path(f.filePath).getName) - .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.filePath)) + .getBucketId(f.toPath.getName) + .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath)) } val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index ef33a84decfd9aabd62b4ea149dca78b7e0ed988..3630b0f2ea14bf7a4c4d87ba739d6734b84a870e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -105,12 +105,29 @@ class QueryExecution( case other => other } + // The plan that has been normalized by custom rules, so that it's more likely to hit cache. + lazy val normalized: LogicalPlan = { + val normalizationRules = sparkSession.sessionState.planNormalizationRules + if (normalizationRules.isEmpty) { + commandExecuted + } else { + val planChangeLogger = new PlanChangeLogger[LogicalPlan]() + val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) => + val result = rule.apply(p) + planChangeLogger.logRule(rule.ruleName, p, result) + result + } + planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized) + normalized + } + } + lazy val withCachedData: LogicalPlan = sparkSession.withActive { assertAnalyzed() assertSupported() // clone the plan to avoid sharing the plan instance between different stages like analyzing, // optimizing and planning. - sparkSession.sharedState.cacheManager.useCachedData(commandExecuted.clone()) + sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) } def assertCommandExecuted(): Unit = commandExecuted @@ -227,7 +244,7 @@ class QueryExecution( // output mode does not matter since there is no `Sink`. new IncrementalExecution( sparkSession, logical, OutputMode.Append(), "", - UUID.randomUUID, UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) + UUID.randomUUID, UUID.randomUUID, 0, None ,OffsetSeqMetadata(0, 0)) } else { this } @@ -494,11 +511,10 @@ object QueryExecution { */ private[sql] def toInternalError(msg: String, e: Throwable): Throwable = e match { case e @ (_: java.lang.NullPointerException | _: java.lang.AssertionError) => - new SparkException( - errorClass = "INTERNAL_ERROR", - messageParameters = Array(msg + - " Please, fill a bug report in, and provide the full stack trace."), - cause = e) + SparkException.internalError( + msg + " You hit a bug in Spark or the Spark plugins you use. Please, report this bug " + + "to the corresponding communities or vendors, and provide the full stack trace.", + e) case e: Throwable => e } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 6234751dde4231261f56d55be2a07bc42c65ea03..f964f64902ca843647499b7b7bf4940d9d8b337b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -84,7 +84,9 @@ case class AdaptiveSparkPlanExec( @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]() // The logical plan optimizer for re-optimizing the current logical plan. - @transient private val optimizer = new AQEOptimizer(conf) + @transient private val optimizer = new AQEOptimizer(conf, + session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules + ) // `EnsureRequirements` may remove user-specified repartition and assume the query plan won't // change its output partitioning. This assumption is not true in AQE. Here we check the @@ -123,7 +125,7 @@ case class AdaptiveSparkPlanExec( RemoveRedundantSorts, DisableUnnecessaryBucketedScan, OptimizeSkewedJoin(ensureRequirements) - ) ++ context.session.sessionState.queryStagePrepRules + ) ++ context.session.sessionState.adaptiveRulesHolder.queryStagePrepRules } // A list of physical optimizer rules to be applied to a new stage before its execution. These diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index b5a1ad3754929c738a1dcc1dda5da9047a27305f..dfdbe2c7038298794a375bf701157acfe03956ec 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -30,11 +30,11 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning(_.containsAnyPattern( SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY, RUNTIME_FILTER_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _) => + case expressions.ScalarSubquery(_, _, exprId, _, _, _) => val subquery = SubqueryExec.createForScalarSubquery( s"subquery#${exprId.id}", subqueryMap(exprId.id)) execution.ScalarSubquery(subquery, exprId) - case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _)) => + case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) => val expr = if (values.length == 1) { values.head } else { @@ -47,7 +47,7 @@ case class PlanAdaptiveSubqueries( val subquery = SubqueryExec(s"subquery#${exprId.id}", subqueryMap(exprId.id)) InSubqueryExec(expr, subquery, exprId, shouldBroadcast = true) case expressions.DynamicPruningSubquery(value, buildPlan, - buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) => + buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) => val name = s"dynamicpruning#${exprId.id}" val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndex, onlyInBroadcast, buildPlan, buildKeys, subqueryMap(exprId.id)) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala index 422435694a86284aed9e8515a94dfaca88bbfd6a..cba7f366fdea0b3e6665a27eec1faa8f54321c42 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala @@ -34,7 +34,7 @@ case class PushOrderedLimitThroughAgg(session: SparkSession) extends Rule[SparkP val enableColumnarTopNSort: Boolean = columnarConf.enableColumnarTopNSort plan.transform { - case orderAndProject @ TakeOrderedAndProjectExec(limit, sortOrder, projectList, orderAndProjectChild) => { + case orderAndProject @ TakeOrderedAndProjectExec(limit, sortOrder, projectList, orderAndProjectChild, _) => { orderAndProjectChild match { case finalAgg @ HashAggregateExec(_, _, _, _, _, _, _, _, finalAggChild) => finalAggChild match { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala index 465123d8730d39656d242bd4785e22680312ce05..a65e66c4edf3ccd9bad15fc34e16e9155438270e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala @@ -272,7 +272,7 @@ object OmniFileFormatWriter extends Logging { case cause: Throwable => logError(s"Aborting job ${description.uuid}.", cause) committer.abortJob(job) - throw QueryExecutionErrors.jobAbortedError(cause) + throw cause } } @@ -343,7 +343,7 @@ object OmniFileFormatWriter extends Logging { // We throw the exception and let Executor throw ExceptionFailure to abort the job. throw new TaskOutputFileAlreadyExistException(f) case t: Throwable => - throw QueryExecutionErrors.taskFailedWhileWritingRowsError(t) + throw QueryExecutionErrors.taskFailedWhileWritingRowsError(description.path, t) } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala index 9d0008e0b2cf1a57443e019bf41113f60c8855fd..7c60110911fedccbc88ccce0e9ffc970d9ec5e23 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala @@ -83,7 +83,6 @@ case class OmniInsertIntoHadoopFsRelationCommand( // Most formats don't do well with duplicate columns, so lets not allow that SchemaUtils.checkColumnNameDuplication( outputColumnNames, - s"when inserting into $outputPath", sparkSession.sessionState.conf.caseSensitiveAnalysis) val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java index 93950e9f0f1b9339587893b95f99c70967024acc..1880208fbeaed55b542cc5144cac7cc72d55c500 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -22,18 +22,15 @@ import com.google.common.annotations.VisibleForTesting; import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; import com.huawei.boostkit.spark.jni.OrcColumnarBatchScanReader; import nova.hetu.omniruntime.vector.Vec; -import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.RecordReader; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.input.FileSplit; -import org.apache.orc.OrcConf; -import org.apache.orc.OrcFile; -import org.apache.orc.Reader; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.OmniColumnVectorUtils; import org.apache.spark.sql.execution.vectorized.OmniColumnVector; +import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -49,26 +46,11 @@ import java.util.ArrayList; public class OmniOrcColumnarBatchReader extends RecordReader { // The capacity of vectorized batch. - private int capacity; - /** - * The column IDs of the physical ORC file schema which are required by this reader. - * -1 means this required column is partition column, or it doesn't exist in the ORC file. - * Ideally partition column should never appear in the physical file, and should only appear - * in the directory name. However, Spark allows partition columns inside physical file, - * but Spark will discard the values from the file, and use the partition value got from - * directory name. The column order will be reserved though. - */ - @VisibleForTesting - public int[] requestedDataColIds; - // Native Record reader from ORC row batch. private OrcColumnarBatchScanReader recordReader; - private StructField[] requiredFields; - private StructField[] resultFields; - // The result columnar batch for vectorized execution by whole-stage codegen. @VisibleForTesting public ColumnarBatch columnarBatch; @@ -82,13 +64,15 @@ public class OmniOrcColumnarBatchReader extends RecordReader orcfieldNames = recordReader.fildsNames; // save valid cols and numbers of valid cols recordReader.colsToGet = new int[requiredfieldNames.length]; - recordReader.realColsCnt = 0; - // save valid cols fieldsNames - recordReader.colToInclu = new ArrayList(); + recordReader.includedColumns = new ArrayList<>(); // collect read cols types ArrayList typeBuilder = new ArrayList<>(); + for (int i = 0; i < requiredfieldNames.length; i++) { String target = requiredfieldNames[i]; - boolean is_find = false; - for (int j = 0; j < orcfieldNames.size(); j++) { - String temp = orcfieldNames.get(j); - if (target.equals(temp)) { - requestedDataColIds[i] = i; - recordReader.colsToGet[i] = 0; - recordReader.colToInclu.add(requiredfieldNames[i]); - recordReader.realColsCnt++; - typeBuilder.add(OmniExpressionAdaptor.sparkTypeToOmniType(requiredSchema.fields()[i].dataType())); - is_find = true; - } - } - - // if invalid, set colsToGet value -1, else set colsToGet 0 - if (!is_find) { + // if not find, set colsToGet value -1, else set colsToGet 0 + if (recordReader.allFieldsNames.contains(target)) { + recordReader.colsToGet[i] = 0; + recordReader.includedColumns.add(requiredfieldNames[i]); + typeBuilder.add(OmniExpressionAdaptor.sparkTypeToOmniType(requiredSchema.fields()[i].dataType())); + } else { recordReader.colsToGet[i] = -1; } } vecTypeIds = typeBuilder.stream().mapToInt(Integer::intValue).toArray(); - - for (int i = 0; i < resultFields.length; i++) { - if (requestedPartitionColIds[i] != -1) { - requestedDataColIds[i] = -1; - } - } - - // set data members resultFields and requestedDataColIdS - this.resultFields = resultFields; - this.requestedDataColIds = requestedDataColIds; - - recordReader.requiredfieldNames = requiredfieldNames; - recordReader.precisionArray = precisionArray; - recordReader.scaleArray = scaleArray; - recordReader.initializeRecordReaderJava(options); } /** * Initialize columnar batch by setting required schema and partition information. * With this information, this creates ColumnarBatch with the full schema. * - * @param requiredFields The fields that are required to return,. - * @param resultFields All the fields that are required to return, including partition fields. - * @param requestedDataColIds Requested column ids from orcSchema. -1 if not existed. - * @param requestedPartitionColIds Requested column ids from partition schema. -1 if not existed. + * @param partitionColumns partition columns * @param partitionValues Values of partition columns. */ - public void initBatch( - StructField[] requiredFields, - StructField[] resultFields, - int[] requestedDataColIds, - int[] requestedPartitionColIds, - InternalRow partitionValues) { - if (resultFields.length != requestedDataColIds.length || resultFields.length != requestedPartitionColIds.length){ - throw new UnsupportedOperationException("This operator doesn't support orc initBatch."); - } + public void initBatch(StructType partitionColumns, InternalRow partitionValues) { + StructType resultSchema = new StructType(); - this.requiredFields = requiredFields; + for (StructField f: requiredSchema.fields()) { + resultSchema = resultSchema.add(f); + } - StructType resultSchema = new StructType(resultFields); + if (partitionColumns != null) { + for (StructField f: partitionColumns.fields()) { + resultSchema = resultSchema.add(f); + } + } // Just wrap the ORC column vector instead of copying it to Spark column vector. orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; templateWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; - for (int i = 0; i < resultFields.length; i++) { - DataType dt = resultFields[i].dataType(); - if (requestedPartitionColIds[i] != -1) { - OmniColumnVector partitionCol = new OmniColumnVector(capacity, dt, true); - OmniColumnVectorUtils.populate(partitionCol, partitionValues, requestedPartitionColIds[i]); + if (partitionColumns != null) { + int partitionIdx = requiredSchema.fields().length; + for (int i = 0; i < partitionColumns.fields().length; i++) { + OmniColumnVector partitionCol = new OmniColumnVector(capacity, partitionColumns.fields()[i].dataType(), true); + OmniColumnVectorUtils.populate(partitionCol, partitionValues, i); partitionCol.setIsConstant(); - templateWrappers[i] = partitionCol; - orcVectorWrappers[i] = new OmniColumnVector(capacity, dt, false);; + templateWrappers[i + partitionIdx] = partitionCol; + orcVectorWrappers[i + partitionIdx] = new OmniColumnVector(capacity, partitionColumns.fields()[i].dataType(), false); + } + } + + for (int i = 0; i < requiredSchema.fields().length; i++) { + DataType dt = requiredSchema.fields()[i].dataType(); + if (recordReader.colsToGet[i] == -1) { + // missing cols + OmniColumnVector missingCol = new OmniColumnVector(capacity, dt, true); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + templateWrappers[i] = missingCol; } else { - int colId = requestedDataColIds[i]; - // Initialize the missing columns once. - if (colId == -1) { - OmniColumnVector missingCol = new OmniColumnVector(capacity, dt, true); - missingCol.putNulls(0, capacity); - missingCol.setIsConstant(); - templateWrappers[i] = missingCol; - } else { - templateWrappers[i] = new OmniColumnVector(capacity, dt, false); - } - orcVectorWrappers[i] = new OmniColumnVector(capacity, dt, false); + templateWrappers[i] = new OmniColumnVector(capacity, dt, false); } + orcVectorWrappers[i] = new OmniColumnVector(capacity, dt, false); } + // init batch recordReader.initBatchJava(capacity); vecs = new Vec[orcVectorWrappers.length]; @@ -260,7 +210,7 @@ public class OmniOrcColumnarBatchReader extends RecordReader - convertibleFiltersHelper(left, dataSchema) && convertibleFiltersHelper(right, dataSchema) - case Or(left, right) => - convertibleFiltersHelper(left, dataSchema) && convertibleFiltersHelper(right, dataSchema) - case Not(pred) => - convertibleFiltersHelper(pred, dataSchema) - case other => - other match { - case EqualTo(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case EqualNullSafe(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case LessThan(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case LessThanOrEqual(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case GreaterThan(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case GreaterThanOrEqual(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case IsNull(name) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case IsNotNull(name) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case In(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case _ => false - } - } - - filters.map { filter => - convertibleFiltersHelper(filter, dataSchema) - } - } - override def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, @@ -101,7 +61,6 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) val sqlConf = sparkSession.sessionState.conf val capacity = sqlConf.orcVectorizedReaderBatchSize @@ -111,21 +70,17 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown - val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles (file: PartitionedFile) => { val conf = broadcastedConf.value.value - val filePath = new Path(new URI(file.filePath)) - val isPPDSafeValue = isPPDSafe(filters, dataSchema).reduceOption(_ && _) + val filePath = file.toPath // ORC predicate pushdown - if (orcFilterPushDown && filters.nonEmpty && isPPDSafeValue.getOrElse(false)) { - OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { - fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f => - OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) - } - } + val pushed = if (orcFilterPushDown) { + filters.reduceOption(And(_, _)) + } else { + None } val taskConf = new Configuration(conf) @@ -134,42 +89,16 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) // read data from vectorized reader - val batchReader = new OmniOrcColumnarBatchReader(capacity) + val batchReader = new OmniOrcColumnarBatchReader(capacity, requiredSchema, pushed.orNull) // SPARK-23399 Register a task completion listener first to call `close()` in all cases. // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) // after opening a file. val iter = new RecordReaderIterator(batchReader) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) - // fill requestedDataColIds with -1, fil real values int initDataColIds function - val requestedDataColIds = Array.fill(requiredSchema.length)(-1) ++ Array.fill(partitionSchema.length)(-1) - val requestedPartitionColIds = - Array.fill(requiredSchema.length)(-1) ++ Range(0, partitionSchema.length) - - // 初始化precision数组和scale数组,透传至java侧使用 - val requiredFields = requiredSchema.fields - val fieldslength = requiredFields.length - val precisionArray : Array[Int] = Array.ofDim[Int](fieldslength) - val scaleArray : Array[Int] = Array.ofDim[Int](fieldslength) - for ((reqField, index) <- requiredFields.zipWithIndex) { - val reqdatatype = reqField.dataType - if (reqdatatype.isInstanceOf[DecimalType]) { - val precision = reqdatatype.asInstanceOf[DecimalType].precision - val scale = reqdatatype.asInstanceOf[DecimalType].scale - precisionArray(index) = precision - scaleArray(index) = scale - } - } SparkMemoryUtils.init() batchReader.initialize(fileSplit, taskAttemptContext) - batchReader.initDataColIds(requiredSchema, requestedPartitionColIds, requestedDataColIds, resultSchema.fields, - precisionArray, scaleArray) - batchReader.initBatch( - requiredSchema.fields, - resultSchema.fields, - requestedDataColIds, - requestedPartitionColIds, - file.partitionValues) + batchReader.initBatch(partitionSchema, file.partitionValues) iter.asInstanceOf[Iterator[InternalRow]] } @@ -180,22 +109,6 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - new OutputWriterFactory { - override def getFileExtension(context: TaskAttemptContext): String = { - val compressionExtension: String = { - val name = context.getConfiguration.get(COMPRESS.getAttribute) - OrcUtils.extensionsForCompressionCodecNames.getOrElse(name, "") - } - - compressionExtension + ".orc" - } - - override def newInstance( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new OmniOrcOutputWriter(path, dataSchema, context) - } - } + throw new UnsupportedOperationException() } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala index 78acf30584d9a3c784b9bf7ec5ca52bc4abf252e..4586b8ec078ac0d226fd370f68db9fb9bc050b18 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala @@ -88,7 +88,7 @@ class OmniParquetFileFormat extends FileFormat with DataSourceRegister with Logg (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) - val filePath = new Path(new URI(file.filePath)) + val filePath = file.toPath val split = new org.apache.parquet.hadoop.ParquetInputSplit( filePath, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala index 2204e0fe4b0e3c04ede12b58260ce06cbbb15ce1..ba5db895e8311d3b1c42a37a7fdf429f3b249421 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala @@ -264,6 +264,7 @@ object BroadcastUtils { -1, -1L, -1, + -1, new TaskMemoryManager(memoryManager, -1L), new Properties, MetricsSystem.createMetricsSystem("OMNI_UNSAFE", sparkConf), diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java index fe1c55ffb6eeecc7a16521430759e3239ca77dd4..b236da6446a8d46c9d6bd6b95e039d7679e0e6df 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java @@ -68,7 +68,7 @@ public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java index 995c434f66a8b03a6a76830b8fb0b08be9a3223e..7f6c1acb9543aed62eb535f1a0abf512876aa471 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java @@ -66,7 +66,7 @@ public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java index c9ad9fadaf4ae13b1286467f965a23f66788c37f..2c912d919360bdd19c7ac9d20693decedf425919 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java @@ -66,7 +66,7 @@ public class OrcColumnarBatchJniReaderPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java index 8f4535338cc3715e33dbf336e5f15fd4eb91569f..cf86c0a5ae2b3770474509dc69f6ebc10b4598a1 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java @@ -68,7 +68,7 @@ public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java index 27bcf5d7bdbd4787f42e90e620d1627678f70910..ef8d037bf7a9a7d58c33124c0d8745ce0b584ca9 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java @@ -68,7 +68,7 @@ public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java index eab15fef660250780e0beb311b489e3ceeb8ff5b..19f23db00d133ddb47c475c42003ace1ebecf675 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java @@ -18,103 +18,90 @@ package com.huawei.boostkit.spark.jni; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.io.Input; +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; import junit.framework.TestCase; import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.VarcharVec; import nova.hetu.omniruntime.vector.Vec; -import org.apache.commons.codec.binary.Base64; -import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; -import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl; -import org.apache.orc.OrcFile; -import org.apache.orc.TypeDescription; -import org.apache.orc.mapred.OrcInputFormat; +import org.apache.hadoop.conf.Configuration; +import org.apache.orc.Reader; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.junit.After; import org.junit.Before; import org.junit.FixMethodOrder; import org.junit.Test; import org.junit.runners.MethodSorters; -import org.apache.hadoop.conf.Configuration; + import java.io.File; import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; -import java.util.List; import java.util.Arrays; -import org.apache.orc.Reader.Options; - -import static org.junit.Assert.*; +import java.util.List; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import static org.apache.spark.sql.types.DataTypes.LongType; +import static org.apache.spark.sql.types.DataTypes.StringType; @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderTest extends TestCase { public Configuration conf = new Configuration(); public OrcColumnarBatchScanReader orcColumnarBatchScanReader; - public int batchSize = 4096; + private int batchSize = 4096; + + private StructType requiredSchema; + private int[] vecTypeIds; + + private long offset = 0; + + private long length = Integer.MAX_VALUE; @Before public void setUp() throws Exception { - Configuration conf = new Configuration(); - TypeDescription schema = - TypeDescription.fromString("struct<`i_item_sk`:bigint,`i_item_id`:string>"); - Options options = new Options(conf) - .range(0, Integer.MAX_VALUE) - .useZeroCopy(false) - .skipCorruptRecords(false) - .tolerateMissingSchema(true); - - options.schema(schema); - options.include(OrcInputFormat.parseInclude(schema, - null)); - String kryoSarg = "AQEAb3JnLmFwYWNoZS5oYWRvb3AuaGl2ZS5xbC5pby5zYXJnLkV4cHJlc3Npb25UcmXlAQEBamF2YS51dGlsLkFycmF5TGlz9AECAQABAQEBAQEAAQAAAAEEAAEBAwEAAQEBAQEBAAEAAAIIAAEJAAEBAgEBAQIBAscBb3JnLmFwYWNoZS5oYWRvb3AuaGl2ZS5xbC5pby5zYXJnLlNlYXJjaEFyZ3VtZW50SW1wbCRQcmVkaWNhdGVMZWFmSW1wbAEBaV9pdGVtX3PrAAABBwEBAQIBEAkAAAEEEg=="; - String sargColumns = "i_item_sk,i_item_id,i_rec_start_date,i_rec_end_date,i_item_desc,i_current_price,i_wholesale_cost,i_brand_id,i_brand,i_class_id,i_class,i_category_id,i_category,i_manufact_id,i_manufact,i_size,i_formulation,i_color,i_units,i_container,i_manager_id,i_product_name"; - if (kryoSarg != null && sargColumns != null) { - byte[] sargBytes = Base64.decodeBase64(kryoSarg); - SearchArgument sarg = - new Kryo().readObject(new Input(sargBytes), SearchArgumentImpl.class); - options.searchArgument(sarg, sargColumns.split(",")); - sarg.getExpression().toString(); - } - orcColumnarBatchScanReader = new OrcColumnarBatchScanReader(); + constructSchema(); initReaderJava(); - initDataColIds(options, orcColumnarBatchScanReader); - initRecordReaderJava(options); - initBatch(options); + initDataColIds(); + initRecordReaderJava(); + initBatch(); + } + + private void constructSchema() { + requiredSchema = new StructType() + .add("i_item_sk", LongType) + .add("i_item_id", StringType); } - public void initDataColIds( - Options options, OrcColumnarBatchScanReader orcColumnarBatchScanReader) { - List allCols; - allCols = Arrays.asList(options.getColumnNames()); - orcColumnarBatchScanReader.colToInclu = new ArrayList(); - List optionField = options.getSchema().getFieldNames(); - orcColumnarBatchScanReader.colsToGet = new int[optionField.size()]; - orcColumnarBatchScanReader.realColsCnt = 0; - for (int i = 0; i < optionField.size(); i++) { - if (allCols.contains(optionField.get(i))) { - orcColumnarBatchScanReader.colToInclu.add(optionField.get(i)); - orcColumnarBatchScanReader.colsToGet[i] = 0; - orcColumnarBatchScanReader.realColsCnt++; - } else { + private void initDataColIds() { + // find requiredS fieldNames + String[] requiredfieldNames = requiredSchema.fieldNames(); + // save valid cols and numbers of valid cols + orcColumnarBatchScanReader.colsToGet = new int[requiredfieldNames.length]; + orcColumnarBatchScanReader.includedColumns = new ArrayList<>(); + // collect read cols types + ArrayList typeBuilder = new ArrayList<>(); + + for (int i = 0; i < requiredfieldNames.length; i++) { + String target = requiredfieldNames[i]; + + // if not find, set colsToGet value -1, else set colsToGet 0 + boolean is_find = false; + for (int j = 0; j < orcColumnarBatchScanReader.allFieldsNames.size(); j++) { + if (target.equals(orcColumnarBatchScanReader.allFieldsNames.get(j))) { + orcColumnarBatchScanReader.colsToGet[i] = 0; + orcColumnarBatchScanReader.includedColumns.add(requiredfieldNames[i]); + typeBuilder.add(OmniExpressionAdaptor.sparkTypeToOmniType(requiredSchema.fields()[i].dataType())); + is_find = true; + break; + } + } + + if (!is_find) { orcColumnarBatchScanReader.colsToGet[i] = -1; } } - orcColumnarBatchScanReader.requiredfieldNames = new String[optionField.size()]; - TypeDescription schema = options.getSchema(); - int[] precisionArray = new int[optionField.size()]; - int[] scaleArray = new int[optionField.size()]; - for (int i = 0; i < optionField.size(); i++) { - precisionArray[i] = schema.findSubtype(optionField.get(i)).getPrecision(); - scaleArray[i] = schema.findSubtype(optionField.get(i)).getScale(); - orcColumnarBatchScanReader.requiredfieldNames[i] = optionField.get(i); - } - orcColumnarBatchScanReader.precisionArray = precisionArray; - orcColumnarBatchScanReader.scaleArray = scaleArray; + vecTypeIds = typeBuilder.stream().mapToInt(Integer::intValue).toArray(); } @After @@ -122,8 +109,7 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { System.out.println("orcColumnarBatchJniReader test finished"); } - public void initReaderJava() throws URISyntaxException { - OrcFile.ReaderOptions readerOptions = OrcFile.readerOptions(conf); + private void initReaderJava() { File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); String path = directory.getAbsolutePath(); URI uri = null; @@ -133,16 +119,17 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, readerOptions); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } - public void initRecordReaderJava(Options options) { - orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader.initializeRecordReaderJava(options); + private void initRecordReaderJava() { + orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader. + initializeRecordReaderJava(offset, length, null, requiredSchema); assertTrue(orcColumnarBatchScanReader.recordReader != 0); } - public void initBatch(Options options) { + private void initBatch() { orcColumnarBatchScanReader.initBatchJava(batchSize); assertTrue(orcColumnarBatchScanReader.batchReader != 0); } @@ -150,8 +137,7 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { @Test public void testNext() { Vec[] vecs = new Vec[2]; - int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()}; - long rtn = orcColumnarBatchScanReader.next(vecs, typeId); + long rtn = orcColumnarBatchScanReader.next(vecs, vecTypeIds); assertTrue(rtn == 4096); assertTrue(((LongVec) vecs[0]).get(0) == 1); String str = new String(((VarcharVec) vecs[1]).get(0)); diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties index 89eabe8e6a2fa36aa9523293e7368a3076856dd1..099e28e8dccf4fb3e9c050292253f10bfb0dedd4 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties @@ -2,11 +2,13 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. # -hive.metastore.uris=thrift://server1:9083 +#hive.metastore.uris=thrift://server1:9083 +hive.metastore.uris=thrift://OmniOperator:9083 spark.sql.warehouse.dir=/user/hive/warehouse spark.memory.offHeap.size=8G spark.sql.codegen.wholeStage=false spark.sql.extensions=com.huawei.boostkit.spark.ColumnarPlugin spark.shuffle.manager=org.apache.spark.shuffle.sort.OmniColumnarShuffleManager spark.sql.orc.impl=native -hive.db=tpcds_bin_partitioned_orc_2 \ No newline at end of file +#hive.db=tpcds_bin_partitioned_orc_2 +hive.db=tpcds_bin_partitioned_varchar_orc_2 diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e6b786c2a2030b4414feb47e4f81a786e9ee2427..329295bac6215119a62d1177a7dda3c306c9cf64 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkException import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.execution.ColumnarSparkPlanTest import org.apache.spark.sql.types.{DataType, Decimal} @@ -64,7 +65,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getByte(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as short") { @@ -72,7 +74,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getShort(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as int") { @@ -80,7 +83,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getInt(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as long") { @@ -88,7 +92,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getLong(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as float") { @@ -96,7 +101,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getFloat(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as double") { @@ -104,7 +110,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getDouble(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as date") { @@ -154,13 +161,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception4 = intercept[Exception]( result4.collect().toSeq.head.getBoolean(0) ) - assert(exception4.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception4.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception4.isInstanceOf[SparkException], s"sql: ${sql}") val result5 = spark.sql("select cast('test' as boolean);") val exception5 = intercept[Exception]( result5.collect().toSeq.head.getBoolean(0) ) - assert(exception5.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception5.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception5.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast boolean to string") { @@ -182,13 +191,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getByte(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as byte);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getByte(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast byte to string") { @@ -210,13 +221,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getShort(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as short);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getShort(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast short to string") { @@ -238,13 +251,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getInt(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as int);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getInt(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast int to string") { @@ -266,13 +281,15 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getLong(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as long);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getLong(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast long to string") { @@ -298,7 +315,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception3 = intercept[Exception]( result3.collect().toSeq.head.getFloat(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast float to string") { @@ -324,7 +342,8 @@ class CastSuite extends ColumnarSparkPlanTest { val exception3 = intercept[Exception]( result3.collect().toSeq.head.getDouble(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") +// assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast double to string") { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index f83edb9ca91b045e3ff6c419697cff381bf31bcc..ea52aca621e61ce4c48cf794fc4ddce06a24c632 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -61,7 +61,7 @@ class CombiningLimitsSuite extends PlanTest { comparePlans(optimized1, expected1) // test child max row > limit. - val query2 = testRelation.select().groupBy()(count(1)).limit(0).analyze + val query2 = testRelation2.select($"x").groupBy($"x")(count(1)).limit(1).analyze val optimized2 = Optimize.execute(query2) comparePlans(optimized2, query2) diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala index aaa244cdf65c04699ba2bf4c5c443e8727faf9f5..e1c620e1c0448b85f9d2164f7cda749a6712c5ce 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala @@ -43,7 +43,7 @@ class MergeSubqueryFiltersSuite extends PlanTest { } private def extractorExpression(cteIndex: Int, output: Seq[Attribute], fieldIndex: Int) = { - GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output)), fieldIndex) + GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output, isStreaming = false)), fieldIndex) .as("scalarsubquery()") } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala index c34ff5bb15acf36c4dcd96f57c70554f6e974ce9..c0be72f31e5d80d9ab339e79339bea833193dd11 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala @@ -19,17 +19,15 @@ package org.apache.spark.sql.execution.adaptive import java.io.File import java.net.URI - import org.apache.logging.log4j.Level import org.scalatest.PrivateMethodTester import org.scalatest.time.SpanSugar._ - import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec} +import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource @@ -37,7 +35,7 @@ import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter -import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate +import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLExecutionStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -1122,13 +1120,21 @@ class AdaptiveQueryExecSuite test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + var plan: SparkPlan = null + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + plan = qe.executedPlan + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } + spark.listenerManager.register(listener) withTable("t1") { - val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan - assert(plan.isInstanceOf[CommandResultExec]) - val commandResultExec = plan.asInstanceOf[CommandResultExec] - assert(commandResultExec.commandPhysicalPlan.isInstanceOf[DataWritingCommandExec]) - assert(commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec] - .child.isInstanceOf[AdaptiveSparkPlanExec]) + val format = classOf[NoopDataSource].getName + Seq((0, 1)).toDF("x", "y").write.format(format).mode("overwrite").save() + sparkContext.listenerBus.waitUntilEmpty() + assert(plan.isInstanceOf[V2TableWriteExec]) + assert(plan.asInstanceOf[V2TableWriteExec].child.isInstanceOf[AdaptiveSparkPlanExec]) + spark.listenerManager.unregister(listener) } } } @@ -1172,15 +1178,14 @@ class AdaptiveQueryExecSuite test("SPARK-31658: SQL UI should show write commands") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "false") { withTable("t1") { - var checkDone = false + var commands: Seq[SparkPlanInfo] = Seq.empty val listener = new SparkListener { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { - case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) => - assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand") - checkDone = true + case start: SparkListenerSQLExecutionStart => + commands = commands ++ Seq(start.sparkPlanInfo) case _ => // ignore other events } } @@ -1189,7 +1194,12 @@ class AdaptiveQueryExecSuite try { sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() spark.sparkContext.listenerBus.waitUntilEmpty() - assert(checkDone) + assert(commands.size == 3) + assert(commands.head.nodeName == "Execute CreateDataSourceTableAsSelectCommand") + assert(commands(1).nodeName == "Execute InsertIntoHadoopFsRelationCommand") + assert(commands(1).children.size == 1) + assert(commands(1).children.head.nodeName == "WriteFiles") + assert(commands(2).nodeName == "CommandResult") } finally { spark.sparkContext.removeSparkListener(listener) } @@ -1574,7 +1584,7 @@ class AdaptiveQueryExecSuite test("SPARK-32932: Do not use local shuffle read at final stage on write command") { withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString, SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { val data = for ( i <- 1L to 10L; j <- 1L to 3L @@ -1584,9 +1594,8 @@ class AdaptiveQueryExecSuite var noLocalread: Boolean = false val listener = new QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - qe.executedPlan match { + stripAQEPlan(qe.executedPlan) match { case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) => - assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec]) noLocalread = collect(plan) { case exec: AQEShuffleReadExec if exec.isLocalRead => exec }.isEmpty diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml index 4376a89bef1c869c7f9feb19bc24f381bd01a9c0..e949cea2d673a51d90262735fa860ab3d2b46ebe 100644 --- a/omnioperator/omniop-spark-extension/pom.xml +++ b/omnioperator/omniop-spark-extension/pom.xml @@ -8,13 +8,13 @@ com.huawei.kunpeng boostkit-omniop-spark-parent pom - 3.3.1-1.6.0 + 3.4.3-1.6.0 BoostKit Spark Native Sql Engine Extension Parent Pom 2.12.10 2.12 - 3.3.1 + 3.4.3 3.2.2 UTF-8 UTF-8