diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index 063a83b01b479b16653d2a33ca2c56da2f75898f..bb3bdc2fb25a99c7b6b6ddfaab50bdbd3fc259bf 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -160,6 +160,8 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { def enableOmniUnixTimeFunc: Boolean = conf.getConf(ENABLE_OMNI_UNIXTIME_FUNCTION) def enableVecPredicateFilter: Boolean = conf.getConf(ENABLE_VEC_PREDICATE_FILTER) + + def enableOmniStaticInvoke: Boolean = conf.getConf(ENABLED_OMNI_STATICINVOKE) } @@ -541,4 +543,10 @@ object ColumnarPluginConfig { .doc("enable vectorized predicate filtering") .booleanConf .createWithDefault(false) + + val ENABLED_OMNI_STATICINVOKE = buildConf("spark.omni.sql.columnar.staticInvoke") + .internal() + .doc("enable omni staticInvoke") + .booleanConf + .createWithDefault(false) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index f4a088aa58774e97496dbc0f83f5671c382a76c5..f5ec23374b2cd7c661abadef42450b2846649827 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -516,7 +516,7 @@ object OmniExpressionAdaptor extends Logging { } } - case staticInvoke: StaticInvoke => + case staticInvoke: StaticInvoke if ColumnarPluginConfig.getSessionConf.enableOmniStaticInvoke => { val funcName = staticInvoke.functionName funcName match { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarBasicFunctionSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarBasicFunctionSuite.scala index 2d765d166e7f5bedf7b23ff724110feb47da11a2..c7c9ff79f9c094ae181861d69bb2d1010838d177 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarBasicFunctionSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarBasicFunctionSuite.scala @@ -51,33 +51,35 @@ class ColumnarBasicFunctionSuite extends QueryTest with SharedSparkSession { } test("Unsupported StaticInvoke function varcharTypeWriteSideCheck") { - val drop = spark.sql("drop table if exists source_table") - drop.collect() - val createTable = spark.sql("create table source_table" + - "(id int, name string, amount int) using parquet") - createTable.collect() - val dropNP = spark.sql("drop table if exists target_table") - dropNP.collect() - val createTableNP = spark.sql("create table target_table" + - "(name varchar(5), total_amount long) using parquet") - createTableNP.collect() - var insert = spark.sql("insert into table source_table values" + - "(1, 'Bob', 250), (2, '測試中文', 250), (3, NULL, 250), (4, 'abide', 250)") - insert.collect() - insert = spark.sql("insert into table target_table select UPPER(name) as name, SUM(amount) as total_amount from " + - "source_table where amount >= 10 GROUP BY UPPER(name)") - insert.collect() - assert(insert.queryExecution.executedPlan.toString().contains("OmniColumnarHashAggregate"), - "use columnar data writing command") - val columnarDataWrite = insert.queryExecution.executedPlan - .find({ - case _: HashAggregateExec => true - case _ => false - }) - assert(columnarDataWrite.isEmpty, "use columnar data writing command") - val select = spark.sql("select * from target_table order by name") - val runRows = select.collect() - val expectedRows = Seq(Row(null, 250), Row("ABIDE", 250), Row("BOB", 250), Row("測試中文", 250)) - assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + withSQLConf(("spark.omni.sql.columnar.staticInvoke", "true")) { + val drop = spark.sql("drop table if exists source_table") + drop.collect() + val createTable = spark.sql("create table source_table" + + "(id int, name string, amount int) using parquet") + createTable.collect() + val dropNP = spark.sql("drop table if exists target_table") + dropNP.collect() + val createTableNP = spark.sql("create table target_table" + + "(name varchar(5), total_amount long) using parquet") + createTableNP.collect() + var insert = spark.sql("insert into table source_table values" + + "(1, 'Bob', 250), (2, '測試中文', 250), (3, NULL, 250), (4, 'abide', 250)") + insert.collect() + insert = spark.sql("insert into table target_table select UPPER(name) as name, SUM(amount) as total_amount from " + + "source_table where amount >= 10 GROUP BY UPPER(name)") + insert.collect() + assert(insert.queryExecution.executedPlan.toString().contains("OmniColumnarHashAggregate"), + "use columnar data writing command") + val columnarDataWrite = insert.queryExecution.executedPlan + .find({ + case _: HashAggregateExec => true + case _ => false + }) + assert(columnarDataWrite.isEmpty, "use columnar data writing command") + val select = spark.sql("select * from target_table order by name") + val runRows = select.collect() + val expectedRows = Seq(Row(null, 250), Row("ABIDE", 250), Row("BOB", 250), Row("測試中文", 250)) + assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + } } } \ No newline at end of file