diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index f5ec23374b2cd7c661abadef42450b2846649827..e0a2359e0f376e10337edb8aba83efdb558ef8d9 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -416,6 +416,10 @@ object OmniExpressionAdaptor extends Logging { case concat: Concat => getConcatJsonStr(concat, exprsIndexMap) + + case concatWs: ConcatWs => + getConcatWsJsonStr(concatWs, exprsIndexMap) + case greatest: Greatest => getGreatestJsonStr(greatest, exprsIndexMap) @@ -435,6 +439,13 @@ object OmniExpressionAdaptor extends Logging { .put("arguments", new JsonArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(inStr.str, exprsIndexMap)) .put(rewriteToOmniJsonExpressionLiteralJsonObject(inStr.substr, exprsIndexMap))) + case rlike: RLike => + new JsonObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", rlike.dataType) + .put("function_name", "RLike") + .put("arguments", new JsonArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(rlike.left, exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(rlike.right, exprsIndexMap))) + // for floating numbers normalize case normalizeNaNAndZero: NormalizeNaNAndZero => new JsonObject().put("exprType", "FUNCTION") @@ -516,6 +527,18 @@ object OmniExpressionAdaptor extends Logging { } } + case regExpReplace: RegExpReplace => + getRegExpReplaceStr(regExpReplace, exprsIndexMap) + + case trim: StringTrim => + getTrimStr(trim, exprsIndexMap) + + case floor: Floor => + new JsonObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", floor.dataType) + .put("function_name", "floor") + .put("arguments", new JsonArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(floor.child, exprsIndexMap))) + case staticInvoke: StaticInvoke if ColumnarPluginConfig.getSessionConf.enableOmniStaticInvoke => { val funcName = staticInvoke.functionName @@ -549,6 +572,34 @@ object OmniExpressionAdaptor extends Logging { } } + private def getRegExpReplaceStr(regExpReplace: RegExpReplace, exprsIndexMap: Map[ExprId, Int]): JsonObject = { + val children: Seq[Expression] = regExpReplace.children + val arguments = new JsonArray() + for (i <- children.indices) { + arguments.put(rewriteToOmniJsonExpressionLiteralJsonObject(children(i), exprsIndexMap)) + } + new JsonObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", regExpReplace.dataType) + .put("function_name", "regexpReplace") + .put("arguments", arguments) + } + + private def getTrimStr(trim: StringTrim, exprsIndexMap: Map[ExprId, Int]): JsonObject = { + val children: Seq[Expression] = trim.children + val arguments = new JsonArray() + arguments.put(rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap)) + if (children.size == 1) { + arguments.put(new JsonObject().put("exprType", "LITERAL").put("dataType", 15).put("isNull", false) + .put("value", " ").put("width", 1)) + } else { + arguments.put(rewriteToOmniJsonExpressionLiteralJsonObject(children(1), exprsIndexMap)) + } + new JsonObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", trim.dataType) + .put("function_name", "trim") + .put("arguments", arguments) + } + private def getJsonExprArgumentsByChildren(children: Seq[Expression], exprsIndexMap: Map[ExprId, Int]): JsonArray = { val size = children.size @@ -591,6 +642,39 @@ object OmniExpressionAdaptor extends Logging { res } + private def getConcatWsJsonStr(concatWs: ConcatWs, exprsIndexMap: Map[ExprId, Int]): JsonObject = { + val children: Seq[Expression] = concatWs.children + checkInputDataTypes(children) + + val separator = rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap) + val res = new JsonObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", concatWs.dataType) + .put("function_name", "concatWs") + if (children.length == 1) { + res.put("arguments", new JsonArray().put(separator) + .put(createNullLiteralJson(concatWs.dataType)) + .put(createNullLiteralJson(concatWs.dataType))) + } else if (children.length == 2) { + res.put("arguments", new JsonArray().put(separator) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(children(1), exprsIndexMap)) + .put(createNullLiteralJson(concatWs.dataType))) + } else { + res.put("arguments", new JsonArray().put(separator) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(children(1), exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(children(2), exprsIndexMap))) + for (i <- 3 until children.length) { + val preResJson = new JsonObject().addAll(res) + res.put("arguments", new JsonArray().put(separator).put(preResJson) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(children(i), exprsIndexMap))) + } + } + res + } + + private def createNullLiteralJson(dataType: DataType): JsonObject = { + new JsonObject().put("exprType", "LITERAL").addOmniExpJsonType("dataType", dataType).put("isNull", true) + } + private def getGreatestJsonStr(greatest: Greatest, exprsIndexMap: Map[ExprId, Int]): JsonObject = { if (greatest.children.length != 2) { throw new UnsupportedOperationException(s"Number of parameters is ${greatest.children.length}. " + diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarFuncSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarFuncSuite.scala index 1c85618f384bb66d49725d7bb914f97d0141510a..ae401004ed4d460dd3370af4b7e7984e28c3e38c 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarFuncSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarFuncSuite.scala @@ -125,6 +125,61 @@ class ColumnarFuncSuite extends ColumnarSparkPlanTest { checkAnswer(res3, Seq(Row("2086-08-10 05:05:05", "2086-08-10"))) } + test("Test concat_ws Function") { + spark.conf.set("spark.sql.optimizer.excludedRules", "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") + val res1 = spark.sql("select concat_ws('--', 'aaa', 'bbb')") + assertOmniProjectHappened(res1) + checkAnswer(res1, Seq(Row("aaa--bbb"))) + + val res2 = spark.sql("select concat_ws('一一', '哈哈哈', '啦啦啦')") + assertOmniProjectHappened(res2) + checkAnswer(res2, Seq(Row("哈哈哈一一啦啦啦"))) + } + + test("Test regexp Function") { + spark.conf.set("spark.sql.optimizer.excludedRules", "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") + val res1 = spark.sql("select 'hello' regexp 'hel.o' as test1") + assertOmniProjectHappened(res1) + checkAnswer(res1, Seq(Row(true))) + + val res2 = spark.sql("select 'aaa' regexp 'a{2,4}' as test2") + assertOmniProjectHappened(res2) + checkAnswer(res2, Seq(Row(true))) + } + + test("Test regexp_replace Function") { + spark.conf.set("spark.sql.optimizer.excludedRules", "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") + val res1 = spark.sql("select regexp_replace('abcabc', 'a', 'x')") + assertOmniProjectHappened(res1) + checkAnswer(res1, Seq(Row("xbcxbc"))) + + val res2 = spark.sql("select regexp_replace('你好世界', '好', '差')") + assertOmniProjectHappened(res2) + checkAnswer(res2, Seq(Row("你差世界"))) + } + + test("Test trim Function") { + spark.conf.set("spark.sql.optimizer.excludedRules", "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") + val res1 = spark.sql("select trim(' hello ')") + assertOmniProjectHappened(res1) + checkAnswer(res1, Seq(Row("hello"))) + + val res2 = spark.sql("select trim(both '空' from '空稀少珍稀空')") + assertOmniProjectHappened(res2) + checkAnswer(res2, Seq(Row("稀少珍稀"))) + } + + test("Test floor Function") { + spark.conf.set("spark.sql.optimizer.excludedRules", "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") + val res1 = spark.sql("select floor(1.9)") + assertOmniProjectHappened(res1) + checkAnswer(res1, Seq(Row(1L))) + + val res2 = spark.sql("select floor(-1.9)") + assertOmniProjectHappened(res2) + checkAnswer(res2, Seq(Row(-2L))) + } + private def assertOmniProjectHappened(res: DataFrame) = { val executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan")