diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 26555cc23b21a78180403ada2e9a3c3921543c6b..5a184d84458ded27ffd782a37ff260d5a4696098 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -300,9 +300,10 @@ object OmniExpressionAdaptor extends Logging { private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = { def isDecimalOrStringType(dataType: DataType): Boolean = (dataType.isInstanceOf[DecimalType]) || (dataType.isInstanceOf[StringType]) - // not support Cast(string as !(decimal/string)) and Cast(!(decimal/string) as string) + // not support Cast(string as !(decimal/string)) and Cast(!(decimal/string) as string) and (string as !(Date/string)) if ((cast.dataType.isInstanceOf[StringType] && !isDecimalOrStringType(cast.child.dataType)) || - (!isDecimalOrStringType(cast.dataType) && cast.child.dataType.isInstanceOf[StringType])) { + (!isDecimalOrStringType(cast.dataType) && !cast.dataType.isInstanceOf[DateType] + && cast.child.dataType.isInstanceOf[StringType])) { throw new UnsupportedOperationException(s"Unsupported expression: $expr") } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarStringCastSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarStringCastSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..dd189d73e15159a10b614c503f0715d50ba4f0f4 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarStringCastSuite.scala @@ -0,0 +1,208 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.forsql + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.execution.{ColumnarProjectExec, ColumnarSparkPlanTest, ProjectExec} + +class ColumnarStringCastSuite extends ColumnarSparkPlanTest{ + import testImplicits.{localSeqToDatasetHolder, newProductEncoder} + + private var stringDateDf: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + stringDateDf = Seq[(String, String, String, String, String, String, String)]( + ("2010", "2010-1", "2010-01", "2010-02-4", "2010-02-04", "2022-01-021234", "20 10"), + (" 2011 ", " 2011-1 ", " 2011-01 ", " 2011-03-4 ", " 2010-03-04 ", "2022-01-021 12", "abcd"), + (null, "2010-8", null, "2010-08-4", null, null, null), + ("2012", null, "2012-11", null, "2012-04-10", "2022-01-02abd", "2010-1 3") + ).toDF("yyyy", "yyyy-m", "yyyy-mm", "yyyy-mm-d", "yyyy-mm-dd", "yyyy-mm-dd-x", "invalid") + + stringDateDf.createOrReplaceTempView("string_date") + } + + // string to date + // cast yyyy to date + test("Test ColumnarProjectExec happen and result is same as native " + + "when cast yyyy to date") { + val sql = "select cast(yyyy as date) from string_date" + val expected = Seq( + Row("2010-01-01"), + Row("2011-01-01"), + Row("2012-01-01"), + Row(null) + ) + checkResult(sql, expected) + } + + // cast yyyy-m to date + test("Test ColumnarProjectExec happen and result is same as native " + + "when cast yyyy-m to date") { + val sql = "select cast(yyyy-m as date) from string_date" + val expected = Seq( + Row("2010-01-01"), + Row("2010-08-01"), + Row("2011-01-01"), + Row(null) + ) + checkResult(sql, expected) + } + + // cast yyyy-mm to date + test("Test ColumnarProjectExec happen and result is same as native " + + "when cast yyyy-mm to date") { + val sql = "select cast(yyyy-mm as date) from string_date" + val expected = Seq( + Row("2010-01-01"), + Row("2011-01-01"), + Row("2012-11-01"), + Row(null) + ) + checkResult(sql, expected) + } + + // cast yyyy-mm-d to date + test("Test ColumnarProjectExec happen and result is same as native " + + "when cast yyyy-mm-d to date") { + val sql = "select cast(yyyy-mm-d as date) from string_date" + val expected = Seq( + Row("2010-02-04"), + Row("2010-08-04"), + Row("2011-03-04"), + Row(null) + ) + checkResult(sql, expected) + } + + // cast yyyy-mm-dd to date + test("Test ColumnarProjectExec happen and result is same as native " + + "when cast yyyy-mm-dd to date") { + val sql = "select cast(yyyy-mm-dd as date) from string_date" + val expected = Seq( + Row("2010-02-04"), + Row("2011-03-04"), + Row("2012-04-10"), + Row(null) + ) + checkResult(sql, expected) + } + + // cast yyyy-mm-dd* to date + test("Test ColumnarProjectExec happen and result is same as native " + + "when cast yyyy-mm-dd* to date") { + val sql = "select cast(yyyy-mm-dd-x as date) from string_date" + val expected = Seq( + Row("2022-01-02"), + Row("2022-01-02"), + Row("2022-01-02"), + Row(null) + ) + checkResult(sql, expected) + } + + // cast literal string yyyy to date + test("Test ColumnarProjectExec happen and result is same as native " + + "when cast literal string yyyy to date") { + val sql = "select cast('1998' as date)" + val expected = Seq( + Row("1998-01-01") + ) + checkResult(sql, expected) + } + + // cast literal string yyyy-m to date + test("Test ColumnarProjectExec happen and result is same as native " + + "when cast literal string yyyy-m to date") { + val sql = "select cast('1998-8' as date)" + val expected = Seq( + Row("1998-08-01") + ) + checkResult(sql, expected) + } + + // cast literal string yyyy-mm to date + test("Test ColumnarProjectExec happen and result is same as native " + + "when cast literal string yyyy-mm to date") { + val sql = "select cast('1998-10' as date)" + val expected = Seq( + Row("1998-10-01") + ) + checkResult(sql, expected) + } + + // cast literal string yyyy-mm-d to date + test("Test ColumnarProjectExec happen and result is same as native " + + "when cast literal string yyyy-mm-d to date") { + val sql = "select cast('1998-10-3' as date)" + val expected = Seq( + Row("1998-10-03") + ) + checkResult(sql, expected) + } + + // cast literal string yyyy-mm-dd to date + test("Test ColumnarProjectExec happen and result is same as native " + + "when cast literal string yyyy-mm-dd to date") { + val sql = "select cast('1998-10-10' as date)" + val expected = Seq( + Row("1998-10-10") + ) + checkResult(sql, expected) + } + + // cast literal string yyyy-mm-dd* to date + test("Test ColumnarProjectExec happen and result is same as native " + + "when cast literal string yyyy-mm-dd* to date") { + val sql = "select cast('1998-10-101234' as date)" + val expected = Seq( + Row("1998-10-10") + ) + checkResult(sql, expected) + } + + // cast invalid string to date + test("Test ColumnarProjectExec happen and result is same as native " + + "when cast invalid string to date") { + val sql = "select cast(invalid as date) from string_date" + val expected = Seq( + Row(null), + Row(null), + Row(null), + Row(null) + ) + checkResult(sql, expected) + } + + def checkResult(sql: String, expected: Seq[Row], isUseOmni: Boolean = true): Unit = { + def assertOmniProjectHappen(res: DataFrame): Unit = { + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") + } + def assertOmniProjectNotHappen(res: DataFrame): Unit = { + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"ColumnarProjectExec happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isDefined, s"ProjectExec not happened, executedPlan as follows: \n$executedPlan") + } + val res = spark.sql(sql) + if (isUseOmni) assertOmniProjectHappen(res) else assertOmniProjectNotHappen(res) + checkAnswer(res, expected) + } +}