diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp index f871f2c3d3f43c371d4aeff4d0d0d5831fd3f174..13dba015427dcabbc03cb6458b1c9442338afc2c 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp @@ -78,6 +78,23 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJn // Get capacity for each record batch int64_t capacity = (int64_t)env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("capacity")); + std::unique_ptr rebaseInfoPtr = common::BuildTimeRebaseInfo(env, jsonObj); + + ParquetReader *pReader = new ParquetReader(rebaseInfoPtr); + auto state = pReader->InitReader(uriInfo, capacity, ugiString); + if (state != Status::OK()) { + env->ThrowNew(runtimeExceptionClass, state.ToString().c_str()); + return 0; + } + return (jlong)(pReader); + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_initializeRecordReader + (JNIEnv *env, jobject jObj, jlong reader, jobject jsonObj) +{ + JNI_FUNC_START + ParquetReader *pReader = (ParquetReader *)reader; int64_t start = (int64_t)env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("start")); int64_t end = (int64_t)env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("end")); @@ -96,11 +113,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJn auto fieldNames = GetFieldNames(env, jsonObj); - std::unique_ptr rebaseInfoPtr = common::BuildTimeRebaseInfo(env, jsonObj); - - ParquetReader *pReader = new ParquetReader(rebaseInfoPtr); - auto state = pReader->InitRecordReader(uriInfo, start, end, capacity, hasExpressionTree, pushedFilterArray, - fieldNames, ugiString); + auto state = pReader->InitRecordReader(start, end, hasExpressionTree, pushedFilterArray, fieldNames); if (state != Status::OK()) { env->ThrowNew(runtimeExceptionClass, state.ToString().c_str()); return 0; @@ -109,6 +122,28 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJn JNI_FUNC_END(runtimeExceptionClass) } +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_getAllFieldNames + (JNIEnv *env, jobject jObj, jlong reader, jobject allFieldNames) +{ + JNI_FUNC_START + ParquetReader *pReader = (ParquetReader *)reader; + std::shared_ptr schema; + auto state = pReader->arrow_reader->GetSchema(&schema); + if (state != Status::OK()) { + env->ThrowNew(runtimeExceptionClass, state.ToString().c_str()); + return 0; + } + std::vector columnNames = schema->field_names(); + auto num = columnNames.size(); + for (uint32_t i = 0; i < num; i++) { + jstring fieldName = env->NewStringUTF(columnNames[i].c_str()); + env->CallBooleanMethod(allFieldNames, arrayListAdd, fieldName); + env->DeleteLocalRef(fieldName); + } + return (jlong)(num); + JNI_FUNC_END(runtimeExceptionClass) +} + JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_recordReaderNext(JNIEnv *env, jobject jObj, jlong reader, jlongArray vecNativeId) { diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.h b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.h index a374567476487d13848fab31a82ffef7e038a106..b5b382760033dcb69749d470a13293fab2d612e3 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.h @@ -44,6 +44,22 @@ extern "C" { JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_initializeReader (JNIEnv* env, jobject jObj, jobject job); +/* + * Class: com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader + * Method: initializeRecordReader + * Signature: (JLorg/json/JSONObject;)J + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_initializeRecordReader + (JNIEnv *, jobject, jlong, jobject); + +/* + * Class: com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader + * Method: getAllFieldNames + * Signature: (JLjava/util/ArrayList;)J + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_getAllFieldNames + (JNIEnv *, jobject, jlong, jobject); + /* * Class: com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader * Method: recordReaderNext diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp index 83b0f326575dc49df8066d1b8f95b90066a3a3fc..9eab7507f6f3472cf58d3ec4c82cf682502ffe87 100644 --- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp @@ -84,8 +84,7 @@ Filesystem* omniruntime::reader::GetFileSystemPtr(UriInfo &uri, std::string& ugi return restore_filesysptr[key]; } -Status ParquetReader::InitRecordReader(UriInfo &uri, int64_t start, int64_t end, int64_t capacity, bool hasExpressionTree, - Expression pushedFilterArray, const std::vector& fieldNames, std::string& ugi) +Status ParquetReader::InitReader(UriInfo &uri, int64_t capacity, std::string& ugi) { // Configure reader settings auto reader_properties = parquet::ReaderProperties(pool); @@ -94,8 +93,6 @@ Status ParquetReader::InitRecordReader(UriInfo &uri, int64_t start, int64_t end, auto arrow_reader_properties = parquet::ArrowReaderProperties(); arrow_reader_properties.set_batch_size(capacity); - std::shared_ptr file; - // Get the file from filesystem Status result; mutex_.lock(); @@ -113,6 +110,12 @@ Status ParquetReader::InitRecordReader(UriInfo &uri, int64_t start, int64_t end, reader_builder.properties(arrow_reader_properties); ARROW_ASSIGN_OR_RAISE(arrow_reader, reader_builder.Build()); + return arrow::Status::OK(); +} + +Status ParquetReader::InitRecordReader(int64_t start, int64_t end, bool hasExpressionTree, + Expression pushedFilterArray, const std::vector& fieldNames) +{ std::vector row_group_indices; auto filesource = std::make_shared(file); ARROW_RETURN_NOT_OK(GetRowGroupIndices(*filesource, start, end, hasExpressionTree, pushedFilterArray, row_group_indices)); diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.h b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.h index 5bbd3a5031056ca73d15cd796180007a467f0968..3d1645054259a6d23c789395f590d871051d0ab1 100644 --- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.h @@ -53,8 +53,10 @@ namespace omniruntime::reader { ParquetReader(std::unique_ptr &rebaseInfoPtr) : rebaseInfoPtr(std::move(rebaseInfoPtr)) {} - arrow::Status InitRecordReader(UriInfo &uri, int64_t start, int64_t end, int64_t capacity, bool hasExpressionTree, - Expression pushedFilterArray, const std::vector& fieldNames, std::string& ugi); + arrow::Status InitReader(UriInfo &uri, int64_t capacity, std::string& ugi); + + arrow::Status InitRecordReader(int64_t start, int64_t end, bool hasExpressionTree, + Expression pushedFilterArray, const std::vector& fieldNames); arrow::Status ReadNextBatch(std::vector &batch, long *batchRowSize); @@ -84,6 +86,8 @@ namespace omniruntime::reader { const std::shared_ptr &ctx, std::unique_ptr* out); std::unique_ptr rebaseInfoPtr; + + std::shared_ptr file; }; class Filesystem { diff --git a/omnioperator/omniop-native-reader/cpp/test/tablescan/parquet_scan_test.cpp b/omnioperator/omniop-native-reader/cpp/test/tablescan/parquet_scan_test.cpp index 26a752f5d2200d80854b84182bd9a20ce0a790bc..acc83d51de8cffe05ed08b9e079ab4572f1b0fe3 100644 --- a/omnioperator/omniop-native-reader/cpp/test/tablescan/parquet_scan_test.cpp +++ b/omnioperator/omniop-native-reader/cpp/test/tablescan/parquet_scan_test.cpp @@ -44,7 +44,9 @@ TEST(read, test_parquet_reader) ParquetReader *reader = new ParquetReader(rebaseInfoPtr); std::string ugi = "root@sample"; Expression pushedFilterArray; - auto state1 = reader->InitRecordReader(uriInfo, 0, 1000000, 1024, false, pushedFilterArray, column_indices, ugi); + auto state0 = reader->InitReader(uriInfo, 1024, ugi); + ASSERT_EQ(state0, arrow::Status::OK()); + auto state1 = reader->InitRecordReader(0, 1000000, false, pushedFilterArray, column_indices); ASSERT_EQ(state1, arrow::Status::OK()); std::vector recordBatch(column_indices.size()); @@ -113,7 +115,9 @@ TEST(read, test_varchar) ParquetReader *reader = new ParquetReader(rebaseInfoPtr); std::string ugi = "root@sample"; Expression pushedFilterArray; - auto state1 = reader->InitRecordReader(uriInfo, 0, 1000000, 4096, false, pushedFilterArray, column_indices, ugi); + auto state0 = reader->InitReader(uriInfo, 4096, ugi); + ASSERT_EQ(state0, arrow::Status::OK()); + auto state1 = reader->InitRecordReader(0, 1000000, false, pushedFilterArray, column_indices); ASSERT_EQ(state1, arrow::Status::OK()); int total_nums = 0; int iter = 0; diff --git a/omnioperator/omniop-native-reader/java/pom.xml b/omnioperator/omniop-native-reader/java/pom.xml index e7ddfe6c3bc7764df4b1642e4a137afcef64f6cd..3cd67b1fb55204f0ea5668b8c27db39c66c809bd 100644 --- a/omnioperator/omniop-native-reader/java/pom.xml +++ b/omnioperator/omniop-native-reader/java/pom.xml @@ -8,13 +8,13 @@ com.huawei.boostkit boostkit-omniop-native-reader jar - 3.3.1-1.6.0 + 3.4.3-1.6.0 BoostKit Spark Native Sql Engine Extension With OmniOperator 2.12 - 3.3.1 + 3.4.3 FALSE ../cpp/ ../cpp/build/releases/ @@ -35,8 +35,8 @@ org.slf4j - slf4j-api - 1.7.32 + slf4j-simple + 1.7.36 junit diff --git a/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/ParquetColumnarBatchJniReader.java b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/ParquetColumnarBatchJniReader.java index b740b726ce9b3ea08d30ee516cbdb4d8c9ee7cdb..a02be6b9ee1a4c4ee7240e49729daa49c5f41be9 100644 --- a/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/ParquetColumnarBatchJniReader.java +++ b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/ParquetColumnarBatchJniReader.java @@ -19,6 +19,8 @@ package com.huawei.boostkit.scan.jni; import org.json.JSONObject; +import java.util.ArrayList; + public class ParquetColumnarBatchJniReader { public ParquetColumnarBatchJniReader() { @@ -27,6 +29,10 @@ public class ParquetColumnarBatchJniReader { public native long initializeReader(JSONObject job); + public native long initializeRecordReader(long parquetReader, JSONObject job); + + public native long getAllFieldNames(long parquetReader, ArrayList allFieldNames); + public native long recordReaderNext(long parquetReader, long[] vecNativeId); public native void recordReaderClose(long parquetReader); diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml index 9cc1b9d25848f62caf13359f920db280820b04f0..13866589316a4e5add47eb2d38ea74815e7eb6e9 100644 --- a/omnioperator/omniop-spark-extension/java/pom.xml +++ b/omnioperator/omniop-spark-extension/java/pom.xml @@ -7,7 +7,7 @@ com.huawei.kunpeng boostkit-omniop-spark-parent - 3.3.1-1.6.0 + 3.4.3-1.6.0 ../pom.xml @@ -52,7 +52,7 @@ com.huawei.boostkit boostkit-omniop-native-reader - 3.3.1-1.6.0 + 3.4.3-1.6.0 junit diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java index df8c564b5a9168f3589c7293e6c6595e4dcea5a0..1aab7425b4d703e3f20310277db3ca22db566dee 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java @@ -25,65 +25,78 @@ import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.vector.*; import org.apache.orc.impl.writer.TimestampTreeWriter; +import org.apache.spark.sql.catalyst.util.CharVarcharUtils; import org.apache.spark.sql.catalyst.util.RebaseDateTime; -import org.apache.hadoop.hive.ql.io.sarg.ExpressionTree; -import org.apache.hadoop.hive.ql.io.sarg.PredicateLeaf; -import org.apache.orc.OrcFile.ReaderOptions; -import org.apache.orc.Reader.Options; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.json.JSONObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.math.BigDecimal; import java.net.URI; -import java.sql.Date; -import java.sql.Timestamp; +import java.time.LocalDate; import java.text.DateFormat; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.ArrayList; -import java.util.List; +import java.util.Arrays; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.TimeZone; public class OrcColumnarBatchScanReader { private static final Logger LOGGER = LoggerFactory.getLogger(OrcColumnarBatchScanReader.class); private boolean nativeSupportTimestampRebase; + private static final Pattern CHAR_TYPE = Pattern.compile("char\\(\\s*(\\d+)\\s*\\)"); + + private static final int MAX_LEAF_THRESHOLD = 256; + public long reader; public long recordReader; public long batchReader; - public int[] colsToGet; - public int realColsCnt; - public ArrayList fildsNames; + // All ORC fieldNames + public ArrayList allFieldsNames; - public ArrayList colToInclu; + // Indicate columns to read + public int[] colsToGet; - public String[] requiredfieldNames; + // Actual columns to read + public ArrayList includedColumns; - public int[] precisionArray; + // max threshold for leaf node + private int leafIndex; - public int[] scaleArray; + // spark required schema + private StructType requiredSchema; public OrcColumnarBatchJniReader jniReader; public OrcColumnarBatchScanReader() { jniReader = new OrcColumnarBatchJniReader(); - fildsNames = new ArrayList(); - } - - public JSONObject getSubJson(ExpressionTree node) { - JSONObject jsonObject = new JSONObject(); - jsonObject.put("op", node.getOperator().ordinal()); - if (node.getOperator().toString().equals("LEAF")) { - jsonObject.put("leaf", node.toString()); - return jsonObject; - } - ArrayList child = new ArrayList(); - for (ExpressionTree childNode : node.getChildren()) { - JSONObject rtnJson = getSubJson(childNode); - child.add(rtnJson); - } - jsonObject.put("child", child); - return jsonObject; + allFieldsNames = new ArrayList(); } public String padZeroForDecimals(String [] decimalStrArray, int decimalScale) { @@ -95,91 +108,6 @@ public class OrcColumnarBatchScanReader { return String.format("%1$-" + decimalScale + "s", decimalVal).replace(' ', '0'); } - public int getPrecision(String colname) { - for (int i = 0; i < requiredfieldNames.length; i++) { - if (colname.equals(requiredfieldNames[i])) { - return precisionArray[i]; - } - } - - return -1; - } - - public int getScale(String colname) { - for (int i = 0; i < requiredfieldNames.length; i++) { - if (colname.equals(requiredfieldNames[i])) { - return scaleArray[i]; - } - } - - return -1; - } - - public JSONObject getLeavesJson(List leaves) { - JSONObject jsonObjectList = new JSONObject(); - for (int i = 0; i < leaves.size(); i++) { - PredicateLeaf pl = leaves.get(i); - JSONObject jsonObject = new JSONObject(); - jsonObject.put("op", pl.getOperator().ordinal()); - jsonObject.put("name", pl.getColumnName()); - jsonObject.put("type", pl.getType().ordinal()); - if (pl.getLiteral() != null) { - if (pl.getType() == PredicateLeaf.Type.DATE) { - jsonObject.put("literal", ((int)Math.ceil(((Date)pl.getLiteral()).getTime()* 1.0/3600/24/1000)) + ""); - } else if (pl.getType() == PredicateLeaf.Type.DECIMAL) { - int decimalP = getPrecision(pl.getColumnName()); - int decimalS = getScale(pl.getColumnName()); - String[] spiltValues = pl.getLiteral().toString().split("\\."); - if (decimalS == 0) { - jsonObject.put("literal", spiltValues[0] + " " + decimalP + " " + decimalS); - } else { - String scalePadZeroStr = padZeroForDecimals(spiltValues, decimalS); - jsonObject.put("literal", spiltValues[0] + "." + scalePadZeroStr + " " + decimalP + " " + decimalS); - } - } else if (pl.getType() == PredicateLeaf.Type.TIMESTAMP) { - Timestamp t = (Timestamp)pl.getLiteral(); - jsonObject.put("literal", formatSecs(t.getTime() / TimestampTreeWriter.MILLIS_PER_SECOND) + " " + formatNanos(t.getNanos())); - } else { - jsonObject.put("literal", pl.getLiteral().toString()); - } - } else { - jsonObject.put("literal", ""); - } - if ((pl.getLiteralList() != null) && (pl.getLiteralList().size() != 0)){ - List lst = new ArrayList<>(); - for (Object ob : pl.getLiteralList()) { - if (ob == null) { - lst.add(null); - continue; - } - if (pl.getType() == PredicateLeaf.Type.DECIMAL) { - int decimalP = getPrecision(pl.getColumnName()); - int decimalS = getScale(pl.getColumnName()); - String[] spiltValues = ob.toString().split("\\."); - if (decimalS == 0) { - lst.add(spiltValues[0] + " " + decimalP + " " + decimalS); - } else { - String scalePadZeroStr = padZeroForDecimals(spiltValues, decimalS); - lst.add(spiltValues[0] + "." + scalePadZeroStr + " " + decimalP + " " + decimalS); - } - } else if (pl.getType() == PredicateLeaf.Type.DATE) { - lst.add(((int)Math.ceil(((Date)ob).getTime()* 1.0/3600/24/1000)) + ""); - } else if (pl.getType() == PredicateLeaf.Type.TIMESTAMP) { - Timestamp t = (Timestamp)pl.getLiteral(); - lst.add(formatSecs(t.getTime() / TimestampTreeWriter.MILLIS_PER_SECOND) + " " + formatNanos(t.getNanos())); - } else { - lst.add(ob.toString()); - } - } - jsonObject.put("literalList", lst); - } else { - jsonObject.put("literalList", new ArrayList()); - } - jsonObjectList.put("leaf-" + i, jsonObject); - } - return jsonObjectList; - } - private long formatSecs(long secs) { DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); long epoch; @@ -224,15 +152,11 @@ public class OrcColumnarBatchScanReader { * Init Orc reader. * * @param uri split file path - * @param options split file options */ - public long initializeReaderJava(URI uri, ReaderOptions options) { + public long initializeReaderJava(URI uri) { JSONObject job = new JSONObject(); - if (options.getOrcTail() == null) { - job.put("serializedTail", ""); - } else { - job.put("serializedTail", options.getOrcTail().getSerializedTail().toString()); - } + + job.put("serializedTail", ""); job.put("tailLocation", 9223372036854775807L); job.put("scheme", uri.getScheme() == null ? "" : uri.getScheme()); @@ -240,38 +164,37 @@ public class OrcColumnarBatchScanReader { job.put("port", uri.getPort()); job.put("path", uri.getPath() == null ? "" : uri.getPath()); - reader = jniReader.initializeReader(job, fildsNames); + reader = jniReader.initializeReader(job, allFieldsNames); return reader; } /** * Init Orc RecordReader. * - * @param options split file options + * @param offset split file offset + * @param length split file read length + * @param pushedFilter the filter push down to native + * @param requiredSchema the columns read from native */ - public long initializeRecordReaderJava(Options options) { + public long initializeRecordReaderJava(long offset, long length, Filter pushedFilter, StructType requiredSchema) { + this.requiredSchema = requiredSchema; JSONObject job = new JSONObject(); - if (options.getInclude() == null) { - job.put("include", ""); - } else { - job.put("include", options.getInclude().toString()); - } - job.put("offset", options.getOffset()); - job.put("length", options.getLength()); - // When the number of pushedFilters > hive.CNF_COMBINATIONS_THRESHOLD, the expression is rewritten to - // 'YES_NO_NULL'. Under the circumstances, filter push down will be skipped. - if (options.getSearchArgument() != null - && !options.getSearchArgument().toString().contains("YES_NO_NULL")) { - LOGGER.debug("SearchArgument: {}", options.getSearchArgument().toString()); - JSONObject jsonexpressionTree = getSubJson(options.getSearchArgument().getExpression()); - job.put("expressionTree", jsonexpressionTree); - JSONObject jsonleaves = getLeavesJson(options.getSearchArgument().getLeaves()); - job.put("leaves", jsonleaves); + + job.put("offset", offset); + job.put("length", length); + + if (pushedFilter != null) { + JSONObject jsonExpressionTree = new JSONObject(); + JSONObject jsonLeaves = new JSONObject(); + boolean flag = canPushDown(pushedFilter, jsonExpressionTree, jsonLeaves); + if (flag) { + job.put("expressionTree", jsonExpressionTree); + job.put("leaves", jsonLeaves); + } } - job.put("includedColumns", colToInclu.toArray()); + job.put("includedColumns", includedColumns.toArray()); addJulianGregorianInfo(job); - recordReader = jniReader.initializeRecordReader(reader, job); return recordReader; } @@ -318,13 +241,13 @@ public class OrcColumnarBatchScanReader { } public int next(Vec[] vecList, int[] typeIds) { - long[] vecNativeIds = new long[realColsCnt]; + long[] vecNativeIds = new long[typeIds.length]; long rtn = jniReader.recordReaderNext(recordReader, batchReader, typeIds, vecNativeIds); if (rtn == 0) { return 0; } int nativeGetId = 0; - for (int i = 0; i < realColsCnt; i++) { + for (int i = 0; i < colsToGet.length; i++) { if (colsToGet[i] != 0) { continue; } @@ -372,7 +295,7 @@ public class OrcColumnarBatchScanReader { } default: { throw new RuntimeException("UnSupport type for ColumnarFileScan:" + - DataType.DataTypeId.values()[typeIds[i]]); + DataType.DataTypeId.values()[typeIds[i]]); } } nativeGetId++; @@ -380,18 +303,226 @@ public class OrcColumnarBatchScanReader { return (int)rtn; } - private static String bytesToHexString(byte[] bytes) { - if (bytes == null || bytes.length < 1) { - throw new IllegalArgumentException("this bytes must not be null or empty"); + enum OrcOperator { + OR, + AND, + NOT, + LEAF, + CONSTANT + } + + enum OrcLeafOperator { + EQUALS, + NULL_SAFE_EQUALS, + LESS_THAN, + LESS_THAN_EQUALS, + IN, + BETWEEN, // not use, spark transfers it to gt and lt + IS_NULL + } + + enum OrcPredicateDataType { + LONG, // all of integer types + FLOAT, // float and double + STRING, // string, char, varchar + DATE, + DECIMAL, + TIMESTAMP, + BOOLEAN + } + + private OrcPredicateDataType getOrcPredicateDataType(String attribute) { + StructField field = requiredSchema.apply(attribute); + org.apache.spark.sql.types.DataType dataType = field.dataType(); + if (dataType instanceof ShortType || dataType instanceof IntegerType || + dataType instanceof LongType) { + return OrcPredicateDataType.LONG; + } else if (dataType instanceof DoubleType) { + return OrcPredicateDataType.FLOAT; + } else if (dataType instanceof StringType) { + if (isCharType(field.metadata())) { + throw new UnsupportedOperationException("Unsupported orc push down filter data type: char"); + } + return OrcPredicateDataType.STRING; + } else if (dataType instanceof DateType) { + return OrcPredicateDataType.DATE; + } else if (dataType instanceof DecimalType) { + return OrcPredicateDataType.DECIMAL; + } else if (dataType instanceof BooleanType) { + return OrcPredicateDataType.BOOLEAN; + } else { + throw new UnsupportedOperationException("Unsupported orc push down filter data type: " + + dataType.getClass().getSimpleName()); + } + } + + // Check the type whether is char type, which orc native does not support push down + private boolean isCharType(Metadata metadata) { + if (metadata != null) { + String rawTypeString = CharVarcharUtils.getRawTypeString(metadata).getOrElse(null); + if (rawTypeString != null) { + Matcher matcher = CHAR_TYPE.matcher(rawTypeString); + return matcher.matches(); + } + } + return false; + } + + private boolean canPushDown(Filter pushedFilter, JSONObject jsonExpressionTree, + JSONObject jsonLeaves) { + try { + getExprJson(pushedFilter, jsonExpressionTree, jsonLeaves); + if (leafIndex > MAX_LEAF_THRESHOLD) { + throw new UnsupportedOperationException("leaf node nums is " + leafIndex + + ", which is bigger than max threshold " + MAX_LEAF_THRESHOLD + "."); + } + return true; + } catch (Exception e) { + LOGGER.info("Unable to push down orc filter because " + e.getMessage()); + return false; + } + } + + private void getExprJson(Filter filterPredicate, JSONObject jsonExpressionTree, + JSONObject jsonLeaves) { + if (filterPredicate instanceof And) { + addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.AND, + ((And) filterPredicate).left(), ((And) filterPredicate).right()); + } else if (filterPredicate instanceof Or) { + addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.OR, + ((Or) filterPredicate).left(), ((Or) filterPredicate).right()); + } else if (filterPredicate instanceof Not) { + addChildJson(jsonExpressionTree, jsonLeaves, OrcOperator.NOT, + ((Not) filterPredicate).child()); + } else if (filterPredicate instanceof EqualTo) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.EQUALS, jsonLeaves, + ((EqualTo) filterPredicate).attribute(), ((EqualTo) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof GreaterThan) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN_EQUALS, jsonLeaves, + ((GreaterThan) filterPredicate).attribute(), ((GreaterThan) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof GreaterThanOrEqual) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN, jsonLeaves, + ((GreaterThanOrEqual) filterPredicate).attribute(), ((GreaterThanOrEqual) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof LessThan) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN, jsonLeaves, + ((LessThan) filterPredicate).attribute(), ((LessThan) filterPredicate).value(), null); + leafIndex++; + } else if (filterPredicate instanceof LessThanOrEqual) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.LESS_THAN_EQUALS, jsonLeaves, + ((LessThanOrEqual) filterPredicate).attribute(), ((LessThanOrEqual) filterPredicate).value(), null); + leafIndex++; + // For IsNotNull/IsNull/In, pass literal = "" to native to avoid throwing exception. + } else if (filterPredicate instanceof IsNotNull) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, true); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IS_NULL, jsonLeaves, + ((IsNotNull) filterPredicate).attribute(), "", null); + leafIndex++; + } else if (filterPredicate instanceof IsNull) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IS_NULL, jsonLeaves, + ((IsNull) filterPredicate).attribute(), "", null); + leafIndex++; + } else if (filterPredicate instanceof In) { + addToJsonExpressionTree("leaf-" + leafIndex, jsonExpressionTree, false); + addLiteralToJsonLeaves("leaf-" + leafIndex, OrcLeafOperator.IN, jsonLeaves, + ((In) filterPredicate).attribute(), "", Arrays.stream(((In) filterPredicate).values()).toArray()); + leafIndex++; + } else { + throw new UnsupportedOperationException("Unsupported orc push down filter operation: " + + filterPredicate.getClass().getSimpleName()); + } + } + + private void addLiteralToJsonLeaves(String leaf, OrcLeafOperator leafOperator, JSONObject jsonLeaves, + String name, Object literal, Object[] literals) { + JSONObject leafJson = new JSONObject(); + leafJson.put("op", leafOperator.ordinal()); + leafJson.put("name", name); + leafJson.put("type", getOrcPredicateDataType(name).ordinal()); + + leafJson.put("literal", getLiteralValue(literal)); + + ArrayList literalList = new ArrayList<>(); + if (literals != null) { + for (Object lit: literals) { + literalList.add(getLiteralValue(lit)); + } + } + leafJson.put("literalList", literalList); + jsonLeaves.put(leaf, leafJson); + } + + private void addToJsonExpressionTree(String leaf, JSONObject jsonExpressionTree, boolean addNot) { + if (addNot) { + jsonExpressionTree.put("op", OrcOperator.NOT.ordinal()); + ArrayList child = new ArrayList<>(); + JSONObject subJson = new JSONObject(); + subJson.put("op", OrcOperator.LEAF.ordinal()); + subJson.put("leaf", leaf); + child.add(subJson); + jsonExpressionTree.put("child", child); + } else { + jsonExpressionTree.put("op", OrcOperator.LEAF.ordinal()); + jsonExpressionTree.put("leaf", leaf); + } + } + + private void addChildJson(JSONObject jsonExpressionTree, JSONObject jsonLeaves, + OrcOperator orcOperator, Filter ... filters) { + jsonExpressionTree.put("op", orcOperator.ordinal()); + ArrayList child = new ArrayList<>(); + for (Filter filter: filters) { + JSONObject subJson = new JSONObject(); + getExprJson(filter, subJson, jsonLeaves); + child.add(subJson); } + jsonExpressionTree.put("child", child); + } - final StringBuilder hexString = new StringBuilder(); - for (int i = 0; i < bytes.length; i++) { - if ((bytes[i] & 0xff) < 0x10) - hexString.append("0"); - hexString.append(Integer.toHexString(bytes[i] & 0xff)); + private String getLiteralValue(Object literal) { + // For null literal, the predicate will not be pushed down. + if (literal == null) { + throw new UnsupportedOperationException("Unsupported orc push down filter for literal is null"); } - return hexString.toString().toLowerCase(); + // For Decimal Type, we use the special string format to represent, which is "$decimalVal + // $precision $scale". + // e.g., Decimal(9, 3) = 123456.789, it outputs "123456.789 9 3". + // e.g., Decimal(9, 3) = 123456.7, it outputs "123456.700 9 3". + if (literal instanceof BigDecimal) { + BigDecimal value = (BigDecimal) literal; + int precision = value.precision(); + int scale = value.scale(); + String[] split = value.toString().split("\\."); + if (scale == 0) { + return split[0] + " " + precision + " " + scale; + } else { + String padded = padZeroForDecimals(split, scale); + return split[0] + "." + padded + " " + precision + " " + scale; + } + } + // For Date Type, spark uses Gregorian in default but orc uses Julian, which should be converted. + if (literal instanceof LocalDate) { + int epochDay = Math.toIntExact(((LocalDate) literal).toEpochDay()); + int rebased = RebaseDateTime.rebaseGregorianToJulianDays(epochDay); + return String.valueOf(rebased); + } + if (literal instanceof String) { + return (String) literal; + } + if (literal instanceof Integer || literal instanceof Long || literal instanceof Boolean || + literal instanceof Short || literal instanceof Double) { + return literal.toString(); + } + throw new UnsupportedOperationException("Unsupported orc push down filter date type: " + + literal.getClass().getSimpleName()); } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java index 3c1e7dba10e56e63e94ed27748c86198c0b3ca63..40fcb06d9203e2b7ac0a95ea1c79b02959cf511b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java @@ -55,6 +55,7 @@ import java.math.BigDecimal; import java.net.URI; import java.time.Instant; import java.time.LocalDate; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -81,6 +82,8 @@ public class ParquetColumnarBatchScanReader { private final List parquetTypes; + private ArrayList allFieldsNames; + public ParquetColumnarBatchScanReader(StructType requiredSchema, RebaseSpec datetimeRebaseSpec, RebaseSpec int96RebaseSpec, List parquetTypes) { this.requiredSchema = requiredSchema; @@ -124,8 +127,7 @@ public class ParquetColumnarBatchScanReader { } } - public long initializeReaderJava(Path path, long start, long end, int capacity, - Filter pushedFilter) throws UnsupportedEncodingException { + public long initializeReaderJava(Path path, int capacity) throws UnsupportedEncodingException { JSONObject job = new JSONObject(); URI uri = path.toUri(); @@ -135,11 +137,7 @@ public class ParquetColumnarBatchScanReader { job.put("port", uri.getPort()); job.put("path", uri.getPath() == null ? "" : uri.getPath()); - job.put("start", start); - job.put("end", end); - job.put("capacity", capacity); - job.put("fieldNames", requiredSchema.fieldNames()); String ugi = null; try { @@ -149,15 +147,35 @@ public class ParquetColumnarBatchScanReader { } job.put("ugi", ugi); + addJulianGregorianInfo(job); + + parquetReader = jniReader.initializeReader(job); + return parquetReader; + } + + public long initializeRecordReaderJava(long start, long end, String[] fieldNames, Filter pushedFilter) + throws UnsupportedEncodingException { + JSONObject job = new JSONObject(); + job.put("start", start); + job.put("end", end); + job.put("fieldNames", fieldNames); + if (pushedFilter != null) { pushDownFilter(pushedFilter, job); } - addJulianGregorianInfo(job); - parquetReader = jniReader.initializeReader(job); + parquetReader = jniReader.initializeRecordReader(parquetReader, job); return parquetReader; } + public ArrayList getAllFieldsNames() { + if (allFieldsNames == null) { + allFieldsNames = new ArrayList<>(); + jniReader.getAllFieldNames(parquetReader, allFieldsNames); + } + return allFieldsNames; + } + private void pushDownFilter(Filter pushedFilter, JSONObject job) { try { JSONObject jsonExpressionTree = getSubJson(pushedFilter); @@ -258,7 +276,12 @@ public class ParquetColumnarBatchScanReader { private void putCompareOp(JSONObject json, ParquetPredicateOperator op, String field, Object value) { json.put("op", op.ordinal()); - json.put("field", field); + if (allFieldsNames.contains(field)) { + json.put("field", field); + } else { + throw new ParquetDecodingException("Unsupported parquet push down missing columns: " + field); + } + if (value == null) { json.put("type", 0); } else { @@ -361,43 +384,48 @@ public class ParquetColumnarBatchScanReader { } } - public int next(Vec[] vecList, List types) { - int vectorCnt = vecList.length; - long[] vecNativeIds = new long[vectorCnt]; + public int next(Vec[] vecList, boolean[] missingColumns, List types) { + int colsCount = missingColumns.length; + long[] vecNativeIds = new long[types.size()]; long rtn = jniReader.recordReaderNext(parquetReader, vecNativeIds); if (rtn == 0) { return 0; } - for (int i = 0; i < vectorCnt; i++) { - DataType type = types.get(i); + int nativeGetId = 0; + for (int i = 0; i < colsCount; i++) { + if (missingColumns[i]) { + continue; + } + DataType type = types.get(nativeGetId); if (type instanceof LongType) { - vecList[i] = new LongVec(vecNativeIds[i]); + vecList[i] = new LongVec(vecNativeIds[nativeGetId]); } else if (type instanceof BooleanType) { - vecList[i] = new BooleanVec(vecNativeIds[i]); + vecList[i] = new BooleanVec(vecNativeIds[nativeGetId]); } else if (type instanceof ShortType) { - vecList[i] = new ShortVec(vecNativeIds[i]); + vecList[i] = new ShortVec(vecNativeIds[nativeGetId]); } else if (type instanceof IntegerType) { - vecList[i] = new IntVec(vecNativeIds[i]); + vecList[i] = new IntVec(vecNativeIds[nativeGetId]); } else if (type instanceof DecimalType) { if (DecimalType.is64BitDecimalType(type)) { - vecList[i] = new LongVec(vecNativeIds[i]); + vecList[i] = new LongVec(vecNativeIds[nativeGetId]); } else { - vecList[i] = new Decimal128Vec(vecNativeIds[i]); + vecList[i] = new Decimal128Vec(vecNativeIds[nativeGetId]); } } else if (type instanceof DoubleType) { - vecList[i] = new DoubleVec(vecNativeIds[i]); + vecList[i] = new DoubleVec(vecNativeIds[nativeGetId]); } else if (type instanceof StringType) { - vecList[i] = new VarcharVec(vecNativeIds[i]); + vecList[i] = new VarcharVec(vecNativeIds[nativeGetId]); } else if (type instanceof DateType) { - vecList[i] = new IntVec(vecNativeIds[i]); + vecList[i] = new IntVec(vecNativeIds[nativeGetId]); } else if (type instanceof ByteType) { - vecList[i] = new VarcharVec(vecNativeIds[i]); + vecList[i] = new VarcharVec(vecNativeIds[nativeGetId]); } else if (type instanceof TimestampType) { - vecList[i] = new LongVec(vecNativeIds[i]); + vecList[i] = new LongVec(vecNativeIds[nativeGetId]); tryToAdjustTimestampVec((LongVec) vecList[i], rtn, i); } else { throw new RuntimeException("Unsupport type for ColumnarFileScan: " + type.typeName()); } + nativeGetId++; } return (int)rtn; } diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java index 55d56ae20203d4fdfa551f974a4d92b4210e1a76..0859173ca18d6f8ae4730b57c6be0f474d15f94f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java @@ -80,12 +80,12 @@ public class ShuffleDataSerializer { long[] omniVecs = new long[vecCount]; int[] omniTypes = new int[vecCount]; createEmptyVec(rowBatch, omniTypes, omniVecs, columnarVecs, vecCount, rowCount); - OmniRowDeserializer deserializer = new OmniRowDeserializer(omniTypes); + OmniRowDeserializer deserializer = new OmniRowDeserializer(omniTypes, omniVecs); for (int rowIdx = 0; rowIdx < rowCount; rowIdx++) { VecData.ProtoRow protoRow = rowBatch.getRows(rowIdx); byte[] array = protoRow.getData().toByteArray(); - deserializer.parse(array, omniVecs, rowIdx); + deserializer.parse(array, rowIdx); } // update initial varchar vector because it's capacity might have been expanded. diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java similarity index 57% rename from omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java rename to omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java index 93950e9f0f1b9339587893b95f99c70967024acc..bd0b42463daff40bd7696706c92de1554c5f6ea3 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -22,18 +22,15 @@ import com.google.common.annotations.VisibleForTesting; import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; import com.huawei.boostkit.spark.jni.OrcColumnarBatchScanReader; import nova.hetu.omniruntime.vector.Vec; -import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.RecordReader; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.input.FileSplit; -import org.apache.orc.OrcConf; -import org.apache.orc.OrcFile; -import org.apache.orc.Reader; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.OmniColumnVectorUtils; import org.apache.spark.sql.execution.vectorized.OmniColumnVector; +import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -49,26 +46,11 @@ import java.util.ArrayList; public class OmniOrcColumnarBatchReader extends RecordReader { // The capacity of vectorized batch. - private int capacity; - /** - * The column IDs of the physical ORC file schema which are required by this reader. - * -1 means this required column is partition column, or it doesn't exist in the ORC file. - * Ideally partition column should never appear in the physical file, and should only appear - * in the directory name. However, Spark allows partition columns inside physical file, - * but Spark will discard the values from the file, and use the partition value got from - * directory name. The column order will be reserved though. - */ - @VisibleForTesting - public int[] requestedDataColIds; - // Native Record reader from ORC row batch. private OrcColumnarBatchScanReader recordReader; - private StructField[] requiredFields; - private StructField[] resultFields; - // The result columnar batch for vectorized execution by whole-stage codegen. @VisibleForTesting public ColumnarBatch columnarBatch; @@ -82,10 +64,13 @@ public class OmniOrcColumnarBatchReader extends RecordReader orcfieldNames = recordReader.fildsNames; // save valid cols and numbers of valid cols recordReader.colsToGet = new int[requiredfieldNames.length]; - recordReader.realColsCnt = 0; - // save valid cols fieldsNames - recordReader.colToInclu = new ArrayList(); + recordReader.includedColumns = new ArrayList<>(); // collect read cols types ArrayList typeBuilder = new ArrayList<>(); + for (int i = 0; i < requiredfieldNames.length; i++) { String target = requiredfieldNames[i]; - boolean is_find = false; - for (int j = 0; j < orcfieldNames.size(); j++) { - String temp = orcfieldNames.get(j); - if (target.equals(temp)) { - requestedDataColIds[i] = i; - recordReader.colsToGet[i] = 0; - recordReader.colToInclu.add(requiredfieldNames[i]); - recordReader.realColsCnt++; - typeBuilder.add(OmniExpressionAdaptor.sparkTypeToOmniType(requiredSchema.fields()[i].dataType())); - is_find = true; - } - } - - // if invalid, set colsToGet value -1, else set colsToGet 0 - if (!is_find) { + // if not find, set colsToGet value -1, else set colsToGet 0 + if (recordReader.allFieldsNames.contains(target)) { + recordReader.colsToGet[i] = 0; + recordReader.includedColumns.add(requiredfieldNames[i]); + typeBuilder.add(OmniExpressionAdaptor.sparkTypeToOmniType(requiredSchema.fields()[i].dataType())); + } else { recordReader.colsToGet[i] = -1; } } vecTypeIds = typeBuilder.stream().mapToInt(Integer::intValue).toArray(); - - for (int i = 0; i < resultFields.length; i++) { - if (requestedPartitionColIds[i] != -1) { - requestedDataColIds[i] = -1; - } - } - - // set data members resultFields and requestedDataColIdS - this.resultFields = resultFields; - this.requestedDataColIds = requestedDataColIds; - - recordReader.requiredfieldNames = requiredfieldNames; - recordReader.precisionArray = precisionArray; - recordReader.scaleArray = scaleArray; - recordReader.initializeRecordReaderJava(options); } /** * Initialize columnar batch by setting required schema and partition information. * With this information, this creates ColumnarBatch with the full schema. * - * @param requiredFields The fields that are required to return,. - * @param resultFields All the fields that are required to return, including partition fields. - * @param requestedDataColIds Requested column ids from orcSchema. -1 if not existed. - * @param requestedPartitionColIds Requested column ids from partition schema. -1 if not existed. + * @param partitionColumns partition columns * @param partitionValues Values of partition columns. */ - public void initBatch( - StructField[] requiredFields, - StructField[] resultFields, - int[] requestedDataColIds, - int[] requestedPartitionColIds, - InternalRow partitionValues) { - if (resultFields.length != requestedDataColIds.length || resultFields.length != requestedPartitionColIds.length){ - throw new UnsupportedOperationException("This operator doesn't support orc initBatch."); - } + public void initBatch(StructType partitionColumns, InternalRow partitionValues) { + StructType resultSchema = new StructType(); - this.requiredFields = requiredFields; + for (StructField f: requiredSchema.fields()) { + resultSchema = resultSchema.add(f); + } - StructType resultSchema = new StructType(resultFields); + if (partitionColumns != null) { + for (StructField f: partitionColumns.fields()) { + resultSchema = resultSchema.add(f); + } + } // Just wrap the ORC column vector instead of copying it to Spark column vector. orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; templateWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; - for (int i = 0; i < resultFields.length; i++) { - DataType dt = resultFields[i].dataType(); - if (requestedPartitionColIds[i] != -1) { - OmniColumnVector partitionCol = new OmniColumnVector(capacity, dt, true); - OmniColumnVectorUtils.populate(partitionCol, partitionValues, requestedPartitionColIds[i]); + if (partitionColumns != null) { + int partitionIdx = requiredSchema.fields().length; + for (int i = 0; i < partitionColumns.fields().length; i++) { + OmniColumnVector partitionCol = new OmniColumnVector(capacity, partitionColumns.fields()[i].dataType(), true); + OmniColumnVectorUtils.populate(partitionCol, partitionValues, i); partitionCol.setIsConstant(); - templateWrappers[i] = partitionCol; - orcVectorWrappers[i] = new OmniColumnVector(capacity, dt, false);; + templateWrappers[i + partitionIdx] = partitionCol; + orcVectorWrappers[i + partitionIdx] = new OmniColumnVector(capacity, partitionColumns.fields()[i].dataType(), false); + } + } + + for (int i = 0; i < requiredSchema.fields().length; i++) { + DataType dt = requiredSchema.fields()[i].dataType(); + if (recordReader.colsToGet[i] == -1) { + // missing cols + OmniColumnVector missingCol = new OmniColumnVector(capacity, dt, true); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + templateWrappers[i] = missingCol; } else { - int colId = requestedDataColIds[i]; - // Initialize the missing columns once. - if (colId == -1) { - OmniColumnVector missingCol = new OmniColumnVector(capacity, dt, true); - missingCol.putNulls(0, capacity); - missingCol.setIsConstant(); - templateWrappers[i] = missingCol; - } else { - templateWrappers[i] = new OmniColumnVector(capacity, dt, false); - } - orcVectorWrappers[i] = new OmniColumnVector(capacity, dt, false); + templateWrappers[i] = new OmniColumnVector(capacity, dt, false); } + orcVectorWrappers[i] = new OmniColumnVector(capacity, dt, false); } + // init batch recordReader.initBatchJava(capacity); vecs = new Vec[orcVectorWrappers.length]; @@ -260,7 +211,7 @@ public class OmniOrcColumnarBatchReader extends RecordReader allFieldsNames = reader.getAllFieldsNames(); + + ArrayList includeFieldNames = new ArrayList<>(); + for (int i = 0; i < requiredFieldNames.length; i++) { + String target = requiredFieldNames[i]; + if (allFieldsNames.contains(target)) { + missingColumns[i] = false; + includeFieldNames.add(target); + types.add(structFields[i].dataType()); + } else { + missingColumns[i] = true; + } + } + return includeFieldNames.toArray(new String[includeFieldNames.size()]); + } // Creates a columnar batch that includes the schema from the data files and the additional // partition columns appended to the end of the batch. @@ -138,7 +161,6 @@ public class OmniParquetColumnarBatchReader extends RecordReader - logInfo(s"Columnar Processing for ${cmd.getClass} is currently supported.") - val fileFormat: FileFormat = cmd.fileFormat match { - case _: OrcFileFormat => new OmniOrcFileFormat() - case format => - logInfo(s"Unsupported ${format.getClass} file " + - s"format for columnar data write command.") - unSupportedFileFormat = true - null - } - if (unSupportedFileFormat) { - cmd - } else { - OmniInsertIntoHadoopFsRelationCommand(cmd.outputPath, cmd.staticPartitions, - cmd.ifPartitionNotExists, cmd.partitionColumns, cmd.bucketSpec, fileFormat, - cmd.options, cmd.query, cmd.mode, cmd.catalogTable, - cmd.fileIndex, cmd.outputColumnNames - ) - } + logInfo(s"Columnar Processing for ${cmd.getClass} is currently not supported.") + unSupportedColumnarCommand = true + cmd case cmd: DataWritingCommand => logInfo(s"Columnar Processing for ${cmd.getClass} is currently not supported.") unSupportedColumnarCommand = true diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index eb8bf478da56835aa50878dd30567f314e295717..26b6f8f1d5caefbf11c3c32186ed79f5ad933e47 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -183,6 +183,10 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { def adaptivePartialAggregationRatio: Double = conf.getConf(ADAPTIVE_PARTIAL_AGGREGATION_RATIO) + def timeParserPolicy: String = conf.getConfString("spark.sql.legacy.timeParserPolicy") + + def enableOmniUnixTimeFunc: Boolean = conf.getConf(ENABLE_OMNI_UNIXTIME_FUNCTION) + } @@ -636,4 +640,10 @@ object ColumnarPluginConfig { .doubleConf .createWithDefault(0.8) + val ENABLE_OMNI_UNIXTIME_FUNCTION = buildConf("spark.omni.sql.columnar.unixTimeFunc.enabled") + .internal() + .doc("enable omni unix_timestamp and from_unixtime") + .booleanConf + .createWithDefault(true) + } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 3152d6c7c4f8d928d5df4941cf75826dd5cbe33e..c06c8e0a72e0e15a0f5fbe52f53b65aedfd7f180 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -35,8 +35,11 @@ import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti import org.apache.spark.sql.catalyst.util.CharVarcharUtils.getRawTypeString import org.apache.spark.sql.execution import org.apache.spark.sql.execution.ColumnarBloomFilterSubquery +import org.apache.spark.sql.execution.datasources.OmniFileFormatWriter.Empty2Null import org.apache.spark.sql.expression.ColumnarExpressionConverter import org.apache.spark.sql.hive.HiveUdfAdaptorUtil +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DecimalType.{MAX_PRECISION, MAX_SCALE} import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, NullType, ShortType, StringType, TimestampType} import org.json.{JSONArray, JSONObject} @@ -76,9 +79,9 @@ object OmniExpressionAdaptor extends Logging { } } - private def unsupportedCastCheck(expr: Expression, cast: CastBase): Unit = { + private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = { def doSupportCastToString(dataType: DataType): Boolean = { - dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[IntegerType] || + (dataType.isInstanceOf[DecimalType] && !SQLConf.get.ansiEnabled) || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[IntegerType] || dataType.isInstanceOf[LongType] || dataType.isInstanceOf[DateType] || dataType.isInstanceOf[DoubleType] || dataType.isInstanceOf[NullType] } @@ -100,22 +103,69 @@ object OmniExpressionAdaptor extends Logging { } - private val timeFormatSet: Set[String] = Set("yyyy-MM-dd HH:mm:ss", "yyyy-MM-dd") + private def binaryOperatorAdjust(expr: BinaryOperator, returnDataType: DataType): Tuple2[Expression, Expression] = { + import scala.math.{max, min} + def bounded(precision: Int, scale: Int): DecimalType = { + DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) + } - private def unsupportedTimeFormatCheck(timeFormat: String): Unit = { - if (!timeFormatSet.contains(timeFormat)) { - throw new UnsupportedOperationException(s"Unsupported Time Format: $timeFormat") + def widerDecimalType(d1: DecimalType, d2: DecimalType): Tuple2[DecimalType, Boolean] = { + getWiderDecimalType(d1.precision, d1.scale, d2.precision, d2.scale) + } + + def getWiderDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): Tuple2[DecimalType, Boolean] = { + val scale = max(s1, s2) + val range = max(p1 - s1, p2 - s2) + (bounded(range + scale, scale), range + scale > MAX_PRECISION) + } + + def decimalTypeCast(expr: Expression, d: DecimalType, widerType: DecimalType, returnType: DecimalType, isOverPrecision: Boolean): Expression = { + if (isOverPrecision && d.scale <= returnType.scale) { + if (returnType.precision - returnType.scale < d.precision - d.scale) { + return expr + } + Cast (expr, returnDataType) + } else { + Cast (expr, widerType) + } + } + + if (DecimalType.unapply(expr.left) && DecimalType.unapply(expr.right)) { + val leftDataType = expr.left.dataType.asInstanceOf[DecimalType] + val rightDataType = expr.right.dataType.asInstanceOf[DecimalType] + val (widerType, isOverPrecision) = widerDecimalType(leftDataType, rightDataType) + val result = expr match { + case _: Add | _: Subtract => (Cast(expr.left, returnDataType), Cast(expr.right, returnDataType)) + case _: Multiply | _: Divide | _: Remainder => + val newLeft = decimalTypeCast(expr.left, leftDataType, widerType, returnDataType.asInstanceOf[DecimalType], isOverPrecision) + val newRight = decimalTypeCast(expr.right, rightDataType, widerType, returnDataType.asInstanceOf[DecimalType], isOverPrecision) + (newLeft, newRight) + case _ => (expr.left, expr.right) + } + return result } + (expr.left, expr.right) } + private val timeFormatSet: Set[String] = Set("yyyy-MM-dd HH:mm:ss", "yyyy-MM-dd") private val timeZoneSet: Set[String] = Set("GMT+08:00", "Asia/Beijing", "Asia/Shanghai") - private def unsupportedTimeZoneCheck(timeZone: String): Unit = { + private def unsupportedUnixTimeFunction(timeFormat: String, timeZone: String): Unit = { + if (!ColumnarPluginConfig.getSessionConf.enableOmniUnixTimeFunc) { + throw new UnsupportedOperationException(s"Not Enabled Omni UnixTime Function") + } + if (ColumnarPluginConfig.getSessionConf.timeParserPolicy == "LEGACY") { + throw new UnsupportedOperationException(s"Unsupported Time Parser Policy: LEGACY") + } if (!timeZoneSet.contains(timeZone)) { throw new UnsupportedOperationException(s"Unsupported Time Zone: $timeZone") } + if (!timeFormatSet.contains(timeFormat)) { + throw new UnsupportedOperationException(s"Unsupported Time Format: $timeFormat") + } } + def toOmniTimeFormat(format: String): String = { format.replace("yyyy", "%Y") .replace("MM", "%m") @@ -185,43 +235,45 @@ object OmniExpressionAdaptor extends Logging { throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}") } - case promotePrecision: PromotePrecision => - rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap) - case sub: Subtract => + val (left, right) = binaryOperatorAdjust(sub, returnDatatype) new JSONObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) .put("operator", "SUBTRACT") - .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(sub.left, exprsIndexMap)) - .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(sub.right, exprsIndexMap)) + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case add: Add => + val (left, right) = binaryOperatorAdjust(add, returnDatatype) new JSONObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) .put("operator", "ADD") - .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(add.left, exprsIndexMap)) - .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(add.right, exprsIndexMap)) + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case mult: Multiply => + val (left, right) = binaryOperatorAdjust(mult, returnDatatype) new JSONObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) .put("operator", "MULTIPLY") - .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(mult.left, exprsIndexMap)) - .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(mult.right, exprsIndexMap)) + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case divide: Divide => + val (left, right) = binaryOperatorAdjust(divide, returnDatatype) new JSONObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) .put("operator", "DIVIDE") - .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(divide.left, exprsIndexMap)) - .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(divide.right, exprsIndexMap)) + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case mod: Remainder => + val (left, right) = binaryOperatorAdjust(mod, returnDatatype) new JSONObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) .put("operator", "MODULUS") - .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(mod.left, exprsIndexMap)) - .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(mod.right, exprsIndexMap)) + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case greaterThan: GreaterThan => new JSONObject().put("exprType", "BINARY") @@ -319,9 +371,14 @@ object OmniExpressionAdaptor extends Logging { .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.str, exprsIndexMap)). put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.pos, exprsIndexMap)) .put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.len, exprsIndexMap))) - + case empty2Null: Empty2Null => + new JSONObject().put("exprType", "FUNCTION") + .put("function_name", "empty2null") + .addOmniExpJsonType("returnType", empty2Null.dataType) + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject + (empty2Null.child, exprsIndexMap))) // Cast - case cast: CastBase => + case cast: Cast => unsupportedCastCheck(expr, cast) cast.child.dataType match { case NullType => @@ -436,8 +493,8 @@ object OmniExpressionAdaptor extends Logging { .put("function_name", "might_contain") .put("arguments", new JSONArray() .put(rewriteToOmniJsonExpressionLiteralJsonObject( - ColumnarExpressionConverter.replaceWithColumnarExpression(bloomFilterMightContain.bloomFilterExpression), - exprsIndexMap)) + ColumnarExpressionConverter.replaceWithColumnarExpression(bloomFilterMightContain.bloomFilterExpression), + exprsIndexMap)) .put(rewriteToOmniJsonExpressionLiteralJsonObject(bloomFilterMightContain.valueExpression, exprsIndexMap, returnDatatype))) case columnarBloomFilterSubquery: ColumnarBloomFilterSubquery => @@ -480,20 +537,21 @@ object OmniExpressionAdaptor extends Logging { // for date time functions case unixTimestamp: UnixTimestamp => val timeZone = unixTimestamp.timeZoneId.getOrElse("") - unsupportedTimeZoneCheck(timeZone) - unsupportedTimeFormatCheck(unixTimestamp.format.toString) + unsupportedUnixTimeFunction(unixTimestamp.format.toString, timeZone) + val policy = ColumnarPluginConfig.getSessionConf.timeParserPolicy new JSONObject().put("exprType", "FUNCTION") .addOmniExpJsonType("returnType", unixTimestamp.dataType) .put("function_name", "unix_timestamp") .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(unixTimestamp.timeExp, exprsIndexMap)) .put(new JSONObject(toOmniTimeFormat(rewriteToOmniJsonExpressionLiteral(unixTimestamp.format, exprsIndexMap)))) .put(new JSONObject().put("exprType", "LITERAL").put("dataType", 15).put("isNull", timeZone.isEmpty()) - .put("value", timeZone).put("width", timeZone.length))) + .put("value", timeZone).put("width", timeZone.length)) + .put(new JSONObject().put("exprType", "LITERAL").put("dataType", 15).put("isNull", policy.isEmpty()) + .put("value", policy).put("width", policy.length))) case fromUnixTime: FromUnixTime => val timeZone = fromUnixTime.timeZoneId.getOrElse("") - unsupportedTimeZoneCheck(timeZone) - unsupportedTimeFormatCheck(fromUnixTime.format.toString) + unsupportedUnixTimeFunction(fromUnixTime.format.toString, timeZone) new JSONObject().put("exprType", "FUNCTION") .addOmniExpJsonType("returnType", fromUnixTime.dataType) .put("function_name", "from_unixtime") @@ -640,10 +698,10 @@ object OmniExpressionAdaptor extends Logging { rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap) } else { children.head match { - case base: CastBase if base.child.dataType.isInstanceOf[NullType] => + case base: Cast if base.child.dataType.isInstanceOf[NullType] => rewriteToOmniJsonExpressionLiteralJsonObject(children(1), exprsIndexMap) case _ => children(1) match { - case base: CastBase if base.child.dataType.isInstanceOf[NullType] => + case base: Cast if base.child.dataType.isInstanceOf[NullType] => rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap) case _ => new JSONObject().put("exprType", "FUNCTION") diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala index ea271244784cc5466dbbce00b4c942a8e22777d0..4d8c246ab52ebd7f11ea416c18080c216f1e963e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/Tree/TreePatterns.scala @@ -26,6 +26,7 @@ object TreePattern extends Enumeration { val AGGREGATE_EXPRESSION = Value(0) val ALIAS: Value = Value val AND_OR: Value = Value + val AND: Value = Value val ARRAYS_ZIP: Value = Value val ATTRIBUTE_REFERENCE: Value = Value val APPEND_COLUMNS: Value = Value @@ -58,6 +59,7 @@ object TreePattern extends Enumeration { val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value + val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value val LATERAL_SUBQUERY: Value = Value val LIKE_FAMLIY: Value = Value val LIST_SUBQUERY: Value = Value @@ -69,7 +71,10 @@ object TreePattern extends Enumeration { val NULL_CHECK: Value = Value val NULL_LITERAL: Value = Value val SERIALIZE_FROM_OBJECT: Value = Value + val OR: Value = Value val OUTER_REFERENCE: Value = Value + val PARAMETER: Value = Value + val PARAMETERIZED_QUERY: Value = Value val PIVOT: Value = Value val PLAN_EXPRESSION: Value = Value val PYTHON_UDF: Value = Value @@ -81,6 +86,7 @@ object TreePattern extends Enumeration { val SCALAR_SUBQUERY: Value = Value val SCALAR_SUBQUERY_REFERENCE: Value = Value val SCALA_UDF: Value = Value + val SESSION_WINDOW: Value = Value val SORT: Value = Value val SUBQUERY_ALIAS: Value = Value val SUBQUERY_WRAPPER: Value = Value @@ -89,7 +95,9 @@ object TreePattern extends Enumeration { val TIME_ZONE_AWARE_EXPRESSION: Value = Value val TRUE_OR_FALSE_LITERAL: Value = Value val WINDOW_EXPRESSION: Value = Value + val WINDOW_TIME: Value = Value val UNARY_POSITIVE: Value = Value + val UNPIVOT: Value = Value val UPDATE_FIELDS: Value = Value val UPPER_OR_LOWER: Value = Value val UP_CAST: Value = Value @@ -119,6 +127,7 @@ object TreePattern extends Enumeration { val UNION: Value = Value val UNRESOLVED_RELATION: Value = Value val UNRESOLVED_WITH: Value = Value + val TEMP_RESOLVED_COLUMN: Value = Value val TYPED_FILTER: Value = Value val WINDOW: Value = Value val WITH_WINDOW_DEFINITION: Value = Value @@ -127,6 +136,7 @@ object TreePattern extends Enumeration { val UNRESOLVED_ALIAS: Value = Value val UNRESOLVED_ATTRIBUTE: Value = Value val UNRESOLVED_DESERIALIZER: Value = Value + val UNRESOLVED_HAVING: Value = Value val UNRESOLVED_ORDINAL: Value = Value val UNRESOLVED_FUNCTION: Value = Value val UNRESOLVED_HINT: Value = Value @@ -135,6 +145,8 @@ object TreePattern extends Enumeration { // Unresolved Plan patterns (Alphabetically ordered) val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value val UNRESOLVED_FUNC: Value = Value + val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value + val UNRESOLVED_TVF_ALIASES: Value = Value // Execution expression patterns (alphabetically ordered) val IN_SUBQUERY_EXEC: Value = Value diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala index 0a5d509b05af809fccbedcb28a74cbaea73b40b8..85192fc36038815c38167b64da74e51896d688fa 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/expressions/runtimefilter.scala @@ -39,7 +39,7 @@ case class RuntimeFilterSubquery( exprId: ExprId = NamedExpression.newExprId, hint: Option[HintInfo] = None) extends SubqueryExpression( - filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty) + filterCreationSidePlan, Seq(filterApplicationSideExp), exprId, Seq.empty, hint) with Unevaluable with UnaryLike[Expression] { @@ -74,6 +74,8 @@ case class RuntimeFilterSubquery( override protected def withNewChildInternal(newChild: Expression): RuntimeFilterSubquery = copy(filterApplicationSideExp = newChild) + + override def withNewHint(hint: Option[HintInfo]): RuntimeFilterSubquery = copy(hint = hint) } /** diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 812c387bc6b8e5d253f82efb6ac639eebde71c22..ac38c5399325c48475df32f346b2af3a13e38361 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -218,9 +218,9 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J leftKey: Expression, rightKey: Expression): Boolean = { (left, right) match { - case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) => + case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan), _) => pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey) - case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) => + case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan)) => pruningKey.fastEquals(rightKey) || hasDynamicPruningSubquery(left, plan, leftKey, rightKey) case _ => false @@ -251,10 +251,10 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J rightKey: Expression): Boolean = { (left, right) match { case (Filter(InSubquery(Seq(key), - ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) => + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _), _) => key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey))) case (_, Filter(InSubquery(Seq(key), - ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) => + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _)) => key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey))) case _ => false } @@ -299,7 +299,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J case s: Subquery if s.correlated => plan case _ if !conf.runtimeFilterSemiJoinReductionEnabled && !conf.runtimeFilterBloomFilterEnabled => plan - case _ => tryInjectRuntimeFilter(plan) + case _ => + val newPlan = tryInjectRuntimeFilter(plan) + if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) { + RewritePredicateSubquery(newPlan) + } else { + newPlan + } } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala index 1b5baa23080e816868934da02c8cbe87fa21c4d8..c4435379ffb153b5f2d211a1eec044102f2a1f6c 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala @@ -643,7 +643,7 @@ object MergeSubqueryFilters extends Rule[LogicalPlan] { val subqueryCTE = header.plan.asInstanceOf[CTERelationDef] GetStructField( ScalarSubquery( - CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output), + CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output, subqueryCTE.isStreaming), exprId = ssr.exprId), ssr.headerIndex) } else { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala index 9e402902518a90df7f68db4b420d325a53f9bf41..f6ebd716dc1ebb25d3f36bcb5705168f0d76934e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala @@ -61,7 +61,7 @@ object RewriteSelfJoinInInPredicate extends Rule[LogicalPlan] with PredicateHelp case f: Filter => f transformExpressions { case in @ InSubquery(_, listQuery @ ListQuery(Project(projectList, - Join(left, right, Inner, Some(joinCond), _)), _, _, _, _)) + Join(left, right, Inner, Some(joinCond), _)), _, _, _, _, _)) if left.canonicalized ne right.canonicalized => val attrMapping = AttributeMap(right.output.zip(left.output)) val subCondExprs = splitConjunctivePredicates(joinCond transform { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index 4863698430e78896de76218897b30870bad46634..52d78418288baa2dde40ae819f561fd2ccdbb85f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -44,8 +44,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode - with AliasAwareOutputPartitioning - with AliasAwareOutputOrdering { + with PartitioningPreservingUnaryExecNode + with OrderPreservingUnaryExecNode { override def supportsColumnar: Boolean = true @@ -267,8 +267,8 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], condition: Expression, child: SparkPlan) extends UnaryExecNode - with AliasAwareOutputPartitioning - with AliasAwareOutputOrdering { + with PartitioningPreservingUnaryExecNode + with OrderPreservingUnaryExecNode { override def supportsColumnar: Boolean = true diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index 800dcf1a0c047b206603b24a2a963ef7f3e48db1..80d796ce126774970594d75195a186be096f5de3 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -479,8 +479,8 @@ abstract class BaseColumnarFileSourceScanExec( } }.groupBy { f => BucketingUtils - .getBucketId(new Path(f.filePath).getName) - .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.filePath)) + .getBucketId(f.toPath.getName) + .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath)) } val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index ef33a84decfd9aabd62b4ea149dca78b7e0ed988..3630b0f2ea14bf7a4c4d87ba739d6734b84a870e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -105,12 +105,29 @@ class QueryExecution( case other => other } + // The plan that has been normalized by custom rules, so that it's more likely to hit cache. + lazy val normalized: LogicalPlan = { + val normalizationRules = sparkSession.sessionState.planNormalizationRules + if (normalizationRules.isEmpty) { + commandExecuted + } else { + val planChangeLogger = new PlanChangeLogger[LogicalPlan]() + val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) => + val result = rule.apply(p) + planChangeLogger.logRule(rule.ruleName, p, result) + result + } + planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized) + normalized + } + } + lazy val withCachedData: LogicalPlan = sparkSession.withActive { assertAnalyzed() assertSupported() // clone the plan to avoid sharing the plan instance between different stages like analyzing, // optimizing and planning. - sparkSession.sharedState.cacheManager.useCachedData(commandExecuted.clone()) + sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) } def assertCommandExecuted(): Unit = commandExecuted @@ -227,7 +244,7 @@ class QueryExecution( // output mode does not matter since there is no `Sink`. new IncrementalExecution( sparkSession, logical, OutputMode.Append(), "", - UUID.randomUUID, UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) + UUID.randomUUID, UUID.randomUUID, 0, None ,OffsetSeqMetadata(0, 0)) } else { this } @@ -494,11 +511,10 @@ object QueryExecution { */ private[sql] def toInternalError(msg: String, e: Throwable): Throwable = e match { case e @ (_: java.lang.NullPointerException | _: java.lang.AssertionError) => - new SparkException( - errorClass = "INTERNAL_ERROR", - messageParameters = Array(msg + - " Please, fill a bug report in, and provide the full stack trace."), - cause = e) + SparkException.internalError( + msg + " You hit a bug in Spark or the Spark plugins you use. Please, report this bug " + + "to the corresponding communities or vendors, and provide the full stack trace.", + e) case e: Throwable => e } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 6234751dde4231261f56d55be2a07bc42c65ea03..6db6fc3c1166fb73e732ea9019cfb28bce6ef8c5 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -84,7 +84,9 @@ case class AdaptiveSparkPlanExec( @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]() // The logical plan optimizer for re-optimizing the current logical plan. - @transient private val optimizer = new AQEOptimizer(conf) + @transient private val optimizer = new AQEOptimizer(conf, + session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules + ) // `EnsureRequirements` may remove user-specified repartition and assume the query plan won't // change its output partitioning. This assumption is not true in AQE. Here we check the @@ -123,7 +125,7 @@ case class AdaptiveSparkPlanExec( RemoveRedundantSorts, DisableUnnecessaryBucketedScan, OptimizeSkewedJoin(ensureRequirements) - ) ++ context.session.sessionState.queryStagePrepRules + ) ++ context.session.sessionState.adaptiveRulesHolder.queryStagePrepRules } // A list of physical optimizer rules to be applied to a new stage before its execution. These @@ -188,7 +190,7 @@ case class AdaptiveSparkPlanExec( @volatile private var currentPhysicalPlan = initialPlan - private var isFinalPlan = false + @volatile private var _isFinalPlan = false private var currentStageId = 0 @@ -205,6 +207,8 @@ case class AdaptiveSparkPlanExec( def executedPlan: SparkPlan = currentPhysicalPlan + def isFinalPlan: Boolean = _isFinalPlan + override def conf: SQLConf = context.session.sessionState.conf override def output: Seq[Attribute] = inputPlan.output @@ -223,6 +227,8 @@ case class AdaptiveSparkPlanExec( .map(_.toLong).filter(SQLExecution.getQueryExecution(_) eq context.qe) } + def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity) + private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized { if (isFinalPlan) return currentPhysicalPlan @@ -326,7 +332,7 @@ case class AdaptiveSparkPlanExec( optimizeQueryStage(result.newPlan, isFinalStage = true), postStageCreationRules(supportsColumnar), Some((planChangeLogger, "AQE Post Stage Creation"))) - isFinalPlan = true + _isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index a70ba852e7a3b0adf145afde24427997d08658e9..3cd4913643b213d407609edc547077a6af188aa9 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} +import org.apache.spark.sql.execution.datasources.WriteFilesExec import org.apache.spark.sql.execution.datasources.v2.V2CommandExec import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.internal.SQLConf @@ -44,6 +45,7 @@ case class InsertAdaptiveSparkPlan( private def applyInternal(plan: SparkPlan, isSubquery: Boolean): SparkPlan = plan match { case _ if !conf.adaptiveExecutionEnabled => plan + case _: WriteFilesExec => plan case _: ExecutedCommandExec => plan case _: CommandResultExec => plan case c: DataWritingCommandExec => c.copy(child = apply(c.child)) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index b5a1ad3754929c738a1dcc1dda5da9047a27305f..dfdbe2c7038298794a375bf701157acfe03956ec 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -30,11 +30,11 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning(_.containsAnyPattern( SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY, RUNTIME_FILTER_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _) => + case expressions.ScalarSubquery(_, _, exprId, _, _, _) => val subquery = SubqueryExec.createForScalarSubquery( s"subquery#${exprId.id}", subqueryMap(exprId.id)) execution.ScalarSubquery(subquery, exprId) - case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _)) => + case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) => val expr = if (values.length == 1) { values.head } else { @@ -47,7 +47,7 @@ case class PlanAdaptiveSubqueries( val subquery = SubqueryExec(s"subquery#${exprId.id}", subqueryMap(exprId.id)) InSubqueryExec(expr, subquery, exprId, shouldBroadcast = true) case expressions.DynamicPruningSubquery(value, buildPlan, - buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) => + buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) => val name = s"dynamicpruning#${exprId.id}" val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndex, onlyInBroadcast, buildPlan, buildKeys, subqueryMap(exprId.id)) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala index 422435694a86284aed9e8515a94dfaca88bbfd6a..171825da27aee4aff3ac0371cb49d2d0d2dee858 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala @@ -27,14 +27,18 @@ import org.apache.spark.sql.SparkSession case class PushOrderedLimitThroughAgg(session: SparkSession) extends Rule[SparkPlan] with PredicateHelper { override def apply(plan: SparkPlan): SparkPlan = { - if (!ColumnarPluginConfig.getSessionConf.pushOrderedLimitThroughAggEnable) { + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + // The two optimization principles are contrary and cannot be used at the same time. + // reason: the pushOrderedLimitThroughAgg rule depends on the actual aggregation result in the partial phase. + // However, if the partial phase is skipped, aggregation is not performed. + if (!columnarConf.pushOrderedLimitThroughAggEnable || columnarConf.enableAdaptivePartialAggregation) { return plan } - val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableColumnarTopNSort: Boolean = columnarConf.enableColumnarTopNSort plan.transform { - case orderAndProject @ TakeOrderedAndProjectExec(limit, sortOrder, projectList, orderAndProjectChild) => { + case orderAndProject @ TakeOrderedAndProjectExec(limit, sortOrder, projectList, orderAndProjectChild, _) => { orderAndProjectChild match { case finalAgg @ HashAggregateExec(_, _, _, _, _, _, _, _, finalAggChild) => finalAggChild match { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala index 8983a6f180fdc618fe465144e461f41417902e4f..f3fe865e0b85e88277ba10776bde67092415a245 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala @@ -157,6 +157,35 @@ abstract class OmniBaseDynamicPartitionDataWriter( protected val getOutputRow = UnsafeProjection.create(description.dataColumns, description.allColumns) + protected def getPartitionPath(partitionValues: Option[InternalRow], + bucketId: Option[Int]): String = { + val partDir = partitionValues.map(getPartitionPath(_)) + partDir.foreach(updatedPartitions.add) + + val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + + // The prefix and suffix must be in a form that matches our bucketing format. See BucketingUtils + // for details. The prefix is required to represent bucket id when writing Hive-compatible + // bucketed table. + val prefix = bucketId match { + case Some(id) => description.bucketSpec.get.bucketFileNamePrefix(id) + case _ => "" + } + val suffix = f"$bucketIdStr.c$fileCounter%03d" + + description.outputWriterFactory.getFileExtension(taskAttemptContext) + val fileNameSpec = FileNameSpec(prefix, suffix) + + val customPath = partDir.flatMap { dir => + description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + } + val currentPath = if (customPath.isDefined) { + customPath.get + fileNameSpec.toString + } else { + partDir.toString + fileNameSpec.toString + } + currentPath + } + /** * Opens a new OutputWriter given a partition key and/or a bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the @@ -276,21 +305,25 @@ class OmniDynamicPartitionDataSingleWriter( val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None if (currentPartitionValues != nextPartitionValues || currentBucketId != nextBucketId) { - // See a new partition or bucket - write to a new partition dir (or a new bucket file). - if (isPartitioned && currentPartitionValues != nextPartitionValues) { - currentPartitionValues = Some(nextPartitionValues.get.copy()) - statsTrackers.foreach(_.newPartition(currentPartitionValues.get)) - } - if (isBucketed) { - currentBucketId = nextBucketId - } - - fileCounter = 0 - if (i != 0) { - writeRecord(omniInternalRow, lastIndex, i) - lastIndex = i + val isFilePathSame = getPartitionPath(currentPartitionValues, + currentBucketId) == getPartitionPath(nextPartitionValues, nextBucketId) + if (!isFilePathSame) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + if (isPartitioned && currentPartitionValues != nextPartitionValues) { + currentPartitionValues = Some(nextPartitionValues.get.copy()) + statsTrackers.foreach(_.newPartition(currentPartitionValues.get)) + } + if (isBucketed) { + currentBucketId = nextBucketId + } + + fileCounter = 0 + if (i != 0) { + writeRecord(omniInternalRow, lastIndex, i) + lastIndex = i + } + renewCurrentWriter(currentPartitionValues, currentBucketId, closeCurrentWriter = true) } - renewCurrentWriter(currentPartitionValues, currentBucketId, closeCurrentWriter = true) } else if ( description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala index 465123d8730d39656d242bd4785e22680312ce05..a65e66c4edf3ccd9bad15fc34e16e9155438270e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala @@ -272,7 +272,7 @@ object OmniFileFormatWriter extends Logging { case cause: Throwable => logError(s"Aborting job ${description.uuid}.", cause) committer.abortJob(job) - throw QueryExecutionErrors.jobAbortedError(cause) + throw cause } } @@ -343,7 +343,7 @@ object OmniFileFormatWriter extends Logging { // We throw the exception and let Executor throw ExceptionFailure to abort the job. throw new TaskOutputFileAlreadyExistException(f) case t: Throwable => - throw QueryExecutionErrors.taskFailedWhileWritingRowsError(t) + throw QueryExecutionErrors.taskFailedWhileWritingRowsError(description.path, t) } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala index 9d0008e0b2cf1a57443e019bf41113f60c8855fd..7c60110911fedccbc88ccce0e9ffc970d9ec5e23 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala @@ -83,7 +83,6 @@ case class OmniInsertIntoHadoopFsRelationCommand( // Most formats don't do well with duplicate columns, so lets not allow that SchemaUtils.checkColumnNameDuplication( outputColumnNames, - s"when inserting into $outputPath", sparkSession.sessionState.conf.caseSensitiveAnalysis) val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala index 397196c7cd982b266429585a5f60a45dc7c79561..4e8d144240c6c7e58b3a76fa013fc2311c8bd56d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala @@ -17,25 +17,24 @@ package org.apache.spark.sql.execution.datasources.orc -import java.io.Serializable -import java.net.URI import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.orc.OrcConf import org.apache.orc.OrcConf.COMPRESS -import org.apache.orc.{OrcConf, OrcFile, TypeDescription} -import org.apache.orc.TypeDescription.Category._ -import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.TaskContext import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.util.SparkMemoryUtils import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DecimalType, StringType, StructType, TimestampType} -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +import java.io.Serializable +import java.net.URI class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializable { @@ -54,44 +53,6 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ OrcUtils.inferSchema(sparkSession, files, options) } - private def isPPDSafe(filters: Seq[Filter], dataSchema: StructType): Seq[Boolean] = { - def convertibleFiltersHelper(filter: Filter, - dataSchema: StructType): Boolean = filter match { - case And(left, right) => - convertibleFiltersHelper(left, dataSchema) && convertibleFiltersHelper(right, dataSchema) - case Or(left, right) => - convertibleFiltersHelper(left, dataSchema) && convertibleFiltersHelper(right, dataSchema) - case Not(pred) => - convertibleFiltersHelper(pred, dataSchema) - case other => - other match { - case EqualTo(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case EqualNullSafe(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case LessThan(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case LessThanOrEqual(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case GreaterThan(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case GreaterThanOrEqual(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case IsNull(name) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case IsNotNull(name) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case In(name, _) => - dataSchema.apply(name).dataType != StringType && dataSchema.apply(name).dataType != TimestampType - case _ => false - } - } - - filters.map { filter => - convertibleFiltersHelper(filter, dataSchema) - } - } - override def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, @@ -101,7 +62,6 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) val sqlConf = sparkSession.sessionState.conf val capacity = sqlConf.orcVectorizedReaderBatchSize @@ -111,21 +71,17 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown - val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles (file: PartitionedFile) => { val conf = broadcastedConf.value.value - val filePath = new Path(new URI(file.filePath)) - val isPPDSafeValue = isPPDSafe(filters, dataSchema).reduceOption(_ && _) + val filePath = file.toPath // ORC predicate pushdown - if (orcFilterPushDown && filters.nonEmpty && isPPDSafeValue.getOrElse(false)) { - OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { - fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f => - OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) - } - } + val pushed = if (orcFilterPushDown) { + filters.reduceOption(And(_, _)) + } else { + None } val taskConf = new Configuration(conf) @@ -134,42 +90,16 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) // read data from vectorized reader - val batchReader = new OmniOrcColumnarBatchReader(capacity) + val batchReader = new OmniOrcColumnarBatchReader(capacity, requiredSchema, pushed.orNull) // SPARK-23399 Register a task completion listener first to call `close()` in all cases. // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) // after opening a file. val iter = new RecordReaderIterator(batchReader) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) - // fill requestedDataColIds with -1, fil real values int initDataColIds function - val requestedDataColIds = Array.fill(requiredSchema.length)(-1) ++ Array.fill(partitionSchema.length)(-1) - val requestedPartitionColIds = - Array.fill(requiredSchema.length)(-1) ++ Range(0, partitionSchema.length) - - // 初始化precision数组和scale数组,透传至java侧使用 - val requiredFields = requiredSchema.fields - val fieldslength = requiredFields.length - val precisionArray : Array[Int] = Array.ofDim[Int](fieldslength) - val scaleArray : Array[Int] = Array.ofDim[Int](fieldslength) - for ((reqField, index) <- requiredFields.zipWithIndex) { - val reqdatatype = reqField.dataType - if (reqdatatype.isInstanceOf[DecimalType]) { - val precision = reqdatatype.asInstanceOf[DecimalType].precision - val scale = reqdatatype.asInstanceOf[DecimalType].scale - precisionArray(index) = precision - scaleArray(index) = scale - } - } SparkMemoryUtils.init() batchReader.initialize(fileSplit, taskAttemptContext) - batchReader.initDataColIds(requiredSchema, requestedPartitionColIds, requestedDataColIds, resultSchema.fields, - precisionArray, scaleArray) - batchReader.initBatch( - requiredSchema.fields, - resultSchema.fields, - requestedDataColIds, - requestedPartitionColIds, - file.partitionValues) + batchReader.initBatch(partitionSchema, file.partitionValues) iter.asInstanceOf[Iterator[InternalRow]] } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala index 78acf30584d9a3c784b9bf7ec5ca52bc4abf252e..4586b8ec078ac0d226fd370f68db9fb9bc050b18 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala @@ -88,7 +88,7 @@ class OmniParquetFileFormat extends FileFormat with DataSourceRegister with Logg (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) - val filePath = new Path(new URI(file.filePath)) + val filePath = file.toPath val split = new org.apache.parquet.hadoop.ParquetInputSplit( filePath, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala index 2204e0fe4b0e3c04ede12b58260ce06cbbb15ce1..ba5db895e8311d3b1c42a37a7fdf429f3b249421 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala @@ -264,6 +264,7 @@ object BroadcastUtils { -1, -1L, -1, + -1, new TaskMemoryManager(memoryManager, -1L), new Properties, MetricsSystem.createMetricsSystem("OMNI_UNSAFE", sparkConf), diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java index fe1c55ffb6eeecc7a16521430759e3239ca77dd4..b236da6446a8d46c9d6bd6b95e039d7679e0e6df 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java @@ -68,7 +68,7 @@ public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java index 995c434f66a8b03a6a76830b8fb0b08be9a3223e..7f6c1acb9543aed62eb535f1a0abf512876aa471 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java @@ -66,7 +66,7 @@ public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java index c9ad9fadaf4ae13b1286467f965a23f66788c37f..2c912d919360bdd19c7ac9d20693decedf425919 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java @@ -66,7 +66,7 @@ public class OrcColumnarBatchJniReaderPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java index 8f4535338cc3715e33dbf336e5f15fd4eb91569f..cf86c0a5ae2b3770474509dc69f6ebc10b4598a1 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java @@ -68,7 +68,7 @@ public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java index 27bcf5d7bdbd4787f42e90e620d1627678f70910..ef8d037bf7a9a7d58c33124c0d8745ce0b584ca9 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java @@ -68,7 +68,7 @@ public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java index eab15fef660250780e0beb311b489e3ceeb8ff5b..19f23db00d133ddb47c475c42003ace1ebecf675 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java @@ -18,103 +18,90 @@ package com.huawei.boostkit.spark.jni; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.io.Input; +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; import junit.framework.TestCase; import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.VarcharVec; import nova.hetu.omniruntime.vector.Vec; -import org.apache.commons.codec.binary.Base64; -import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; -import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl; -import org.apache.orc.OrcFile; -import org.apache.orc.TypeDescription; -import org.apache.orc.mapred.OrcInputFormat; +import org.apache.hadoop.conf.Configuration; +import org.apache.orc.Reader; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.junit.After; import org.junit.Before; import org.junit.FixMethodOrder; import org.junit.Test; import org.junit.runners.MethodSorters; -import org.apache.hadoop.conf.Configuration; + import java.io.File; import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; -import java.util.List; import java.util.Arrays; -import org.apache.orc.Reader.Options; - -import static org.junit.Assert.*; +import java.util.List; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; -import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import static org.apache.spark.sql.types.DataTypes.LongType; +import static org.apache.spark.sql.types.DataTypes.StringType; @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderTest extends TestCase { public Configuration conf = new Configuration(); public OrcColumnarBatchScanReader orcColumnarBatchScanReader; - public int batchSize = 4096; + private int batchSize = 4096; + + private StructType requiredSchema; + private int[] vecTypeIds; + + private long offset = 0; + + private long length = Integer.MAX_VALUE; @Before public void setUp() throws Exception { - Configuration conf = new Configuration(); - TypeDescription schema = - TypeDescription.fromString("struct<`i_item_sk`:bigint,`i_item_id`:string>"); - Options options = new Options(conf) - .range(0, Integer.MAX_VALUE) - .useZeroCopy(false) - .skipCorruptRecords(false) - .tolerateMissingSchema(true); - - options.schema(schema); - options.include(OrcInputFormat.parseInclude(schema, - null)); - String kryoSarg = "AQEAb3JnLmFwYWNoZS5oYWRvb3AuaGl2ZS5xbC5pby5zYXJnLkV4cHJlc3Npb25UcmXlAQEBamF2YS51dGlsLkFycmF5TGlz9AECAQABAQEBAQEAAQAAAAEEAAEBAwEAAQEBAQEBAAEAAAIIAAEJAAEBAgEBAQIBAscBb3JnLmFwYWNoZS5oYWRvb3AuaGl2ZS5xbC5pby5zYXJnLlNlYXJjaEFyZ3VtZW50SW1wbCRQcmVkaWNhdGVMZWFmSW1wbAEBaV9pdGVtX3PrAAABBwEBAQIBEAkAAAEEEg=="; - String sargColumns = "i_item_sk,i_item_id,i_rec_start_date,i_rec_end_date,i_item_desc,i_current_price,i_wholesale_cost,i_brand_id,i_brand,i_class_id,i_class,i_category_id,i_category,i_manufact_id,i_manufact,i_size,i_formulation,i_color,i_units,i_container,i_manager_id,i_product_name"; - if (kryoSarg != null && sargColumns != null) { - byte[] sargBytes = Base64.decodeBase64(kryoSarg); - SearchArgument sarg = - new Kryo().readObject(new Input(sargBytes), SearchArgumentImpl.class); - options.searchArgument(sarg, sargColumns.split(",")); - sarg.getExpression().toString(); - } - orcColumnarBatchScanReader = new OrcColumnarBatchScanReader(); + constructSchema(); initReaderJava(); - initDataColIds(options, orcColumnarBatchScanReader); - initRecordReaderJava(options); - initBatch(options); + initDataColIds(); + initRecordReaderJava(); + initBatch(); + } + + private void constructSchema() { + requiredSchema = new StructType() + .add("i_item_sk", LongType) + .add("i_item_id", StringType); } - public void initDataColIds( - Options options, OrcColumnarBatchScanReader orcColumnarBatchScanReader) { - List allCols; - allCols = Arrays.asList(options.getColumnNames()); - orcColumnarBatchScanReader.colToInclu = new ArrayList(); - List optionField = options.getSchema().getFieldNames(); - orcColumnarBatchScanReader.colsToGet = new int[optionField.size()]; - orcColumnarBatchScanReader.realColsCnt = 0; - for (int i = 0; i < optionField.size(); i++) { - if (allCols.contains(optionField.get(i))) { - orcColumnarBatchScanReader.colToInclu.add(optionField.get(i)); - orcColumnarBatchScanReader.colsToGet[i] = 0; - orcColumnarBatchScanReader.realColsCnt++; - } else { + private void initDataColIds() { + // find requiredS fieldNames + String[] requiredfieldNames = requiredSchema.fieldNames(); + // save valid cols and numbers of valid cols + orcColumnarBatchScanReader.colsToGet = new int[requiredfieldNames.length]; + orcColumnarBatchScanReader.includedColumns = new ArrayList<>(); + // collect read cols types + ArrayList typeBuilder = new ArrayList<>(); + + for (int i = 0; i < requiredfieldNames.length; i++) { + String target = requiredfieldNames[i]; + + // if not find, set colsToGet value -1, else set colsToGet 0 + boolean is_find = false; + for (int j = 0; j < orcColumnarBatchScanReader.allFieldsNames.size(); j++) { + if (target.equals(orcColumnarBatchScanReader.allFieldsNames.get(j))) { + orcColumnarBatchScanReader.colsToGet[i] = 0; + orcColumnarBatchScanReader.includedColumns.add(requiredfieldNames[i]); + typeBuilder.add(OmniExpressionAdaptor.sparkTypeToOmniType(requiredSchema.fields()[i].dataType())); + is_find = true; + break; + } + } + + if (!is_find) { orcColumnarBatchScanReader.colsToGet[i] = -1; } } - orcColumnarBatchScanReader.requiredfieldNames = new String[optionField.size()]; - TypeDescription schema = options.getSchema(); - int[] precisionArray = new int[optionField.size()]; - int[] scaleArray = new int[optionField.size()]; - for (int i = 0; i < optionField.size(); i++) { - precisionArray[i] = schema.findSubtype(optionField.get(i)).getPrecision(); - scaleArray[i] = schema.findSubtype(optionField.get(i)).getScale(); - orcColumnarBatchScanReader.requiredfieldNames[i] = optionField.get(i); - } - orcColumnarBatchScanReader.precisionArray = precisionArray; - orcColumnarBatchScanReader.scaleArray = scaleArray; + vecTypeIds = typeBuilder.stream().mapToInt(Integer::intValue).toArray(); } @After @@ -122,8 +109,7 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { System.out.println("orcColumnarBatchJniReader test finished"); } - public void initReaderJava() throws URISyntaxException { - OrcFile.ReaderOptions readerOptions = OrcFile.readerOptions(conf); + private void initReaderJava() { File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); String path = directory.getAbsolutePath(); URI uri = null; @@ -133,16 +119,17 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { // if URISyntaxException thrown, next line assertNotNull will interrupt the test } assertNotNull(uri); - orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, readerOptions); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); assertTrue(orcColumnarBatchScanReader.reader != 0); } - public void initRecordReaderJava(Options options) { - orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader.initializeRecordReaderJava(options); + private void initRecordReaderJava() { + orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader. + initializeRecordReaderJava(offset, length, null, requiredSchema); assertTrue(orcColumnarBatchScanReader.recordReader != 0); } - public void initBatch(Options options) { + private void initBatch() { orcColumnarBatchScanReader.initBatchJava(batchSize); assertTrue(orcColumnarBatchScanReader.batchReader != 0); } @@ -150,8 +137,7 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { @Test public void testNext() { Vec[] vecs = new Vec[2]; - int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()}; - long rtn = orcColumnarBatchScanReader.next(vecs, typeId); + long rtn = orcColumnarBatchScanReader.next(vecs, vecTypeIds); assertTrue(rtn == 4096); assertTrue(((LongVec) vecs[0]).get(0) == 1); String str = new String(((VarcharVec) vecs[1]).get(0)); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java index 703149f462338695e3e6aa35cfe11deebdb32e58..3fb5b29284e197c664cc0090392e14b357a37a05 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java @@ -47,6 +47,7 @@ public class ParquetColumnarBatchJniReaderTest extends TestCase { private Vec[] vecs; + private boolean[] missingColumns; private StructType schema; private List types; @@ -59,9 +60,11 @@ public class ParquetColumnarBatchJniReaderTest extends TestCase { File file = new File("src/test/java/com/huawei/boostkit/spark/jni/parquetsrc/parquet_data_all_type"); String path = file.getAbsolutePath(); - parquetColumnarBatchScanReader.initializeReaderJava(new Path(path), 0, 100000, - 4096, null); - vecs = new Vec[9]; + parquetColumnarBatchScanReader.initializeReaderJava(new Path(path), 4096); + parquetColumnarBatchScanReader.initializeRecordReaderJava(0, 100000, schema.fieldNames(), null); + missingColumns = new boolean[schema.fieldNames().length]; + Arrays.fill(missingColumns, false); + vecs = new Vec[schema.fieldNames().length]; } private void constructSchema() { @@ -93,7 +96,7 @@ public class ParquetColumnarBatchJniReaderTest extends TestCase { @Test public void testRead() { - long num = parquetColumnarBatchScanReader.next(vecs, types); + long num = parquetColumnarBatchScanReader.next(vecs, missingColumns, types); assertTrue(num == 1); } } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala index 1566fb383e5e3cf64b113124ab68c27a2efcaf94..b32c3983d8331a4fa94cc5351f21f7d78da6727b 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala @@ -168,4 +168,37 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { assert(runRows(0).getDate(0).toString == "1001-01-04", "the run value is error") } + + test("empty string partition") { + val drop = spark.sql("drop table if exists table_insert_varchar") + drop.collect() + val createTable = spark.sql("create table table_insert_varchar" + + "(id int, c_varchar varchar(40)) using orc partitioned by (p_varchar varchar(40))") + createTable.collect() + val insert = spark.sql("insert into table table_insert_varchar values" + + "(5,'',''), (13,'6884578', null), (6,'72135', '666')") + insert.collect() + + val select = spark.sql("select * from table_insert_varchar order by id, c_varchar, p_varchar") + val runRows = select.collect() + val expectedRows = Seq(Row(5, "", null), Row(6, "72135", "666"), Row(13, "6884578", null)) + assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + + val dropNP = spark.sql("drop table if exists table_insert_varchar_np") + dropNP.collect() + val createTableNP = spark.sql("create table table_insert_varchar_np" + + "(id int, c_varchar varchar(40)) using orc partitioned by " + + "(p_varchar1 int, p_varchar2 varchar(40), p_varchar3 varchar(40))") + createTableNP.collect() + val insertNP = spark.sql("insert into table table_insert_varchar_np values" + + "(5,'',1,'',''), (13,'6884578',6, null, null), (1,'abc',1,'',''), " + + "(3,'abcde',6,null,null), (4,'qqqqq', 8, 'a', 'b'), (6,'ooooo', 8, 'a', 'b')") + val selectNP = spark.sql("select * from table_insert_varchar_np " + + "order by id, c_varchar, p_varchar1") + val runRowsNP = selectNP.collect() + val expectedRowsNP = Seq(Row(1, "abc", 1, null, null), Row(3, "abcde", 6, null, null), + Row(4, "qqqqq", 8, "a", "b"), Row(5, "", 1, null, null), Row(6, "ooooo", 8, "a", "b"), + Row(13, "6884578", 6, null, null)) + assert(QueryTest.sameRows(runRowsNP, expectedRowsNP).isEmpty, "the run value is error") + } } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e6b786c2a2030b4414feb47e4f81a786e9ee2427..36d0a3e4905e44ba9e32ffe8a9e58cb8b2c2deac 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkException import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.execution.ColumnarSparkPlanTest import org.apache.spark.sql.types.{DataType, Decimal} @@ -64,7 +65,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getByte(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as short") { @@ -72,7 +73,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getShort(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as int") { @@ -80,7 +81,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getInt(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as long") { @@ -88,7 +89,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getLong(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as float") { @@ -96,7 +97,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getFloat(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as double") { @@ -104,7 +105,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception = intercept[Exception]( result.collect().toSeq.head.getDouble(0) ) - assert(exception.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast null as date") { @@ -154,13 +155,13 @@ class CastSuite extends ColumnarSparkPlanTest { val exception4 = intercept[Exception]( result4.collect().toSeq.head.getBoolean(0) ) - assert(exception4.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception4.isInstanceOf[SparkException], s"sql: ${sql}") val result5 = spark.sql("select cast('test' as boolean);") val exception5 = intercept[Exception]( result5.collect().toSeq.head.getBoolean(0) ) - assert(exception5.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception5.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast boolean to string") { @@ -182,13 +183,13 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getByte(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as byte);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getByte(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast byte to string") { @@ -210,13 +211,13 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getShort(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as short);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getShort(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast short to string") { @@ -238,13 +239,13 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getInt(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as int);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getInt(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast int to string") { @@ -266,13 +267,13 @@ class CastSuite extends ColumnarSparkPlanTest { val exception2 = intercept[Exception]( result2.collect().toSeq.head.getLong(0) ) - assert(exception2.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception2.isInstanceOf[SparkException], s"sql: ${sql}") val result3 = spark.sql("select cast('false' as long);") val exception3 = intercept[Exception]( result3.collect().toSeq.head.getLong(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast long to string") { @@ -298,7 +299,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception3 = intercept[Exception]( result3.collect().toSeq.head.getFloat(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast float to string") { @@ -324,7 +325,7 @@ class CastSuite extends ColumnarSparkPlanTest { val exception3 = intercept[Exception]( result3.collect().toSeq.head.getDouble(0) ) - assert(exception3.isInstanceOf[NullPointerException], s"sql: ${sql}") + assert(exception3.isInstanceOf[SparkException], s"sql: ${sql}") } test("cast double to string") { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarFuncSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarFuncSuite.scala index 467ad35cef0c97162ad4567cafdfa8c80670d5ba..20c861eeae56f4049d46f2d0cd809bce764f7c27 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarFuncSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/expressions/ColumnarFuncSuite.scala @@ -110,6 +110,46 @@ class ColumnarFuncSuite extends ColumnarSparkPlanTest { assertOmniProjectNotHappened(rollbackRes) } + test("Test Unix_timestamp Function") { + spark.conf.set("spark.sql.optimizer.excludedRules", "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") + spark.conf.set("spark.sql.session.timeZone", "Asia/Shanghai") + spark.conf.set("spark.sql.legacy.timeParserPolicy", "CORRECTED") + val res1 = spark.sql("select unix_timestamp('','yyyy-MM-dd'), unix_timestamp('123-abc', " + + "'yyyy-MM-dd HH:mm:ss'), unix_timestamp(NULL, 'yyyy-MM-dd')") + assertOmniProjectHappened(res1) + checkAnswer(res1, Seq(Row(null, null, null))) + + val res2 = spark.sql("select unix_timestamp('2024-10-21', 'yyyy-MM-dd'), " + + "unix_timestamp('2024-10-21 11:22:33', 'yyyy-MM-dd HH:mm:ss')") + assertOmniProjectHappened(res2) + checkAnswer(res2, Seq(Row(1729440000L, 1729480953L))) + + val res3 = spark.sql("select unix_timestamp('1986-08-10 05:05:05','yyyy-MM-dd HH:mm:ss')") + assertOmniProjectHappened(res3) + checkAnswer(res3, Seq(Row(524001905L))) + + val res4 = spark.sql("select unix_timestamp('2086-08-10 05:05:05','yyyy-MM-dd HH:mm:ss')") + assertOmniProjectHappened(res4) + checkAnswer(res4, Seq(Row(3679765505L))) + } + + test("Test from_unixtime Function") { + spark.conf.set("spark.sql.optimizer.excludedRules", "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") + spark.conf.set("spark.sql.session.timeZone", "Asia/Shanghai") + spark.conf.set("spark.sql.legacy.timeParserPolicy", "CORRECTED") + val res1 = spark.sql("select from_unixtime(1, 'yyyy-MM-dd HH:mm:ss'), from_unixtime(1, 'yyyy-MM-dd')") + assertOmniProjectHappened(res1) + checkAnswer(res1, Seq(Row("1970-01-01 08:00:01", "1970-01-01"))) + + val res2 = spark.sql("select from_unixtime(524001905, 'yyyy-MM-dd HH:mm:ss'), from_unixtime(524001905, 'yyyy-MM-dd')") + assertOmniProjectHappened(res2) + checkAnswer(res2, Seq(Row("1986-08-10 05:05:05", "1986-08-10"))) + + val res3 = spark.sql("select from_unixtime(3679765505, 'yyyy-MM-dd HH:mm:ss'), from_unixtime(3679765505, 'yyyy-MM-dd')") + assertOmniProjectHappened(res3) + checkAnswer(res3, Seq(Row("2086-08-10 05:05:05", "2086-08-10"))) + } + private def assertOmniProjectHappened(res: DataFrame) = { val executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index f83edb9ca91b045e3ff6c419697cff381bf31bcc..ea52aca621e61ce4c48cf794fc4ddce06a24c632 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -61,7 +61,7 @@ class CombiningLimitsSuite extends PlanTest { comparePlans(optimized1, expected1) // test child max row > limit. - val query2 = testRelation.select().groupBy()(count(1)).limit(0).analyze + val query2 = testRelation2.select($"x").groupBy($"x")(count(1)).limit(1).analyze val optimized2 = Optimize.execute(query2) comparePlans(optimized2, query2) diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala index aaa244cdf65c04699ba2bf4c5c443e8727faf9f5..e1c620e1c0448b85f9d2164f7cda749a6712c5ce 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala @@ -43,7 +43,7 @@ class MergeSubqueryFiltersSuite extends PlanTest { } private def extractorExpression(cteIndex: Int, output: Seq[Attribute], fieldIndex: Int) = { - GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output)), fieldIndex) + GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output, isStreaming = false)), fieldIndex) .as("scalarsubquery()") } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala index c34ff5bb15acf36c4dcd96f57c70554f6e974ce9..d1b295d5c5dda37432d08778c8a92e7e05398091 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala @@ -19,17 +19,15 @@ package org.apache.spark.sql.execution.adaptive import java.io.File import java.net.URI - import org.apache.logging.log4j.Level import org.scalatest.PrivateMethodTester import org.scalatest.time.SpanSugar._ - import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec} +import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnaryExecNode, UnionExec} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource @@ -37,7 +35,7 @@ import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter -import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate +import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLExecutionStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -1122,13 +1120,21 @@ class AdaptiveQueryExecSuite test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + var plan: SparkPlan = null + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + plan = qe.executedPlan + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } + spark.listenerManager.register(listener) withTable("t1") { - val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan - assert(plan.isInstanceOf[CommandResultExec]) - val commandResultExec = plan.asInstanceOf[CommandResultExec] - assert(commandResultExec.commandPhysicalPlan.isInstanceOf[DataWritingCommandExec]) - assert(commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec] - .child.isInstanceOf[AdaptiveSparkPlanExec]) + val format = classOf[NoopDataSource].getName + Seq((0, 1)).toDF("x", "y").write.format(format).mode("overwrite").save() + sparkContext.listenerBus.waitUntilEmpty() + assert(plan.isInstanceOf[V2TableWriteExec]) + assert(plan.asInstanceOf[V2TableWriteExec].child.isInstanceOf[AdaptiveSparkPlanExec]) + spark.listenerManager.unregister(listener) } } } @@ -1174,13 +1180,12 @@ class AdaptiveQueryExecSuite withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { withTable("t1") { - var checkDone = false + var commands: Seq[SparkPlanInfo] = Seq.empty val listener = new SparkListener { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { - case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) => - assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand") - checkDone = true + case start: SparkListenerSQLExecutionStart => + commands = commands ++ Seq(start.sparkPlanInfo) case _ => // ignore other events } } @@ -1189,7 +1194,12 @@ class AdaptiveQueryExecSuite try { sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() spark.sparkContext.listenerBus.waitUntilEmpty() - assert(checkDone) + assert(commands.size == 3) + assert(commands.head.nodeName == "Execute CreateDataSourceTableAsSelectCommand") + assert(commands(1).nodeName == "Execute InsertIntoHadoopFsRelationCommand") + assert(commands(1).children.size == 1) + assert(commands(1).children.head.nodeName == "WriteFiles") + assert(commands(2).nodeName == "CommandResult") } finally { spark.sparkContext.removeSparkListener(listener) } @@ -1584,9 +1594,8 @@ class AdaptiveQueryExecSuite var noLocalread: Boolean = false val listener = new QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - qe.executedPlan match { + stripAQEPlan(qe.executedPlan) match { case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) => - assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec]) noLocalread = collect(plan) { case exec: AQEShuffleReadExec if exec.isLocalRead => exec }.isEmpty diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml index 4376a89bef1c869c7f9feb19bc24f381bd01a9c0..e949cea2d673a51d90262735fa860ab3d2b46ebe 100644 --- a/omnioperator/omniop-spark-extension/pom.xml +++ b/omnioperator/omniop-spark-extension/pom.xml @@ -8,13 +8,13 @@ com.huawei.kunpeng boostkit-omniop-spark-parent pom - 3.3.1-1.6.0 + 3.4.3-1.6.0 BoostKit Spark Native Sql Engine Extension Parent Pom 2.12.10 2.12 - 3.3.1 + 3.4.3 3.2.2 UTF-8 UTF-8