diff --git a/README.md b/README.md index 97752cfa5c6e3ce50c899d757946af1bb43e14b4..c47c81a448759ea3dc81024ee69983380ad4a84d 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,5 @@ # boostkit-bigdata BoostKit Acceleration Packages —— Big Data Component Adaptation Layer + +## Notice +The boostkit-bigdata repo contains acceleration plugins and patches for multiple pieces of open source software including openLooKeng, Apache Spark, Hive, and HBase. Using these plugins and patches depends on other pieces of open source software (which are available in the central repo). You shall understand and agree that when using the other pieces of open source software, you shall strictly comply with their open source licenses and fulfill the obligations specified in the licenses. Any vulnerabilities and security issues of the other open source software are resolved by the corresponding upstream communities based on their own vulnerability and security response mechanisms. Please pay attention to the notifications and version updates released by the upstream communities. The Kunpeng Compute community does not assume any responsibility for the vulnerabilities and security issues of the preceding open source software. diff --git a/omnidata/omnidata-hive-connector/build.sh b/omnidata/omnidata-hive-connector/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..98c426e22cc430cc1268816c9355bc13d98b8c9f --- /dev/null +++ b/omnidata/omnidata-hive-connector/build.sh @@ -0,0 +1,34 @@ +#!/bin/bash +mvn clean package +jar_name=`ls -n connector/target/*.jar | grep hive-exec | awk -F ' ' '{print$9}' | awk -F '/' '{print$3}'` +dir_name=`ls -n connector/target/*.jar | grep hive-exec | awk -F ' ' '{print$9}' | awk -F '/' '{print$3}' | awk -F '.jar' '{print$1}'` +rm -r $dir_name +rm -r $dir_name.zip +mkdir -p $dir_name +cp connector/target/$jar_name $dir_name +cd $dir_name +wget https://mirrors.huaweicloud.com/repository/maven/org/bouncycastle/bcpkix-jdk15on/1.68/bcpkix-jdk15on-1.68.jar +wget https://mirrors.huaweicloud.com/repository/maven/org/bouncycastle/bcprov-jdk15on/1.68/bcprov-jdk15on-1.68.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/google/guava/guava/31.1-jre/guava-31.1-jre.jar +wget https://mirrors.huaweicloud.com/repository/maven/io/hetu/core/hetu-transport/1.6.1/hetu-transport-1.6.1.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/core/jackson-annotations/2.12.4/jackson-annotations-2.12.4.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/core/jackson-core/2.12.4/jackson-core-2.12.4.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/core/jackson-databind/2.12.4/jackson-databind-2.12.4.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-guava/2.12.4/jackson-datatype-guava-2.12.4.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-jdk8/2.12.4/jackson-datatype-jdk8-2.12.4.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-joda/2.12.4/jackson-datatype-joda-2.12.4.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-jsr310/2.12.4/jackson-datatype-jsr310-2.12.4.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/module/jackson-module-parameter-names/2.12.4/jackson-module-parameter-names-2.12.4.jar +wget https://mirrors.huaweicloud.com/repository/maven/org/jasypt/jasypt/1.9.3/jasypt-1.9.3.jar +wget https://mirrors.huaweicloud.com/repository/maven/org/openjdk/jol/jol-core/0.2/jol-core-0.2.jar +wget https://repo1.maven.org/maven2/io/airlift/joni/2.1.5.3/joni-2.1.5.3.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/esotericsoftware/kryo-shaded/4.0.2/kryo-shaded-4.0.2.jar +wget https://mirrors.huaweicloud.com/repository/maven/io/airlift/log/0.193/log-0.193.jar +wget https://mirrors.huaweicloud.com/repository/maven/io/perfmark/perfmark-api/0.23.0/perfmark-api-0.23.0.jar +wget https://mirrors.huaweicloud.com/repository/maven/io/hetu/core/presto-main/1.6.1/presto-main-1.6.1.jar +wget https://mirrors.huaweicloud.com/repository/maven/io/hetu/core/presto-spi/1.6.1/presto-spi-1.6.1.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/google/protobuf/protobuf-java/3.12.0/protobuf-java-3.12.0.jar +wget https://mirrors.huaweicloud.com/repository/maven/io/airlift/slice/0.38/slice-0.38.jar +cd .. +zip -r -o $dir_name.zip $dir_name +rm -r $dir_name \ No newline at end of file diff --git a/omnidata/omnidata-hive-connector/connector/src/main/java/org/apache/hadoop/hive/ql/omnidata/physical/NdpPlanResolver.java b/omnidata/omnidata-hive-connector/connector/src/main/java/org/apache/hadoop/hive/ql/omnidata/physical/NdpPlanResolver.java index 208d1436588812387fd4d72d45999ba1f1ab4428..e27f0b59969ecfe3f16765c06c6b0cf40d152f0d 100644 --- a/omnidata/omnidata-hive-connector/connector/src/main/java/org/apache/hadoop/hive/ql/omnidata/physical/NdpPlanResolver.java +++ b/omnidata/omnidata-hive-connector/connector/src/main/java/org/apache/hadoop/hive/ql/omnidata/physical/NdpPlanResolver.java @@ -205,6 +205,10 @@ public class NdpPlanResolver implements PhysicalPlanResolver { // get OmniData filter expression if (isPushDownFilter) { filter = getOmniDataFilter(omniDataPredicate); + if (!filter.isPresent()) { + isPushDownFilter = false; + isPushDownAgg = false; + } } // get OmniData agg expression if (isPushDownAgg) { @@ -304,7 +308,6 @@ public class NdpPlanResolver implements PhysicalPlanResolver { // The AGG does not support part push down isPushDownAgg = false; } else if (mode.equals(NdpFilter.NdpFilterMode.NONE)) { - isPushDownFilter = false; return Optional.empty(); } OmniDataFilter omniDataFilter = new OmniDataFilter(omniDataPredicate); @@ -312,7 +315,6 @@ public class NdpPlanResolver implements PhysicalPlanResolver { RowExpression filterRowExpression = omniDataFilter.getFilterExpression( (ExprNodeGenericFuncDesc) filterDesc.clone(), ndpFilter); if (filterRowExpression == null) { - isPushDownFilter = false; return Optional.empty(); } return Optional.of(filterRowExpression); diff --git a/omnidata/omnidata-hive-connector/connector/src/main/java/org/apache/hadoop/hive/ql/omnidata/reader/OmniDataAdapter.java b/omnidata/omnidata-hive-connector/connector/src/main/java/org/apache/hadoop/hive/ql/omnidata/reader/OmniDataAdapter.java index 4d7c90fa19efb84fb90d74cdc796898901d18ae9..8d1bcedefd8340e8aaf01cb548adcd1929d1a1d6 100644 --- a/omnidata/omnidata-hive-connector/connector/src/main/java/org/apache/hadoop/hive/ql/omnidata/reader/OmniDataAdapter.java +++ b/omnidata/omnidata-hive-connector/connector/src/main/java/org/apache/hadoop/hive/ql/omnidata/reader/OmniDataAdapter.java @@ -49,12 +49,8 @@ import java.io.IOException; import java.io.Serializable; import java.net.InetAddress; import java.net.UnknownHostException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.LinkedList; -import java.util.List; -import java.util.Properties; -import java.util.Queue; +import java.util.*; + /** * Obtains data from OmniData through OmniDataAdapter and converts the data into Hive List. @@ -116,7 +112,19 @@ public class OmniDataAdapter implements Serializable { .getFileBlockLocations(fileSplit.getPath(), fileSplit.getStart(), fileSplit.getLength()); for (BlockLocation block : blockLocations) { for (String host : block.getHosts()) { - hosts.add(host); + if ("localhost".equals(host)) { + List dataNodeHosts = new ArrayList<>( + Arrays.asList(conf.get(NdpStatusManager.NDP_DATANODE_HOSTNAMES) + .split(NdpStatusManager.NDP_DATANODE_HOSTNAME_SEPARATOR))); + if (dataNodeHosts.size() > ndpReplicationNum) { + hosts.addAll(dataNodeHosts.subList(0, ndpReplicationNum)); + } else { + hosts.addAll(dataNodeHosts); + } + return hosts; + } else { + hosts.add(host); + } if (ndpReplicationNum == hosts.size()) { return hosts; } @@ -159,7 +167,6 @@ public class OmniDataAdapter implements Serializable { pages.addAll(page); } } while (!dataReader.isFinished()); - dataReader.close(); break; } catch (OmniDataException omniDataException) { LOGGER.warn("OmniDataAdapter failed node info [hostname :{}]", omniDataHost); @@ -195,7 +202,9 @@ public class OmniDataAdapter implements Serializable { LOGGER.error("OmniDataAdapter getBatchFromOmnidata() has error:", e); failedTimes++; } finally { - dataReader.close(); + if (dataReader != null) { + dataReader.close(); + } } } int retryTime = Math.min(TASK_FAILED_TIMES, omniDataHosts.size()); diff --git a/omnidata/omnidata-hiveudf-loader/etc/function-namespace/hive.properties b/omnidata/omnidata-hiveudf-loader/etc/function-namespace/hive.properties new file mode 100644 index 0000000000000000000000000000000000000000..7c1702063f7e81fdba240bb5b9d6319aa0a400db --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/etc/function-namespace/hive.properties @@ -0,0 +1,3 @@ +function-namespace-manager.name=hive-functions +external-functions.dir=src/test/resources/ +UdfTest com.test.udf.UdfTest \ No newline at end of file diff --git a/omnidata/omnidata-hiveudf-loader/pom.xml b/omnidata/omnidata-hiveudf-loader/pom.xml index 5a8a6b1d99d6ffc90a8dbfd0d5c57924a6e314bd..162bd544e71f489760e7ad46cf4deef2eff3dcbf 100644 --- a/omnidata/omnidata-hiveudf-loader/pom.xml +++ b/omnidata/omnidata-hiveudf-loader/pom.xml @@ -263,6 +263,52 @@ io.airlift bytecode + + org.testng + testng + test + + + org.mockito + mockito-core + 1.10.19 + + + io.hetu.core + presto-spi + test-jar + test + + + io.hetu.core + hetu-filesystem-client + test + + + io.hetu.core + hetu-metastore + test + + + org.jetbrains + annotations + test + + + io.hetu.core + presto-client + test + + + org.locationtech.jts + jts-core + test + + + io.hetu.core + hetu-common + test + diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/AbstractTestHiveFunctions.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/AbstractTestHiveFunctions.java new file mode 100644 index 0000000000000000000000000000000000000000..9593a59deb32524dea6e398b38b9b26f43f2c3d3 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/AbstractTestHiveFunctions.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions; + +import com.google.common.base.Splitter; +import com.google.common.io.Files; +import com.google.inject.Key; +import io.airlift.log.Logger; +import io.prestosql.server.testing.TestingPrestoServer; +import io.prestosql.spi.type.TimeZoneKey; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.TypeManager; +import io.prestosql.tests.TestingPrestoClient; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.BeforeClass; + +import java.io.File; +import java.util.Optional; + +import static io.prestosql.plugin.hive.functions.HiveFunctionsTestUtils.createTestingPrestoServer; +import static io.prestosql.testing.TestingSession.testSessionBuilder; +import static java.nio.charset.StandardCharsets.UTF_8; + +public abstract class AbstractTestHiveFunctions +{ + private static final Logger log = Logger + .get(AbstractTestHiveFunctions.class); + + protected TestingPrestoServer server; + protected TestingPrestoClient client; + protected TypeManager typeManager; + protected ClassLoader classLoader; + + @BeforeClass + public void setup() + throws Exception + { + // TODO: Use DistributedQueryRunner to perform query + server = createTestingPrestoServer(); + client = new TestingPrestoClient(server, testSessionBuilder() + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey( + "America/Bahia_Banderas")).build()); + typeManager = server.getInstance(Key.get(TypeManager.class)); + classLoader = Thread.currentThread().getContextClassLoader(); + + if (getInitScript().isPresent()) { + String sql = Files.asCharSource( + getInitScript().get(), UTF_8).read(); + Iterable initQueries = Splitter.on("----\n") + .omitEmptyStrings().trimResults().split(sql); + for (@Language("SQL") String query : initQueries) { + log.debug("Executing %s", query); + client.execute(query); + } + } + } + + protected Optional getInitScript() + { + return Optional.empty(); + } + + public static class Column + { + private final Type type; + + private final Object[] values; + + private Column(Type type, Object[] values) + { + this.type = type; + this.values = values; + } + + public static Column of(Type type, Object... values) + { + return new Column(type, values); + } + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/HiveFunctionsTestUtils.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/HiveFunctionsTestUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..b06aa3db4cc9336fe3fa6fffb549473c4bd7aeee --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/HiveFunctionsTestUtils.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions; + +import com.google.common.collect.ImmutableMap; +import com.google.inject.Key; +import io.hetu.core.common.filesystem.TempFolder; +import io.hetu.core.filesystem.HetuFileSystemClientPlugin; +import io.hetu.core.metastore.HetuMetastorePlugin; +import io.prestosql.metadata.FunctionAndTypeManager; +import io.prestosql.metastore.HetuMetaStoreManager; +import io.prestosql.plugin.memory.MemoryPlugin; +import io.prestosql.server.testing.TestingPrestoServer; + +import java.util.HashMap; +import java.util.Map; + +public final class HiveFunctionsTestUtils +{ + private HiveFunctionsTestUtils() + { + } + + public static TestingPrestoServer createTestingPrestoServer() + throws Exception + { + TempFolder folder = new TempFolder().create(); + Runtime.getRuntime().addShutdownHook(new Thread(folder::close)); + HashMap metastoreConfig = new HashMap<>(); + metastoreConfig.put("hetu.metastore.type", "hetufilesystem"); + metastoreConfig.put("hetu.metastore.hetufilesystem.profile-name", + "default"); + metastoreConfig.put("hetu.metastore.hetufilesystem.path", + folder.newFolder("metastore").getAbsolutePath()); + + TestingPrestoServer server = new TestingPrestoServer(); + server.installPlugin(new HetuFileSystemClientPlugin()); + server.installPlugin(new MemoryPlugin()); + server.installPlugin(new HetuMetastorePlugin()); + server.installPlugin(new HiveFunctionNamespacePlugin()); + server.loadMetastore(metastoreConfig); + + server.createCatalog("memory", "memory", + ImmutableMap.of("memory.spill-path", + folder.newFolder("memory-connector") + .getAbsolutePath())); + + FunctionAndTypeManager functionAndTypeManager = + server.getInstance(Key.get(FunctionAndTypeManager.class)); + functionAndTypeManager.loadFunctionNamespaceManager( + new HetuMetaStoreManager(), + "hive-functions", + "hive", + getNamespaceManagerCreationProperties()); + server.refreshNodes(); + return server; + } + + public static Map getNamespaceManagerCreationProperties() + { + HashMap namespaceManagerCreationPropertie = new HashMap<>(); + namespaceManagerCreationPropertie.put("external-functions.dir", "test/test"); + return namespaceManagerCreationPropertie; + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestFunctionRegistry.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestFunctionRegistry.java new file mode 100644 index 0000000000000000000000000000000000000000..47bea1f2ca6f6d74876e690099ba9df32784c530 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestFunctionRegistry.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions; + +import io.prestosql.spi.connector.QualifiedObjectName; +import org.testng.annotations.Test; + +import static org.mockito.Mockito.mock; + +public class TestFunctionRegistry +{ + @Test + public void testAddFunction() throws ClassNotFoundException + { + HiveFunctionRegistry mockHiveFunctionRegistry = mock(HiveFunctionRegistry.class); + Class functionClass = mockHiveFunctionRegistry.getClass(mock(QualifiedObjectName.class)); + + try { + FunctionRegistry.addFunction("test", functionClass); + } + catch (NullPointerException ignored) { + } + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestHiveFunction.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestHiveFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..8a2c1c0b4568d0a8344df3a68618a4927eb849a7 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestHiveFunction.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions; + +import io.prestosql.spi.connector.QualifiedObjectName; +import io.prestosql.spi.function.FunctionMetadata; +import io.prestosql.spi.function.Signature; +import io.prestosql.spi.type.TypeSignature; +import io.prestosql.testing.assertions.Assert; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; + +import static io.prestosql.spi.function.FunctionKind.SCALAR; +import static java.util.Collections.emptyList; +import static org.mockito.Mockito.mock; + +public class TestHiveFunction +{ + @Test + public void testGetName() + { + QualifiedObjectName mockName = mock(QualifiedObjectName.class); + TypeSignature mockReturnType = mock(TypeSignature.class); + List argumentTypes = new ArrayList<>(); + Signature signature = new Signature(mockName, SCALAR, emptyList(), + emptyList(), mockReturnType, argumentTypes, false); + TmpHiveFunction tmpHiveFunction = new TmpHiveFunction(mockName, + signature, false, false, false, "test"); + + Assert.assertEquals(mockName, tmpHiveFunction.getName()); + Assert.assertEquals(signature, tmpHiveFunction.getSignature()); + Assert.assertFalse(tmpHiveFunction.isDeterministic()); + Assert.assertFalse(tmpHiveFunction.isCalledOnNullInput()); + Assert.assertEquals("test", tmpHiveFunction.getDescription()); + Assert.assertFalse(tmpHiveFunction.isHidden()); + } + + static class TmpHiveFunction + extends HiveFunction + { + TmpHiveFunction(final QualifiedObjectName name, final Signature signature, + final boolean hidden, final boolean deterministic, + final boolean calledOnNullInput, final String description) + { + super(name, signature, hidden, deterministic, + calledOnNullInput, description); + } + + @Override + public FunctionMetadata getFunctionMetadata() + { + return null; + } + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestHiveFunctionErrorCode.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestHiveFunctionErrorCode.java new file mode 100644 index 0000000000000000000000000000000000000000..06def4556ed03efa9e7ffcfdc5a9e6c873cdaec2 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestHiveFunctionErrorCode.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions; + +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.connector.QualifiedObjectName; +import io.prestosql.testing.assertions.Assert; +import org.testng.annotations.Test; + +import java.util.regex.Pattern; + +import static io.prestosql.plugin.hive.functions.HiveFunctionErrorCode.HIVE_FUNCTION_EXECUTION_ERROR; +import static io.prestosql.plugin.hive.functions.HiveFunctionErrorCode.HIVE_FUNCTION_INITIALIZATION_ERROR; +import static io.prestosql.plugin.hive.functions.HiveFunctionErrorCode.unsupportedNamespace; +import static io.prestosql.spi.StandardErrorCode.FUNCTION_NOT_FOUND; +import static java.lang.String.format; +import static org.mockito.Mockito.mock; + +public class TestHiveFunctionErrorCode +{ + private Throwable mockThrowable = mock(Throwable.class); + private QualifiedObjectName mockQualifiedObjectName = mock(QualifiedObjectName.class); + + @Test + public void testFunctionNotFound() + { + String name = "test"; + + PrestoException result = HiveFunctionErrorCode.functionNotFound(name, mock(ClassNotFoundException.class)); + + Assert.assertTrue(result.getErrorCode().toString().startsWith( + FUNCTION_NOT_FOUND.toString())); + Assert.assertEquals(result.getMessage(), + format("Function %s not registered. %s", name, "null")); + } + + @Test + public void testInitializationError() + { + // Throwable t + PrestoException result = + HiveFunctionErrorCode.initializationError(mockThrowable); + + Assert.assertEquals(result.getMessage(), + HIVE_FUNCTION_INITIALIZATION_ERROR.toString()); + + // String filePath, Exception e + Assert.assertEquals(HiveFunctionErrorCode.initializationError("test", + mock(Exception.class)).getMessage(), + "Fail to read the configuration test. null"); + } + + @Test + public void testExecutionError() + { + PrestoException result = + HiveFunctionErrorCode.executionError(mockThrowable); + + Assert.assertEquals(result.getMessage(), + HIVE_FUNCTION_EXECUTION_ERROR.toString()); + } + + @Test + public void testUnsupportedFunctionType() throws ClassNotFoundException + { + HiveFunctionRegistry mockHiveFunctionRegistry = mock(HiveFunctionRegistry.class); + Class functionClass = mockHiveFunctionRegistry.getClass(mockQualifiedObjectName); + + try { + HiveFunctionErrorCode.unsupportedFunctionType(functionClass); + } + catch (NullPointerException ignored) { + } + + try { + HiveFunctionErrorCode.unsupportedFunctionType(functionClass, mockThrowable); + } + catch (NullPointerException ignored) { + } + } + + @Test + public void testUnsupportedNamespace() + { + Assert.assertTrue(unsupportedNamespace(mockQualifiedObjectName).toString() + .contains("Hive udf unsupported namespace null. Its schema should be default.")); + } + + @Test + public void testInvalidParatemers() + { + Assert.assertTrue(Pattern.matches("The input path .* is invalid. .*", + HiveFunctionErrorCode.invalidParatemers("test", mock(Exception.class)).getMessage())); + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestHiveFunctionNamespaceManager.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestHiveFunctionNamespaceManager.java new file mode 100644 index 0000000000000000000000000000000000000000..344d7cd58fabe89911cdde7a1bd9c610c44f3b03 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestHiveFunctionNamespaceManager.java @@ -0,0 +1,370 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions; + +import com.google.common.collect.ImmutableList; +import io.prestosql.spi.Page; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.connector.QualifiedObjectName; +import io.prestosql.spi.function.BuiltInFunctionHandle; +import io.prestosql.spi.function.FunctionHandle; +import io.prestosql.spi.function.FunctionKind; +import io.prestosql.spi.function.FunctionNamespaceTransactionHandle; +import io.prestosql.spi.function.Signature; +import io.prestosql.spi.function.SqlInvokedFunction; +import io.prestosql.spi.type.DoubleType; +import io.prestosql.spi.type.TypeManager; +import io.prestosql.spi.type.TypeSignature; +import io.prestosql.testing.assertions.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.lang.reflect.AccessibleObject; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static io.prestosql.spi.function.FunctionKind.SCALAR; +import static java.util.Collections.emptyList; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; + +public class TestHiveFunctionNamespaceManager +{ + private QualifiedObjectName mockQualifiedObjectName = mock(QualifiedObjectName.class); + private BuiltInFunctionHandle mockBuiltInFunctionHandle = mock(BuiltInFunctionHandle.class); + private TypeManager mockTypeManager = mock(TypeManager.class); + private HiveFunctionNamespaceManager hiveFunctionNamespaceManager; + + @BeforeClass + public void setup() + { + HiveFunctionNamespacePlugin mockHiveFunctionNamespaceManager = + mock(HiveFunctionNamespacePlugin.class); + HiveFunctionRegistry mockHiveFunctionRegistry = + mock(HiveFunctionRegistry.class); + ClassLoader mockClassLoader = + mockHiveFunctionNamespaceManager.getClass().getClassLoader(); + + this.hiveFunctionNamespaceManager = new HiveFunctionNamespaceManager("hive", + mockClassLoader, mockHiveFunctionRegistry, mockTypeManager); + } + + @Test + public void testBeginTransaction() + { + Assert.assertTrue( + hiveFunctionNamespaceManager.beginTransaction().toString().contains("EmptyTransactionHandle")); + } + + @Test + public void testCommit() + { + // Null function + hiveFunctionNamespaceManager.commit(mock(FunctionNamespaceTransactionHandle.class)); + } + + @Test + public void testAbort() + { + // Null function + hiveFunctionNamespaceManager.abort(mock(FunctionNamespaceTransactionHandle.class)); + } + + @Test(expectedExceptions = IllegalStateException.class, + expectedExceptionsMessageRegExp = "Cannot create function in hive" + + " function namespace: test.test.test") + public void testCreateFunction() + { + SqlInvokedFunction mockSqlInvokedFunction = + mock(SqlInvokedFunction.class); + Signature signature = new Signature( + QualifiedObjectName.valueOf("test.test.test"), + FunctionKind.AGGREGATE, + DoubleType.DOUBLE.getTypeSignature(), + ImmutableList.of(DoubleType.DOUBLE.getTypeSignature())); + when(mockSqlInvokedFunction.getSignature()).thenReturn(signature); + + hiveFunctionNamespaceManager.createFunction( + mockSqlInvokedFunction, false); + } + + @Test + public void testListFunctions() + { + Collection result = + hiveFunctionNamespaceManager.listFunctions(); + + assertEquals(0, result.size()); + } + + @Test(expectedExceptions = IllegalStateException.class, + expectedExceptionsMessageRegExp = "Get function is not supported") + public void testGetFunctions() + { + hiveFunctionNamespaceManager.getFunctions(Optional.empty(), + mockQualifiedObjectName); + } + + @Test + public void testGetFunctionHandle() + { + TypeSignature mockReturnType = mock(TypeSignature.class); + List argumentTypes = new ArrayList<>(); + Signature signature = new Signature(mockQualifiedObjectName, SCALAR, emptyList(), + emptyList(), mockReturnType, argumentTypes, false); + + Assert.assertEquals(hiveFunctionNamespaceManager.getFunctionHandle( + Optional.empty(), signature), + new BuiltInFunctionHandle(signature)); + } + + @Test + public void testCanResolveFunction() + { + Assert.assertTrue(hiveFunctionNamespaceManager.canResolveFunction()); + } + + @Test(expectedExceptions = NullPointerException.class) + public void testGetFunctionMetadata() + { + hiveFunctionNamespaceManager.getFunctionMetadata(mockBuiltInFunctionHandle); + } + + @Test(expectedExceptions = NullPointerException.class) + public void testGetScalarFunctionImplementation() + { + hiveFunctionNamespaceManager.getScalarFunctionImplementation(mockBuiltInFunctionHandle); + } + + @Test(expectedExceptions = IllegalStateException.class, + expectedExceptionsMessageRegExp = + "Execute function is not supported") + public void testExecuteFunction() + { + FunctionHandle mockFunctionHandle = mock(FunctionHandle.class); + Page input = new Page(1); + List channels = new ArrayList<>(); + + hiveFunctionNamespaceManager.executeFunction( + mockFunctionHandle, input, channels, mockTypeManager); + } + + @Test + public void testResolveFunction() + { + try { + hiveFunctionNamespaceManager.resolveFunction(Optional.empty(), + QualifiedObjectName.valueOf("test.test.test"), + ImmutableList.of(DoubleType.DOUBLE.getTypeSignature())); + } + catch (PrestoException e) { + Assert.assertEquals(e.getMessage(), + "Hive udf unsupported namespace test.test. Its schema should be default."); + } + + try { + hiveFunctionNamespaceManager.resolveFunction(Optional.empty(), + QualifiedObjectName.valueOf("test.default.test"), + ImmutableList.of(DoubleType.DOUBLE.getTypeSignature())); + } + catch (NullPointerException ignored) { + } + } + + @Test + public void testCreateDummyHiveScalarFunction() throws IllegalAccessException, + NoSuchMethodException, InvocationTargetException + { + Class clas = HiveFunctionNamespaceManager.class; + Method method = clas.getDeclaredMethod("createDummyHiveScalarFunction", String.class); + method.setAccessible(true); + + Assert.assertTrue(method.invoke(hiveFunctionNamespaceManager, "test") + .toString().contains("DummyHiveScalarFunction")); + } + + @Test + public void testInnerGetFunctionMetadata() throws InvocationTargetException, InstantiationException, + IllegalAccessException, NoSuchMethodException + { + Class[] declaredClasses = HiveFunctionNamespaceManager.class.getDeclaredClasses(); + Signature signature = new Signature( + QualifiedObjectName.valueOf("test.test.test"), + FunctionKind.AGGREGATE, + DoubleType.DOUBLE.getTypeSignature(), + ImmutableList.of(DoubleType.DOUBLE.getTypeSignature())); + + for (Class c : declaredClasses) { + int mod = c.getModifiers(); + String modifier = Modifier.toString(mod); + if (modifier.contains("private") && c.getName().contains("DummyHiveScalarFunction")) { + Constructor[] declaredConstructors = c.getDeclaredConstructors(); + AccessibleObject.setAccessible(declaredConstructors, true); + + Constructor constructor = declaredConstructors[0]; + Object object = constructor.newInstance(signature); + Method method = c.getMethod("getFunctionMetadata"); + method.setAccessible(true); + + try { + method.invoke(object); + } + catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof IllegalStateException) { + Assert.assertEquals(cause.getMessage(), "Get function metadata is not supported"); + } + } + } + } + } + + @Test + public void testGetName() throws InvocationTargetException, InstantiationException, + IllegalAccessException, NoSuchMethodException + { + Class[] declaredClasses = HiveFunctionNamespaceManager.class.getDeclaredClasses(); + QualifiedObjectName name = QualifiedObjectName.valueOf("test.test.test"); + + for (Class c : declaredClasses) { + int mod = c.getModifiers(); + String modifier = Modifier.toString(mod); + if (modifier.contains("private") && c.getName().contains("FunctionKey")) { + Constructor[] declaredConstructors = c.getDeclaredConstructors(); + AccessibleObject.setAccessible(declaredConstructors, true); + + Constructor constructor = declaredConstructors[0]; + List argumentTypes = ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()); + Object object = constructor.newInstance(name, argumentTypes); + Method method = c.getMethod("getName"); + method.setAccessible(true); + + Assert.assertEquals(method.invoke(object), name); + } + } + } + + @Test + public void testGetArgumentTypes() throws InvocationTargetException, InstantiationException, + IllegalAccessException, NoSuchMethodException + { + Class[] declaredClasses = HiveFunctionNamespaceManager.class.getDeclaredClasses(); + QualifiedObjectName name = QualifiedObjectName.valueOf("test.test.test"); + + for (Class c : declaredClasses) { + int mod = c.getModifiers(); + String modifier = Modifier.toString(mod); + if (modifier.contains("private") && c.getName().contains("FunctionKey")) { + Constructor[] declaredConstructors = c.getDeclaredConstructors(); + AccessibleObject.setAccessible(declaredConstructors, true); + + Constructor constructor = declaredConstructors[0]; + List argumentTypes = ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()); + Object object = constructor.newInstance(name, argumentTypes); + Method method = c.getMethod("getArgumentTypes"); + method.setAccessible(true); + + Assert.assertEquals(method.invoke(object), argumentTypes); + } + } + } + + @Test + public void testHashCode() throws InvocationTargetException, InstantiationException, + IllegalAccessException, NoSuchMethodException + { + Class[] declaredClasses = HiveFunctionNamespaceManager.class.getDeclaredClasses(); + QualifiedObjectName name = QualifiedObjectName.valueOf("test.test.test"); + + for (Class c : declaredClasses) { + int mod = c.getModifiers(); + String modifier = Modifier.toString(mod); + if (modifier.contains("private") && c.getName().contains("FunctionKey")) { + Constructor[] declaredConstructors = c.getDeclaredConstructors(); + AccessibleObject.setAccessible(declaredConstructors, true); + + Constructor constructor = declaredConstructors[0]; + List argumentTypes = ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()); + Object object = constructor.newInstance(name, argumentTypes); + Method method = c.getMethod("hashCode"); + method.setAccessible(true); + + Assert.assertEquals(method.invoke(object), Objects.hash(name, argumentTypes)); + } + } + } + + @Test + public void testToString() throws InvocationTargetException, InstantiationException, + IllegalAccessException, NoSuchMethodException + { + Class[] declaredClasses = HiveFunctionNamespaceManager.class.getDeclaredClasses(); + QualifiedObjectName name = QualifiedObjectName.valueOf("test.test.test"); + + for (Class c : declaredClasses) { + int mod = c.getModifiers(); + String modifier = Modifier.toString(mod); + if (modifier.contains("private") && c.getName().contains("FunctionKey")) { + Constructor[] declaredConstructors = c.getDeclaredConstructors(); + AccessibleObject.setAccessible(declaredConstructors, true); + + Constructor constructor = declaredConstructors[0]; + List argumentTypes = ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()); + Object object = constructor.newInstance(name, argumentTypes); + Method method = c.getMethod("toString"); + method.setAccessible(true); + + Assert.assertEquals(method.invoke(object), "FunctionKey{name=test.test.test, arguments=[double]}"); + } + } + } + + @Test + public void testEquals() throws InvocationTargetException, InstantiationException, + IllegalAccessException, NoSuchMethodException + { + Class[] declaredClasses = HiveFunctionNamespaceManager.class.getDeclaredClasses(); + QualifiedObjectName name = QualifiedObjectName.valueOf("test.test.test"); + + for (Class c : declaredClasses) { + int mod = c.getModifiers(); + String modifier = Modifier.toString(mod); + if (modifier.contains("private") && c.getName().contains("FunctionKey")) { + Constructor[] declaredConstructors = c.getDeclaredConstructors(); + AccessibleObject.setAccessible(declaredConstructors, true); + + Constructor constructor = declaredConstructors[0]; + List argumentTypes = ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()); + Object object = constructor.newInstance(name, argumentTypes); + Method method = c.getMethod("equals", Object.class); + method.setAccessible(true); + + Assert.assertEquals(method.invoke(object, name.getSchemaName()), false); + + method = c.getMethod("equals", Object.class); + method.setAccessible(true); + Assert.assertEquals(method.invoke(object, object), true); + } + } + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestHiveFunctionNamespaceManagerFactory.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestHiveFunctionNamespaceManagerFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..be6f5e517e85292180f1e939d0635428652d72c2 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestHiveFunctionNamespaceManagerFactory.java @@ -0,0 +1,154 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions; + +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.function.FunctionNamespaceManagerContext; +import io.prestosql.spi.type.RealType; +import io.prestosql.spi.type.Type; +import io.prestosql.testing.MaterializedResult; +import io.prestosql.testing.assertions.Assert; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.Test; + +import java.io.File; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.regex.Pattern; + +import static com.google.common.math.DoubleMath.fuzzyEquals; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; + +@SuppressWarnings("UnknownLanguage") +public class TestHiveFunctionNamespaceManagerFactory + extends AbstractTestHiveFunctions +{ + private static final String FUNCTION_PREFIX = "hive.default."; + private static final String TABLE_NAME = "memory.default.function_testing"; + + @Override + protected Optional getInitScript() + { + return Optional.of(new File("src/test/sql/function-testing.sql")); + } + + private static void assertNaN(Object o) + { + if (o instanceof Double) { + assertEquals(o, Double.NaN); + } + else if (o instanceof Float) { + assertEquals((Float) o, Float.NaN); + } + else { + fail("Unexpected " + o); + } + } + + private void check(@Language("SQL") String query, Type expectedType, Object expectedValue) + { + MaterializedResult result = client.execute(query).getResult(); + assertEquals(result.getRowCount(), 1); + assertEquals(result.getTypes().get(0), expectedType); + Object actual = result.getMaterializedRows().get(0).getField(0); + + if (expectedType.equals(DOUBLE) || expectedType.equals(RealType.REAL)) { + if (expectedValue == null) { + assertNaN(actual); + } + else { + assertTrue(fuzzyEquals(((Number) actual).doubleValue(), ((Number) expectedValue).doubleValue(), 0.000001)); + } + } + else { + assertEquals(actual, expectedValue); + } + } + + private static String selectF(String function, String... args) + { + StringBuilder builder = new StringBuilder(); + builder.append("SELECT ").append(FUNCTION_PREFIX).append(function); + builder.append("("); + if (args != null) { + builder.append(String.join(", ", args)); + } + builder.append(")").append(" FROM ").append(TABLE_NAME); + return builder.toString(); + } + + @Test + public void testCreate() + { + FunctionNamespaceManagerContext mockFunctionNamespaceManagerContext = + mock(FunctionNamespaceManagerContext.class); + HiveFunctionNamespaceManagerFactory instance = new HiveFunctionNamespaceManagerFactory(classLoader); + Map config = new HashMap(); + try { + instance.create("test", config, mockFunctionNamespaceManagerContext); + } + catch (PrestoException e) { + assertTrue(Pattern.matches("The configuration .* should contain the parameter " + + "external-functions.dir.", e.getMessage())); + } + config.put("external-functions.dir", "false"); + config.put("test", "test"); + + try { + instance.create("test", config, mockFunctionNamespaceManagerContext); + } + catch (PrestoException e) { + assertTrue(Pattern.matches("The input path .* is invalid.", e.getMessage())); + } + + // Test registration, loading and calling UDF + config.clear(); + config.put("external-functions.dir", System.getProperty("user.dir") + File.separatorChar + "src/test/resources/"); + config.put("UdfTest", "com.test.udf.UdfTest"); + + when(mockFunctionNamespaceManagerContext.getTypeManager()).thenReturn(Optional.of(typeManager)); + instance.create("hive", config, mockFunctionNamespaceManagerContext); + + check(selectF("UdfTest", "c_varchar"), VARCHAR, "UdfTest varchar"); + } + + @Test + public void testGetURLs() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException, + InstantiationException + { + Constructor constructor = + HiveFunctionNamespaceManagerFactory.class.getDeclaredConstructor(ClassLoader.class); + Object instance = constructor.newInstance(classLoader); + + Class clas = + HiveFunctionNamespaceManagerFactory.class; + Method method = clas.getDeclaredMethod("getURLs", String.class, List.class); + method.setAccessible(true); + + Assert.assertNull(method.invoke(instance, HiveFunctionNamespaceManagerFactory.NAME, null)); + Assert.assertNull(method.invoke(instance, "etc/null-dir", null)); + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestScalarMethodHandles.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestScalarMethodHandles.java new file mode 100644 index 0000000000000000000000000000000000000000..6214408c369b500bf536d9a7faa446465885cca7 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestScalarMethodHandles.java @@ -0,0 +1,168 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions; + +import com.google.inject.Key; +import io.airlift.slice.Slices; +import io.prestosql.plugin.hive.functions.scalar.ScalarFunctionInvoker; +import io.prestosql.server.testing.TestingPrestoServer; +import io.prestosql.spi.block.ArrayBlock; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.ByteArrayBlock; +import io.prestosql.spi.connector.QualifiedObjectName; +import io.prestosql.spi.function.FunctionKind; +import io.prestosql.spi.function.Signature; +import io.prestosql.spi.type.TypeManager; +import io.prestosql.spi.type.TypeSignature; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.lang.invoke.MethodHandle; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static io.prestosql.plugin.hive.functions.HiveFunctionsTestUtils.createTestingPrestoServer; +import static io.prestosql.plugin.hive.functions.gen.ScalarMethodHandles.generateUnbound; +import static io.prestosql.plugin.hive.functions.type.TestObjectEncoders.THREE_LONG_VAL; +import static io.prestosql.spi.type.TypeSignature.parseTypeSignature; +import static org.testng.Assert.assertEquals; + +public class TestScalarMethodHandles +{ + private AtomicInteger number; + private TestingPrestoServer server; + private TypeManager typeManager; + + private static final double DOUBLE_VAL = 2.1; + + @BeforeClass + public void setup() + throws Exception + { + this.number = new AtomicInteger(0); + this.server = createTestingPrestoServer(); + this.typeManager = server.getInstance(Key.get(TypeManager.class)); + } + + @Test + public void generateUnboundTest() + throws Throwable + { + runCase("boolean", tokens("boolean"), true, false); + runCase("bigint", tokens("bigint,bigint"), THREE_LONG_VAL, 2L, 1L); + runCase("varchar", tokens("varchar,double"), Slices.utf8Slice("output"), + Slices.utf8Slice("input"), DOUBLE_VAL); + runCase("array(bigint)", tokens("bigint,bigint"), + createArrayBlock(), 2L, 1L); + } + + private void runCase(final String resultType, final String[] argumentTypes, + final Object result, final Object... arguments) + throws Throwable + { + Signature signature = createScalarSignature(resultType, argumentTypes); + MethodHandle methodHandle = generateUnbound(signature, typeManager); + TestingScalarFunctionInvoker invoker = + new TestingScalarFunctionInvoker(signature, result); + final Object output; + if (arguments.length == 1) { + output = methodHandle.invoke(invoker, arguments[0]); + } + else if (arguments.length == 2) { + output = methodHandle.invoke(invoker, arguments[0], arguments[1]); + } + else if (arguments.length == 3) { + output = methodHandle.invoke( + invoker, arguments[0], arguments[1], arguments[2]); + } + else if (arguments.length == 4) { + output = methodHandle.invoke( + invoker, arguments[0], arguments[1], + arguments[2], arguments[3]); + } + else { + throw new RuntimeException("Not supported yet"); + } + Object[] inputs = invoker.getInputs(); + + assertEquals(output, result); + assertEquals(inputs, arguments); + } + + private String[] tokens(final String s) + { + return s.split(","); + } + + private Signature createScalarSignature(final String returnType, + final String... argumentTypes) + { + return new Signature(QualifiedObjectName.valueOf( + "hive.default.testing_" + number.incrementAndGet()), + FunctionKind.SCALAR, + parseTypeSignature(returnType), + Stream.of(argumentTypes) + .map(TypeSignature::parseTypeSignature) + .toArray(TypeSignature[]::new)); + } + + private Block createEmptyBlock() + { + return new ByteArrayBlock(0, Optional.empty(), new byte[0]); + } + + private Block createArrayBlock() + { + Block emptyValueBlock = createEmptyBlock(); + return ArrayBlock.fromElementBlock(1, + Optional.empty(), IntStream.range(0, 2).toArray(), + emptyValueBlock); + } + + private static class TestingScalarFunctionInvoker + implements ScalarFunctionInvoker + { + private final Signature signature; + private final Object result; + private Object[] inputs; + + TestingScalarFunctionInvoker(final Signature signature, + final Object result) + { + this.signature = signature; + this.result = result; + } + + @Override + public Signature getSignature() + { + return signature; + } + + @Override + public Object evaluate(final Object... inputs) + { + this.inputs = inputs; + return result; + } + + public Object[] getInputs() + { + return inputs; + } + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestStaticHiveFunctionRegistry.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestStaticHiveFunctionRegistry.java new file mode 100644 index 0000000000000000000000000000000000000000..56714d198fa7f2c499d2bb90c870e457b523db46 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/TestStaticHiveFunctionRegistry.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions; + +import io.prestosql.spi.connector.QualifiedObjectName; +import org.testng.annotations.Test; + +import static org.mockito.Mockito.mock; + +public class TestStaticHiveFunctionRegistry +{ + @Test(expectedExceptions = ClassNotFoundException.class, + expectedExceptionsMessageRegExp = "Class of function .* not found") + public void testGetClass() throws ClassNotFoundException + { + HiveFunctionNamespacePlugin mockHiveFunctionNamespaceManager = + mock(HiveFunctionNamespacePlugin.class); + ClassLoader mockClassLoader = + mockHiveFunctionNamespaceManager.getClass().getClassLoader(); + StaticHiveFunctionRegistry staticHiveFunctionRegistry = + new StaticHiveFunctionRegistry(mockClassLoader); + QualifiedObjectName mockQualifiedObjectName = + mock(QualifiedObjectName.class); + QualifiedObjectName.valueOf("test.test.test"); + + staticHiveFunctionRegistry.getClass(mockQualifiedObjectName); + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/scalar/TestHiveScalarFunction.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/scalar/TestHiveScalarFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..be1477effd81e5040191623a6b965c00405a0de8 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/scalar/TestHiveScalarFunction.java @@ -0,0 +1,217 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions.scalar; + +import com.google.common.collect.ImmutableList; +import io.prestosql.plugin.hive.functions.HiveFunctionRegistry; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.connector.QualifiedObjectName; +import io.prestosql.spi.function.FunctionImplementationType; +import io.prestosql.spi.function.FunctionKind; +import io.prestosql.spi.function.FunctionMetadata; +import io.prestosql.spi.function.InvocationConvention; +import io.prestosql.spi.function.ScalarFunctionImplementation; +import io.prestosql.spi.function.Signature; +import io.prestosql.spi.type.DoubleType; +import io.prestosql.spi.type.TypeManager; +import io.prestosql.spi.type.TypeSignature; +import io.prestosql.testing.assertions.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.lang.invoke.MethodHandle; +import java.lang.reflect.AccessibleObject; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.List; +import java.util.Optional; + +import static io.prestosql.plugin.hive.functions.HiveFunctionErrorCode.HIVE_FUNCTION_INITIALIZATION_ERROR; +import static io.prestosql.plugin.hive.functions.scalar.HiveScalarFunction.createHiveScalarFunction; +import static io.prestosql.spi.function.FunctionKind.SCALAR; +import static org.mockito.Mockito.mock; + +public class TestHiveScalarFunction +{ + private Signature signature; + private FunctionMetadata functionMetadata; + private ScalarFunctionImplementation implementation; + private Object instance; + private QualifiedObjectName mockQualifiedObjectName = mock(QualifiedObjectName.class); + private InvocationConvention mockInvocationConvention = mock(InvocationConvention.class); + private MethodHandle mockMethodHandle = mock(MethodHandle.class); + Class clas; + + @BeforeClass + public void setup() throws InvocationTargetException, InstantiationException, + IllegalAccessException, NoSuchMethodException + { + this.signature = new Signature( + QualifiedObjectName.valueOf("test.test.test"), + FunctionKind.AGGREGATE, + DoubleType.DOUBLE.getTypeSignature(), + ImmutableList.of(DoubleType.DOUBLE.getTypeSignature())); + + this.functionMetadata = new FunctionMetadata(mockQualifiedObjectName, + signature.getArgumentTypes(), + signature.getReturnType(), + SCALAR, + FunctionImplementationType.BUILTIN, + true, + true); + + this.implementation = mock(ScalarFunctionImplementation.class); + + Constructor constructor = HiveScalarFunction.class.getDeclaredConstructor( + FunctionMetadata.class, Signature.class, String.class, ScalarFunctionImplementation.class); + constructor.setAccessible(true); + this.instance = constructor.newInstance(functionMetadata, signature, "test", implementation); + this.clas = HiveScalarFunction.class; + } + + @Test + public void testCreateHiveScalarFunction() throws ClassNotFoundException + { + QualifiedObjectName name = QualifiedObjectName.valueOf("test.test.test"); + List argumentTypes = ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()); + TypeManager mockTypeManager = mock(TypeManager.class); + HiveFunctionRegistry mockHiveFunctionRegistry = mock(HiveFunctionRegistry.class); + Class functionClass = mockHiveFunctionRegistry.getClass(mockQualifiedObjectName); + + try { + createHiveScalarFunction(functionClass, name, argumentTypes, mockTypeManager); + } + catch (PrestoException e) { + Assert.assertEquals(HIVE_FUNCTION_INITIALIZATION_ERROR.toString(), e.getMessage()); + } + } + + @Test + public void testGetFunctionMetadata() throws NoSuchMethodException, InvocationTargetException, + IllegalAccessException + { + Method method = clas.getDeclaredMethod("getFunctionMetadata"); + + Assert.assertEquals(functionMetadata, method.invoke(instance)); + } + + @Test + public void testGetJavaScalarFunctionImplementation() throws NoSuchMethodException, InvocationTargetException, + IllegalAccessException + { + Method method = clas.getDeclaredMethod("getJavaScalarFunctionImplementation"); + + Assert.assertEquals(implementation, method.invoke(instance)); + } + + @Test + public void testIsHidden() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException + { + Method method = clas.getDeclaredMethod("isHidden"); + + Assert.assertEquals(false, method.invoke(instance)); + } + + @Test + public void testGetInvocationConvention() throws InvocationTargetException, InstantiationException, + IllegalAccessException, NoSuchMethodException + { + Class[] declaredClasses = HiveScalarFunction.class.getDeclaredClasses(); + + Object object = Optional.empty(); + for (Class c : declaredClasses) { + int mod = c.getModifiers(); + String modifier = Modifier.toString(mod); + if (modifier.contains("private") && c.getName().contains("HiveScalarFunctionImplementation")) { + Constructor[] declaredConstructors = c.getDeclaredConstructors(); + AccessibleObject.setAccessible(declaredConstructors, true); + + for (Constructor declaredConstructor : declaredConstructors) { + if (Modifier.toString(declaredConstructor.getModifiers()).contains("private")) { + Constructor constructor = declaredConstructor; + object = constructor.newInstance(mockMethodHandle, mockInvocationConvention); + } + } + Method method = c.getMethod("getInvocationConvention"); + method.setAccessible(true); + + Assert.assertEquals(method.invoke(object), mockInvocationConvention); + } + } + } + + @Test + public void testGetMethodHandle() throws InvocationTargetException, InstantiationException, + IllegalAccessException, NoSuchMethodException + { + Class[] declaredClasses = HiveScalarFunction.class.getDeclaredClasses(); + + Object object = Optional.empty(); + for (Class c : declaredClasses) { + int mod = c.getModifiers(); + String modifier = Modifier.toString(mod); + if (modifier.contains("private") && c.getName().contains("HiveScalarFunctionImplementation")) { + Constructor[] declaredConstructors = c.getDeclaredConstructors(); + AccessibleObject.setAccessible(declaredConstructors, true); + + for (Constructor declaredConstructor : declaredConstructors) { + if (Modifier.toString(declaredConstructor.getModifiers()).contains("private")) { + Constructor constructor = declaredConstructor; + object = constructor.newInstance(mockMethodHandle, mockInvocationConvention); + } + } + Method method = c.getMethod("getMethodHandle"); + method.setAccessible(true); + + Assert.assertEquals(method.invoke(object), mockMethodHandle); + } + } + } + + @Test + public void testIsNullable() throws InvocationTargetException, InstantiationException, + IllegalAccessException, NoSuchMethodException + { + Class[] declaredClasses = HiveScalarFunction.class.getDeclaredClasses(); + + Object object = Optional.empty(); + for (Class c : declaredClasses) { + int mod = c.getModifiers(); + String modifier = Modifier.toString(mod); + if (modifier.contains("private") && c.getName().contains("HiveScalarFunctionImplementation")) { + Constructor[] declaredConstructors = c.getDeclaredConstructors(); + AccessibleObject.setAccessible(declaredConstructors, true); + + for (Constructor declaredConstructor : declaredConstructors) { + if (Modifier.toString(declaredConstructor.getModifiers()).contains("private")) { + Constructor constructor = declaredConstructor; + object = constructor.newInstance(mockMethodHandle, mockInvocationConvention); + } + } + Method method = c.getMethod("isNullable"); + method.setAccessible(true); + + try { + method.invoke(object); + } + catch (InvocationTargetException e) { + Assert.assertTrue(e.getCause() instanceof NullPointerException); + } + } + } + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestBlockInputDecoders.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestBlockInputDecoders.java new file mode 100644 index 0000000000000000000000000000000000000000..2a82ab73fd1bad23bfc726de1127b7c1b09685a5 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestBlockInputDecoders.java @@ -0,0 +1,132 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions.type; + +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.type.ArrayType; +import io.prestosql.spi.type.MapType; +import io.prestosql.spi.type.RowType; +import io.prestosql.spi.type.Type; +import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.locationtech.jts.util.Assert; +import org.testng.annotations.Test; + +import java.util.regex.Pattern; + +import static org.mockito.Mockito.mock; + +public class TestBlockInputDecoders +{ + @Test + public void testCreateBlockInputDecoder() + { + Type mockType = mock(Type.class); + RowType mockRowType = mock(RowType.class); + ArrayType mockArrayType = mock(ArrayType.class); + MapType mockMapType = mock(MapType.class); + + // inspector instanceof ConstantObjectInspector + ConstantObjectInspector mockConstantObjectInspector + = mock(ConstantObjectInspector.class); + try { + BlockInputDecoders.createBlockInputDecoder( + mockConstantObjectInspector, mockType); + } + catch (PrestoException e) { + Assert.isTrue(Pattern.matches("Unsupported Hive type .*", + e.getMessage())); + } + // inspector instanceof PrimitiveObjectInspector + PrimitiveObjectInspector mockPrimitiveObjectInspector + = mock(PrimitiveObjectInspector.class); + try { + BlockInputDecoders.createBlockInputDecoder( + mockPrimitiveObjectInspector, mockType); + } + catch (PrestoException e) { + Assert.isTrue(Pattern.matches("Unsupported Hive type .*", + e.getMessage())); + } + // inspector instanceof StandardStructObjectInspector + StandardStructObjectInspector mockStandardStructObjectInspector + = mock(StandardStructObjectInspector.class); + try { + BlockInputDecoders.createBlockInputDecoder( + mockStandardStructObjectInspector, mockRowType); + } + catch (PrestoException e) { + Assert.isTrue(Pattern.matches( + "Unsupported Hive type .*", e.getMessage())); + } + // inspector instanceof SettableStructObjectInspector + SettableStructObjectInspector mockSettableStructObjectInspector = mock( + SettableStructObjectInspector.class); + try { + BlockInputDecoders.createBlockInputDecoder( + mockSettableStructObjectInspector, mockRowType); + } + catch (PrestoException e) { + Assert.isTrue(Pattern.matches( + "Unsupported Hive type .*", e.getMessage())); + } + // inspector instanceof StructObjectInspector + StructObjectInspector mockStructObjectInspector = mock( + StructObjectInspector.class); + try { + BlockInputDecoders.createBlockInputDecoder( + mockStructObjectInspector, mockRowType); + } + catch (PrestoException e) { + Assert.isTrue(Pattern.matches( + "Unsupported Hive type .*", e.getMessage())); + } + // inspector instanceof ListObjectInspector + ListObjectInspector mockListObjectInspector = mock( + ListObjectInspector.class); + try { + BlockInputDecoders.createBlockInputDecoder( + mockListObjectInspector, mockArrayType); + } + catch (PrestoException e) { + Assert.isTrue(Pattern.matches( + "Unsupported Hive type .*", e.getMessage())); + } + // inspector instanceof MapObjectInspector + MapObjectInspector mockMapObjectInspector = mock( + MapObjectInspector.class); + try { + BlockInputDecoders.createBlockInputDecoder( + mockMapObjectInspector, mockMapType); + } + catch (PrestoException e) { + Assert.isTrue(Pattern.matches( + "Unsupported Hive type .*", e.getMessage())); + } + // throw unsupported type + try { + BlockInputDecoders.createBlockInputDecoder(null, mockMapType); + } + catch (PrestoException e) { + Assert.isTrue(Pattern.matches( + "Unsupported Hive type .*", e.getMessage())); + } + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestHiveTypes.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestHiveTypes.java new file mode 100644 index 0000000000000000000000000000000000000000..aa6d7b68f1bea8828f00d60bdcf9f464a61a2279 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestHiveTypes.java @@ -0,0 +1,405 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions.type; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.type.ArrayType; +import io.prestosql.spi.type.CharType; +import io.prestosql.spi.type.DecimalType; +import io.prestosql.spi.type.MapType; +import io.prestosql.spi.type.RowType; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.TypeSignature; +import io.prestosql.spi.type.VarcharType; +import io.prestosql.testing.assertions.Assert; +import org.apache.hadoop.hive.common.type.HiveChar; +import org.apache.hadoop.hive.common.type.HiveVarchar; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.testng.annotations.Test; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; + +import static io.prestosql.spi.type.StandardTypes.ARRAY; +import static io.prestosql.spi.type.StandardTypes.BIGINT; +import static io.prestosql.spi.type.StandardTypes.BOOLEAN; +import static io.prestosql.spi.type.StandardTypes.CHAR; +import static io.prestosql.spi.type.StandardTypes.DATE; +import static io.prestosql.spi.type.StandardTypes.DECIMAL; +import static io.prestosql.spi.type.StandardTypes.DOUBLE; +import static io.prestosql.spi.type.StandardTypes.GEOMETRY; +import static io.prestosql.spi.type.StandardTypes.INTEGER; +import static io.prestosql.spi.type.StandardTypes.MAP; +import static io.prestosql.spi.type.StandardTypes.REAL; +import static io.prestosql.spi.type.StandardTypes.ROW; +import static io.prestosql.spi.type.StandardTypes.SMALLINT; +import static io.prestosql.spi.type.StandardTypes.TIMESTAMP; +import static io.prestosql.spi.type.StandardTypes.TINYINT; +import static io.prestosql.spi.type.StandardTypes.VARBINARY; +import static io.prestosql.spi.type.StandardTypes.VARCHAR; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.binaryTypeInfo; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.booleanTypeInfo; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.byteTypeInfo; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.charTypeInfo; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.dateTypeInfo; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.doubleTypeInfo; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.floatTypeInfo; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.getStructTypeInfo; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.intTypeInfo; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.longTypeInfo; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.shortTypeInfo; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.timestampTypeInfo; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.varcharTypeInfo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TestHiveTypes +{ + private Type mockType = mock(Type.class); + + @Test + public void testCreateHiveVarCharString() + { + String s = "test"; + + Assert.assertEquals(new HiveVarchar(s, s.length()), + HiveTypes.createHiveVarChar(s)); + } + + @Test + public void testCreateHiveVarCharSlice() + { + Slice slice = Slices.allocate(1); + + Assert.assertEquals(new HiveVarchar(slice.toStringUtf8(), slice.toStringUtf8().length()), + HiveTypes.createHiveVarChar(slice)); + } + + @Test + public void testCreateHiveCharString() + { + String s = "test"; + + Assert.assertEquals(new HiveChar(s, s.length()), + HiveTypes.createHiveChar(s)); + } + + @Test + public void testCreateHiveCharSlice() + { + Slice slice = Slices.allocate(1); + + Assert.assertEquals(new HiveChar(slice.toStringUtf8(), slice.toStringUtf8().length()), + HiveTypes.createHiveChar(slice)); + } + + @Test + public void testToTypeInfo() + { + TypeSignature mockTypeSignature = mock(TypeSignature.class); + when(mockType.getTypeSignature()).thenReturn(mockTypeSignature); + + // case BIGINT + when(mockTypeSignature.getBase()).thenReturn(BIGINT); + Assert.assertEquals(longTypeInfo, HiveTypes.toTypeInfo(mockType)); + // case INTEGER + when(mockTypeSignature.getBase()).thenReturn(INTEGER); + Assert.assertEquals(intTypeInfo, HiveTypes.toTypeInfo(mockType)); + // case SMALLINT + when(mockTypeSignature.getBase()).thenReturn(SMALLINT); + Assert.assertEquals(shortTypeInfo, HiveTypes.toTypeInfo(mockType)); + // case TINYINT + when(mockTypeSignature.getBase()).thenReturn(TINYINT); + Assert.assertEquals(byteTypeInfo, HiveTypes.toTypeInfo(mockType)); + // case BOOLEAN + when(mockTypeSignature.getBase()).thenReturn(BOOLEAN); + Assert.assertEquals(booleanTypeInfo, HiveTypes.toTypeInfo(mockType)); + // case DATE + when(mockTypeSignature.getBase()).thenReturn(DATE); + Assert.assertEquals(dateTypeInfo, HiveTypes.toTypeInfo(mockType)); + // case DECIMAL + when(mockTypeSignature.getBase()).thenReturn(DECIMAL); + try { + HiveTypes.toTypeInfo(mockType); + } + catch (PrestoException e) { + org.locationtech.jts.util.Assert.isTrue(Pattern.matches( + "Unsupported Presto type .*", e.getMessage())); + } + // case REAL + when(mockTypeSignature.getBase()).thenReturn(REAL); + Assert.assertEquals(floatTypeInfo, HiveTypes.toTypeInfo(mockType)); + // case DOUBLE + when(mockTypeSignature.getBase()).thenReturn(DOUBLE); + Assert.assertEquals(doubleTypeInfo, HiveTypes.toTypeInfo(mockType)); + // case TIMESTAMP + when(mockTypeSignature.getBase()).thenReturn(TIMESTAMP); + Assert.assertEquals(timestampTypeInfo, HiveTypes.toTypeInfo(mockType)); + // case VARBINARY + when(mockTypeSignature.getBase()).thenReturn(VARBINARY); + Assert.assertEquals(binaryTypeInfo, HiveTypes.toTypeInfo(mockType)); + // case VARCHAR + when(mockTypeSignature.getBase()).thenReturn(VARCHAR); + try { + HiveTypes.toTypeInfo(mockType); + } + catch (PrestoException e) { + org.locationtech.jts.util.Assert.isTrue(Pattern.matches( + "Unsupported Presto type .*", e.getMessage())); + } + // case CHAR + when(mockTypeSignature.getBase()).thenReturn(CHAR); + try { + HiveTypes.toTypeInfo(mockType); + } + catch (PrestoException e) { + org.locationtech.jts.util.Assert.isTrue(Pattern.matches( + "Unsupported Presto type .*", e.getMessage())); + } + // case ROW + when(mockTypeSignature.getBase()).thenReturn(ROW); + try { + HiveTypes.toTypeInfo(mockType); + } + catch (PrestoException e) { + org.locationtech.jts.util.Assert.isTrue(Pattern.matches( + "Unsupported Presto type .*", e.getMessage())); + } + // case ARRAY + when(mockTypeSignature.getBase()).thenReturn(ARRAY); + try { + HiveTypes.toTypeInfo(mockType); + } + catch (PrestoException e) { + org.locationtech.jts.util.Assert.isTrue(Pattern.matches( + "Unsupported Presto type .*", e.getMessage())); + } + // case MAP + when(mockTypeSignature.getBase()).thenReturn(MAP); + try { + HiveTypes.toTypeInfo(mockType); + } + catch (PrestoException e) { + org.locationtech.jts.util.Assert.isTrue(Pattern.matches( + "Unsupported Presto type .*", e.getMessage())); + } + // throw unsupported type + when(mockTypeSignature.getBase()).thenReturn(GEOMETRY); + try { + HiveTypes.toTypeInfo(mockType); + } + catch (PrestoException e) { + org.locationtech.jts.util.Assert.isTrue(Pattern.matches( + "Unsupported Presto TypeSignature .*", e.getMessage())); + } + } + + @Test + public void testToDecimalTypeInfo() + throws IllegalAccessException, + NoSuchMethodException, InstantiationException + { + Constructor hiveTypesConstructor = + HiveTypes.class.getDeclaredConstructor(); + hiveTypesConstructor.setAccessible(true); + Object instance = null; + try { + instance = hiveTypesConstructor.newInstance(); + } + catch (InvocationTargetException e) { + e.printStackTrace(); + } + Class hiveTypesClass = HiveTypes.class; + Method toDecimalTypeInfo = + hiveTypesClass.getDeclaredMethod( + "toDecimalTypeInfo", Type.class); + toDecimalTypeInfo.setAccessible(true); + + // type instanceof DecimalType + DecimalType mockDecimalType = mock(DecimalType.class); + try { + toDecimalTypeInfo.invoke(instance, mockDecimalType); + } + catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof IllegalArgumentException) { + IllegalArgumentException ex = (IllegalArgumentException) cause; + org.locationtech.jts.util.Assert.isTrue(Pattern.matches( + "Decimal precision out of allowed range .*", + ex.getMessage())); + } + } + + // throw unsupported type + try { + toDecimalTypeInfo.invoke(instance, mockType); + } + catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof PrestoException) { + PrestoException ex = (PrestoException) cause; + org.locationtech.jts.util.Assert.isTrue(Pattern.matches( + "Unsupported Presto type .*", ex.getMessage())); + } + } + } + + @Test + public void testToVarcharTypeInfo() + throws IllegalAccessException, + NoSuchMethodException, InstantiationException, + InvocationTargetException + { + Constructor hiveTypesConstructor = + HiveTypes.class.getDeclaredConstructor(); + hiveTypesConstructor.setAccessible(true); + Object instance = hiveTypesConstructor.newInstance(); + + Class hiveTypesClass = HiveTypes.class; + Method toVarcharTypeInfo = hiveTypesClass.getDeclaredMethod( + "toVarcharTypeInfo", Type.class); + toVarcharTypeInfo.setAccessible(true); + + // type instanceof VarcharType + Assert.assertEquals(toVarcharTypeInfo.invoke(instance, VarcharType.VARCHAR), varcharTypeInfo); + // throw unsupported type + try { + toVarcharTypeInfo.invoke(instance, mockType); + } + catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof PrestoException) { + PrestoException ex = (PrestoException) cause; + org.locationtech.jts.util.Assert.isTrue(Pattern.matches( + "Unsupported Presto type .*", ex.getMessage())); + } + } + } + + @Test + public void testToCharTypeInfo() throws NoSuchMethodException, InvocationTargetException, + InstantiationException, IllegalAccessException + { + Constructor hiveTypesConstructor = + HiveTypes.class.getDeclaredConstructor(); + hiveTypesConstructor.setAccessible(true); + Object instance = hiveTypesConstructor.newInstance(); + + Class hiveTypesClass = HiveTypes.class; + Method toVarcharTypeInfo = hiveTypesClass.getDeclaredMethod("toCharTypeInfo", Type.class); + toVarcharTypeInfo.setAccessible(true); + + // type instanceof VarcharType + Assert.assertEquals(toVarcharTypeInfo.invoke(instance, CharType.createCharType(255)), charTypeInfo); + // throw unsupported type + try { + toVarcharTypeInfo.invoke(instance, mock(Type.class)); + } + catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof PrestoException) { + PrestoException ex = (PrestoException) cause; + org.locationtech.jts.util.Assert.isTrue(Pattern.matches( + "Unsupported Presto type .*", ex.getMessage())); + } + } + } + + @Test + public void testToStructTypeInfo() + throws IllegalAccessException, + NoSuchMethodException, InstantiationException, + InvocationTargetException + { + Constructor hiveTypesConstructor = + HiveTypes.class.getDeclaredConstructor(); + hiveTypesConstructor.setAccessible(true); + Object instance = hiveTypesConstructor.newInstance(); + + Class hiveTypesClass = HiveTypes.class; + Method toStructTypeInfo = hiveTypesClass.getDeclaredMethod( + "toStructTypeInfo", Type.class); + toStructTypeInfo.setAccessible(true); + + // throw unsupported RowType + RowType mockRowType = mock(RowType.class); + List fields = mockRowType.getFields(); + List fieldNames = new ArrayList<>(fields.size()); + List fieldTypes = new ArrayList<>(fields.size()); + + Assert.assertEquals(getStructTypeInfo(fieldNames, fieldTypes), + toStructTypeInfo.invoke(instance, mockRowType)); + } + + @Test + public void testToListTypeInfo() + throws IllegalAccessException, + NoSuchMethodException, InstantiationException, + InvocationTargetException + { + Constructor hiveTypesConstructor = + HiveTypes.class.getDeclaredConstructor(); + hiveTypesConstructor.setAccessible(true); + Object instance = hiveTypesConstructor.newInstance(); + + Class hiveTypesClass = HiveTypes.class; + Method toListTypeInfo = hiveTypesClass.getDeclaredMethod( + "toListTypeInfo", Type.class); + toListTypeInfo.setAccessible(true); + + // type instanceof ArrayType + ArrayType mockArrayType = mock(ArrayType.class); + try { + toListTypeInfo.invoke(instance, mockArrayType); + } + catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + Assert.assertTrue(cause instanceof NullPointerException); + } + } + + @Test + public void testToMapTypeInfo() + throws IllegalAccessException, + NoSuchMethodException, InstantiationException, + InvocationTargetException + { + Constructor hiveTypesConstructor = + HiveTypes.class.getDeclaredConstructor(); + hiveTypesConstructor.setAccessible(true); + Object instance = hiveTypesConstructor.newInstance(); + + Class hiveTypesClass = HiveTypes.class; + Method toMapTypeInfo = hiveTypesClass.getDeclaredMethod( + "toMapTypeInfo", Type.class); + toMapTypeInfo.setAccessible(true); + + // type instanceof MapType + MapType mockMapType = mock(MapType.class); + try { + toMapTypeInfo.invoke(instance, mockMapType); + } + catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + Assert.assertTrue(cause instanceof NullPointerException); + } + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestObjectEncoders.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestObjectEncoders.java new file mode 100644 index 0000000000000000000000000000000000000000..9ad1a69cd5650a0e2f94b3623eb007730c5cd620 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestObjectEncoders.java @@ -0,0 +1,343 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions.type; + +import com.google.inject.Key; +import io.airlift.slice.Slice; +import io.prestosql.server.testing.TestingPrestoServer; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.block.LongArrayBlock; +import io.prestosql.spi.block.SingleMapBlock; +import io.prestosql.spi.type.ArrayType; +import io.prestosql.spi.type.MapType; +import io.prestosql.spi.type.RowType; +import io.prestosql.spi.type.TestRowType; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.TypeManager; +import io.prestosql.spi.type.TypeSignature; +import io.prestosql.testing.assertions.Assert; +import org.apache.hadoop.hive.serde2.io.ByteWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.io.BooleanWritable; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.lang.reflect.AccessibleObject; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.HashMap; +import java.util.List; +import java.util.regex.Pattern; + +import static io.prestosql.plugin.hive.functions.HiveFunctionsTestUtils.createTestingPrestoServer; +import static io.prestosql.plugin.hive.functions.type.ObjectEncoders.createEncoder; +import static io.prestosql.spi.block.MethodHandleUtil.methodHandle; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.CharType.createCharType; +import static io.prestosql.spi.type.DecimalType.createDecimalType; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.SmallintType.SMALLINT; +import static io.prestosql.spi.type.StandardTypes.DATE; +import static io.prestosql.spi.type.StandardTypes.GEOMETRY; +import static io.prestosql.spi.type.StandardTypes.REAL; +import static io.prestosql.spi.type.StandardTypes.ROW; +import static io.prestosql.spi.type.StandardTypes.TIMESTAMP; +import static io.prestosql.spi.type.StandardTypes.UUID; +import static io.prestosql.spi.type.TinyintType.TINYINT; +import static io.prestosql.spi.type.VarbinaryType.VARBINARY; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static java.util.Arrays.asList; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableBinaryObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableBooleanObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableByteObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableIntObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableLongObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableShortObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableStringObjectInspector; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestObjectEncoders +{ + private TestingPrestoServer server; + private TypeManager typeManager; + + private static final long LONG_VAL = 123456L; + public static final long THREE_LONG_VAL = 3L; + private static final int INT_VAL = 12345; + private static final short SHORT_VAL = 1234; + private static final byte BYTE_VAL = 123; + private static final double DOUBLE_VAL = 0.1; + static final int PRECIS_LONG_VAL = 11; + static final int SCALE_LONG_VAL = 10; + static final int PRECIS_SLICE_VAL = 34; + static final int SCALE_SLICE_VAL = 33; + static final byte[] BYTES = new byte[]{12, 34, 56}; + + @BeforeClass + public void setup() + throws Exception + { + this.server = createTestingPrestoServer(); + this.typeManager = server.getInstance(Key.get(TypeManager.class)); + } + + @Test + public void testPrimitiveObjectEncoders() + { + ObjectInspector inspector; + ObjectEncoder encoder; + + inspector = writableLongObjectInspector; + encoder = createEncoder(BIGINT, inspector); + assertTrue(encoder.encode( + new LongWritable(LONG_VAL)) instanceof Long); + + inspector = writableIntObjectInspector; + encoder = createEncoder(INTEGER, inspector); + assertTrue(encoder.encode( + new IntWritable(INT_VAL)) instanceof Long); + + inspector = writableShortObjectInspector; + encoder = createEncoder(SMALLINT, inspector); + assertTrue(encoder.encode( + new ShortWritable(SHORT_VAL)) instanceof Long); + + inspector = writableByteObjectInspector; + encoder = createEncoder(TINYINT, inspector); + assertTrue(encoder.encode( + new ByteWritable(BYTE_VAL)) instanceof Long); + + inspector = writableBooleanObjectInspector; + encoder = createEncoder(BOOLEAN, inspector); + assertTrue(encoder.encode( + new BooleanWritable(true)) instanceof Boolean); + + inspector = writableDoubleObjectInspector; + encoder = createEncoder(DOUBLE, inspector); + assertTrue(encoder.encode( + new DoubleWritable(DOUBLE_VAL)) instanceof Double); + + inspector = writableHiveDecimalObjectInspector; + encoder = createEncoder(createDecimalType(PRECIS_LONG_VAL, + SCALE_LONG_VAL), inspector); + assertTrue(encoder.encode( + new HiveDecimalWritable("1.2345678910")) instanceof Long); + + encoder = createEncoder(createDecimalType(PRECIS_SLICE_VAL, + SCALE_SLICE_VAL), inspector); + assertTrue(encoder.encode( + new HiveDecimalWritable("1.281734081274028174012432412423134")) + instanceof Slice); + } + + @Test + public void testTextObjectEncoders() + { + ObjectInspector inspector; + ObjectEncoder encoder; + + inspector = writableBinaryObjectInspector; + encoder = createEncoder(VARBINARY, inspector); + assertTrue(encoder.encode( + new BytesWritable(BYTES)) instanceof Slice); + + inspector = writableStringObjectInspector; + encoder = createEncoder(VARCHAR, inspector); + assertTrue(encoder.encode( + new Text("test_varchar")) instanceof Slice); + + inspector = writableStringObjectInspector; + encoder = createEncoder(createCharType(SCALE_LONG_VAL), inspector); + assertTrue(encoder.encode( + new Text("test_char")) instanceof Slice); + } + + @Test + public void testComplexObjectEncoders() + { + ObjectInspector inspector; + ObjectEncoder encoder; + + inspector = ObjectInspectors.create(new ArrayType(BIGINT), typeManager); + encoder = createEncoder(new ArrayType(BIGINT), inspector); + assertTrue(encoder instanceof ObjectEncoders.ListObjectEncoder); + Object arrayObject = encoder.encode(new Long[]{1L, 2L, THREE_LONG_VAL}); + assertTrue(arrayObject instanceof LongArrayBlock); + assertEquals(((LongArrayBlock) arrayObject).getLong(0, 0), 1L); + assertEquals(((LongArrayBlock) arrayObject).getLong(1, 0), 2L); + assertEquals(((LongArrayBlock) arrayObject).getLong(2, 0), + THREE_LONG_VAL); + + inspector = ObjectInspectors.create(new MapType( + VARCHAR, + BIGINT, + methodHandle(TestRowType.class, "throwUnsupportedOperation"), + methodHandle(TestRowType.class, "throwUnsupportedOperation"), + methodHandle(TestRowType.class, "throwUnsupportedOperation"), + methodHandle(TestRowType.class, "throwUnsupportedOperation")), + typeManager); + encoder = createEncoder(new MapType( + VARCHAR, + BIGINT, + methodHandle(TestRowType.class, "throwUnsupportedOperation"), + methodHandle(TestRowType.class, "throwUnsupportedOperation"), + methodHandle(TestRowType.class, "throwUnsupportedOperation"), + methodHandle(TestRowType.class, + "throwUnsupportedOperation")), inspector); + assertTrue(encoder instanceof ObjectEncoders.MapObjectEncoder); + assertTrue(encoder.encode(new HashMap() { + }) instanceof SingleMapBlock); + } + + @Test(expectedExceptions = PrestoException.class, + expectedExceptionsMessageRegExp = "Unsupported Presto type.*") + public void testCreateDecoder() + { + Type mockType = mock(Type.class); + TypeManager mockTypeManager = mock(TypeManager.class); + TypeSignature mockTypeSignature = mock(TypeSignature.class); + when(mockType.getTypeSignature()).thenReturn(mockTypeSignature); + // throw unsupported type + when(mockTypeSignature.getBase()).thenReturn(UUID); + + ObjectInputDecoders.createDecoder(mockType, mockTypeManager); + } + + @Test + public void testCreateEncoder() + { + Type mockType = mock(Type.class); + ObjectInspector mockObjectInspector = mock(ObjectInspector.class); + TypeSignature mockTypeSignature = mock(TypeSignature.class); + when(mockType.getTypeSignature()).thenReturn(mockTypeSignature); + + // case DATE + when(mockTypeSignature.getBase()).thenReturn(DATE); + try { + createEncoder(mockType, mockObjectInspector); + } + catch (IllegalArgumentException ignored) { + } + // case REAL + when(mockTypeSignature.getBase()).thenReturn(REAL); + try { + createEncoder(mockType, mockObjectInspector); + } + catch (IllegalArgumentException ignored) { + } + // case TIMESTAMP + when(mockTypeSignature.getBase()).thenReturn(TIMESTAMP); + try { + createEncoder(mockType, mockObjectInspector); + } + catch (IllegalArgumentException ignored) { + } + // case ROW + when(mockTypeSignature.getBase()).thenReturn(ROW); + try { + createEncoder(mockType, mockObjectInspector); + } + catch (IllegalArgumentException ignored) { + } + // throw unsupported type + when(mockTypeSignature.getBase()).thenReturn(GEOMETRY); + try { + createEncoder(mockType, mockObjectInspector); + } + catch (PrestoException e) { + org.locationtech.jts.util.Assert.isTrue(Pattern.matches( + "Unsupported Presto type .*", e.getMessage())); + } + } + + @Test + public void testEncode() throws InvocationTargetException, InstantiationException, + IllegalAccessException, NoSuchMethodException + { + List fields = asList( + RowType.field("bool_col", BOOLEAN), + RowType.field("double_col", DOUBLE), + RowType.field("array_col", new ArrayType(VARCHAR))); + RowType rowType = RowType.from(fields); + + Class[] declaredClasses = ObjectEncoders.class.getDeclaredClasses(); + + for (Class c : declaredClasses) { + int mod = c.getModifiers(); + String modifier = Modifier.toString(mod); + if (modifier.contains("public") && c.getName().contains("StructObjectEncoder")) { + Constructor[] declaredConstructors = c.getDeclaredConstructors(); + AccessibleObject.setAccessible(declaredConstructors, true); + + Constructor constructor = declaredConstructors[0]; + Object object = constructor.newInstance(rowType, mock(StructObjectInspector.class)); + Method method = c.getMethod("encode", Object.class); + method.setAccessible(true); + + Assert.assertEquals(method.invoke(object, (Object) null), null); + try { + method.invoke(object, new Text("test_char")); + } + catch (InvocationTargetException e) { + Assert.assertTrue(e.getCause() instanceof IllegalStateException); + } + } + } + } + + @Test + public void testCreate() throws InvocationTargetException, InstantiationException, + IllegalAccessException, NoSuchMethodException + { + List fields = asList( + RowType.field("bool_col", BOOLEAN), + RowType.field("double_col", DOUBLE), + RowType.field("array_col", new ArrayType(VARCHAR))); + RowType rowType = RowType.from(fields); + + Class[] declaredClasses = ObjectEncoders.class.getDeclaredClasses(); + + for (Class c : declaredClasses) { + int mod = c.getModifiers(); + String modifier = Modifier.toString(mod); + if (modifier.contains("public") && c.getName().contains("StructObjectEncoder")) { + Constructor[] declaredConstructors = c.getDeclaredConstructors(); + AccessibleObject.setAccessible(declaredConstructors, true); + + Constructor constructor = declaredConstructors[0]; + Object object = constructor.newInstance(rowType, mock(StructObjectInspector.class)); + Method method = c.getMethod("create", Type.class, Object.class); + method.setAccessible(true); + } + } + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestObjectInputDecoders.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestObjectInputDecoders.java new file mode 100644 index 0000000000000000000000000000000000000000..1e12ccf38ef2a6088764272d497410d041bca367 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestObjectInputDecoders.java @@ -0,0 +1,214 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions.type; + +import com.google.common.collect.ImmutableList; +import com.google.inject.Key; +import io.airlift.slice.Slices; +import io.prestosql.server.testing.TestingPrestoServer; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.LongArrayBlock; +import io.prestosql.spi.type.ArrayType; +import io.prestosql.spi.type.MapType; +import io.prestosql.spi.type.TestRowType; +import io.prestosql.spi.type.TypeManager; +import io.prestosql.testing.assertions.Assert; +import org.apache.hadoop.hive.common.type.Date; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.lang.reflect.AccessibleObject; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Optional; + +import static io.prestosql.plugin.hive.functions.HiveFunctionsTestUtils.createTestingPrestoServer; +import static io.prestosql.plugin.hive.functions.type.ObjectInputDecoders.createDecoder; +import static io.prestosql.plugin.hive.functions.type.TestObjectEncoders.BYTES; +import static io.prestosql.plugin.hive.functions.type.TestObjectEncoders.PRECIS_LONG_VAL; +import static io.prestosql.plugin.hive.functions.type.TestObjectEncoders.PRECIS_SLICE_VAL; +import static io.prestosql.plugin.hive.functions.type.TestObjectEncoders.SCALE_LONG_VAL; +import static io.prestosql.plugin.hive.functions.type.TestObjectEncoders.SCALE_SLICE_VAL; +import static io.prestosql.spi.block.MethodHandleUtil.methodHandle; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.CharType.createCharType; +import static io.prestosql.spi.type.DecimalType.createDecimalType; +import static io.prestosql.spi.type.Decimals.parseIncludeLeadingZerosInPrecision; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.RealType.REAL; +import static io.prestosql.spi.type.SmallintType.SMALLINT; +import static io.prestosql.spi.type.TinyintType.TINYINT; +import static io.prestosql.spi.type.VarbinaryType.VARBINARY; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static org.mockito.Mockito.mock; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestObjectInputDecoders +{ + private TestingPrestoServer server; + private TypeManager typeManager; + + private static final long ACTUAL_DAYS = 18380L; + private static final int EXPECTED_YEAR = 2020 - 1900; + private static final int EXPECTED_MONTH = 4 - 1; + private static final long BIGINT_VAL = 123456L; + private static final long INTEGER_VAL = 12345L; + private static final long SMALLINT_VAL = 1234L; + private static final long TINYINT_VAL = 123L; + private static final double REAL_VAL = 0.2; + private static final double DOUBLE_VAL = 0.1; + + @BeforeClass + public void setup() + throws Exception + { + this.server = createTestingPrestoServer(); + this.typeManager = server.getInstance(Key.get(TypeManager.class)); + } + + @Test + public void testToDate() + { + Date date = DateTimeUtils.createDate(ACTUAL_DAYS); + assertEquals(date.getYear(), EXPECTED_YEAR); + assertEquals(date.getMonth(), EXPECTED_MONTH); + } + + @Test + public void testPrimitiveObjectDecoders() + { + ObjectInputDecoder decoder; + + decoder = createDecoder(BIGINT, typeManager); + assertTrue(decoder.decode(BIGINT_VAL) instanceof Long); + + decoder = createDecoder(INTEGER, typeManager); + assertTrue(decoder.decode(INTEGER_VAL) instanceof Integer); + + decoder = createDecoder(SMALLINT, typeManager); + assertTrue(decoder.decode(SMALLINT_VAL) instanceof Short); + + decoder = createDecoder(TINYINT, typeManager); + assertTrue(decoder.decode(TINYINT_VAL) instanceof Byte); + + decoder = createDecoder(BOOLEAN, typeManager); + assertTrue(decoder.decode(true) instanceof Boolean); + + decoder = createDecoder(REAL, typeManager); + assertTrue(decoder.decode(((float) REAL_VAL)) instanceof Float); + + decoder = createDecoder(DOUBLE, typeManager); + assertTrue(decoder.decode(DOUBLE_VAL) instanceof Double); + } + + @Test + public void testDecimalObjectDecoders() + { + ObjectInputDecoder decoder; + + // short decimal + decoder = createDecoder(createDecimalType( + PRECIS_LONG_VAL, SCALE_LONG_VAL), typeManager); + assertTrue( + decoder.decode(decimal("1.2345678910")) instanceof HiveDecimal); + + // long decimal + decoder = createDecoder(createDecimalType( + PRECIS_SLICE_VAL, SCALE_SLICE_VAL), typeManager); + assertTrue(decoder.decode(decimal( + "1.281734081274028174012432412423134")) instanceof HiveDecimal); + } + + @Test + public void testSliceObjectDecoders() + { + ObjectInputDecoder decoder; + + decoder = createDecoder(VARBINARY, typeManager); + assertTrue( + decoder.decode(Slices.wrappedBuffer(BYTES)) instanceof byte[]); + + decoder = createDecoder(VARCHAR, typeManager); + assertTrue(decoder.decode(Slices.utf8Slice( + "test_varchar")) instanceof String); + + decoder = createDecoder(createCharType(SCALE_LONG_VAL), typeManager); + assertTrue(decoder.decode(Slices.utf8Slice( + "test_char")) instanceof String); + } + + @Test + public void testBlockObjectDecoders() + { + ObjectInputDecoder decoder; + + decoder = createDecoder(new ArrayType(BIGINT), typeManager); + assertTrue( + decoder instanceof ObjectInputDecoders.ArrayObjectInputDecoder); + assertEquals(((ArrayList) decoder.decode( + createLongArrayBlock())).get(0), 2L); + + decoder = createDecoder(new MapType( + BIGINT, + BIGINT, + methodHandle(TestRowType.class, "throwUnsupportedOperation"), + methodHandle(TestRowType.class, "throwUnsupportedOperation"), + methodHandle(TestRowType.class, "throwUnsupportedOperation"), + methodHandle(TestRowType.class, + "throwUnsupportedOperation")), typeManager); + assertTrue( + decoder instanceof ObjectInputDecoders.MapObjectInputDecoder); + HashMap map = (HashMap) decoder.decode(createLongArrayBlock()); + assertEquals(map.get(2L), 1L); + } + + private Block createLongArrayBlock() + { + return new LongArrayBlock(2, Optional.empty(), new long[]{2L, 1L}); + } + + private Object decimal(final String decimalString) + { + return parseIncludeLeadingZerosInPrecision(decimalString).getObject(); + } + + @Test + public void testRowObjectInputDecoder() throws InvocationTargetException, InstantiationException, + IllegalAccessException, NoSuchMethodException + { + Class[] declaredClasses = ObjectInputDecoders.class.getDeclaredClasses(); + + for (Class c : declaredClasses) { + if (c.getName().contains("RowObjectInputDecoder")) { + Constructor[] declaredConstructors = c.getDeclaredConstructors(); + AccessibleObject.setAccessible(declaredConstructors, true); + + Constructor constructor = declaredConstructors[0]; + Object object = constructor.newInstance(ImmutableList.of(mock(BlockInputDecoder.class))); + Method method = c.getMethod("decode", Object.class); + method.setAccessible(true); + + Assert.assertNull(method.invoke(object, (Object) null)); + } + } + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestObjectInspectors.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestObjectInspectors.java new file mode 100644 index 0000000000000000000000000000000000000000..13374e2f20b0ee4b82a310f7e2d50d15653a1c19 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestObjectInspectors.java @@ -0,0 +1,148 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions.type; + +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.type.RowType; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.TypeManager; +import io.prestosql.spi.type.TypeSignature; +import io.prestosql.spi.type.UnknownType; +import io.prestosql.testing.assertions.Assert; +import org.testng.annotations.Test; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +import static io.prestosql.client.ClientStandardTypes.ROW; +import static io.prestosql.spi.type.StandardTypes.BOOLEAN; +import static io.prestosql.spi.type.StandardTypes.CHAR; +import static io.prestosql.spi.type.StandardTypes.DATE; +import static io.prestosql.spi.type.StandardTypes.DECIMAL; +import static io.prestosql.spi.type.StandardTypes.DOUBLE; +import static io.prestosql.spi.type.StandardTypes.INTEGER; +import static io.prestosql.spi.type.StandardTypes.REAL; +import static io.prestosql.spi.type.StandardTypes.SMALLINT; +import static io.prestosql.spi.type.StandardTypes.TIMESTAMP; +import static io.prestosql.spi.type.StandardTypes.TINYINT; +import static io.prestosql.spi.type.StandardTypes.VARBINARY; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaBooleanObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaByteObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaDateObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaDoubleObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaFloatObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaIntObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaShortObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaStringObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaTimestampObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaVoidObjectInspector; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TestObjectInspectors +{ + private TypeManager mockTypeManager = mock(TypeManager.class); + + @Test(expectedExceptions = PrestoException.class, + expectedExceptionsMessageRegExp = "Unsupported Presto type.*") + public void testCreate() + { + Type mockType = mock(Type.class); + TypeSignature mockTypeSignature = mock(TypeSignature.class); + when(mockType.getTypeSignature()).thenReturn(mockTypeSignature); + + // case UnknownType.NAME + when(mockTypeSignature.getBase()).thenReturn(UnknownType.NAME); + Assert.assertEquals(javaVoidObjectInspector, + ObjectInspectors.create(mockType, mockTypeManager)); + // case INTEGER + when(mockTypeSignature.getBase()).thenReturn(INTEGER); + Assert.assertEquals(javaIntObjectInspector, + ObjectInspectors.create(mockType, mockTypeManager)); + // case SMALLINT + when(mockTypeSignature.getBase()).thenReturn(SMALLINT); + Assert.assertEquals(javaShortObjectInspector, + ObjectInspectors.create(mockType, mockTypeManager)); + // case TINYINT + when(mockTypeSignature.getBase()).thenReturn(TINYINT); + Assert.assertEquals(javaByteObjectInspector, + ObjectInspectors.create(mockType, mockTypeManager)); + // case BOOLEAN + when(mockTypeSignature.getBase()).thenReturn(BOOLEAN); + Assert.assertEquals(javaBooleanObjectInspector, + ObjectInspectors.create(mockType, mockTypeManager)); + // case DATE + when(mockTypeSignature.getBase()).thenReturn(DATE); + Assert.assertEquals(javaDateObjectInspector, + ObjectInspectors.create(mockType, mockTypeManager)); + // case DECIMAL + when(mockTypeSignature.getBase()).thenReturn(DECIMAL); + try { + ObjectInspectors.create(mockType, mockTypeManager); + } + catch (IllegalArgumentException e) { + Assert.assertTrue(e.getMessage().contains("Invalid decimal type")); + } + // case REAL + when(mockTypeSignature.getBase()).thenReturn(REAL); + Assert.assertEquals(javaFloatObjectInspector, + ObjectInspectors.create(mockType, mockTypeManager)); + // case DOUBLE + when(mockTypeSignature.getBase()).thenReturn(DOUBLE); + Assert.assertEquals(javaDoubleObjectInspector, + ObjectInspectors.create(mockType, mockTypeManager)); + // case TIMESTAMP + when(mockTypeSignature.getBase()).thenReturn(TIMESTAMP); + Assert.assertEquals(javaTimestampObjectInspector, + ObjectInspectors.create(mockType, mockTypeManager)); + // case VARBINARY + when(mockTypeSignature.getBase()).thenReturn(VARBINARY); + Assert.assertEquals(javaByteArrayObjectInspector, + ObjectInspectors.create(mockType, mockTypeManager)); + // case CHAR + when(mockTypeSignature.getBase()).thenReturn(CHAR); + Assert.assertEquals(javaStringObjectInspector, + ObjectInspectors.create(mockType, mockTypeManager)); + // throw unsupported type + when(mockTypeSignature.getBase()).thenReturn(ROW); + Assert.assertEquals(javaByteArrayObjectInspector, + ObjectInspectors.create(mockType, mockTypeManager)); + } + + @Test + public void testCreateForRow() throws NoSuchMethodException, + InvocationTargetException, InstantiationException, + IllegalAccessException + { + Constructor objectInspectorsConstructor = + ObjectInspectors.class.getDeclaredConstructor(); + objectInspectorsConstructor.setAccessible(true); + Object instance = objectInspectorsConstructor.newInstance(); + Class objectInspectorsClass = ObjectInspectors.class; + Method method = objectInspectorsClass.getDeclaredMethod( + "createForRow", RowType.class, TypeManager.class); + method.setAccessible(true); + + RowType mockRowType = mock(RowType.class); + + Object result = method.invoke(instance, mockRowType, + mockTypeManager).toString(); + + Assert.assertEquals("org.apache.hadoop.hive.serde2." + + "objectinspector.StandardStructObjectInspector<>", result); + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestPrestoTypes.java b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestPrestoTypes.java new file mode 100644 index 0000000000000000000000000000000000000000..2311867ae4dd57b470f6ffec49f80f8c8f39d6c7 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/java/io/prestosql/plugin/hive/functions/type/TestPrestoTypes.java @@ -0,0 +1,291 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.functions.type; + +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.type.BigintType; +import io.prestosql.spi.type.BooleanType; +import io.prestosql.spi.type.DateType; +import io.prestosql.spi.type.DoubleType; +import io.prestosql.spi.type.IntegerType; +import io.prestosql.spi.type.RealType; +import io.prestosql.spi.type.SmallintType; +import io.prestosql.spi.type.TimestampType; +import io.prestosql.spi.type.TinyintType; +import io.prestosql.spi.type.TypeManager; +import io.prestosql.spi.type.TypeSignature; +import io.prestosql.spi.type.VarbinaryType; +import io.prestosql.spi.type.VarcharType; +import io.prestosql.testing.assertions.Assert; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.testng.annotations.Test; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.regex.Pattern; + +import static io.prestosql.plugin.hive.functions.HiveFunctionErrorCode.HIVE_FUNCTION_UNSUPPORTED_HIVE_TYPE; +import static io.prestosql.plugin.hive.functions.type.PrestoTypes.createDecimalType; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.LIST; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.MAP; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.PRIMITIVE; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.STRUCT; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.UNION; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.BINARY; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.BOOLEAN; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.BYTE; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.DATE; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.DOUBLE; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.FLOAT; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.INT; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.LONG; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.SHORT; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.STRING; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.TIMESTAMP; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.UNKNOWN; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TestPrestoTypes +{ + private TypeManager mockTypeManager = mock(TypeManager.class); + + @Test + public void testCreateDecimalType() + { + TypeSignature mockTypeSignature = mock(TypeSignature.class); + + try { + createDecimalType(mockTypeSignature); + } + catch (IllegalArgumentException e) { + org.locationtech.jts.util.Assert.isTrue(Pattern.matches( + "Invalid decimal type .*", e.getMessage())); + } + } + + @Test + public void testFromObjectInspector() + { + ObjectInspector mockObjectInspector = mock(ObjectInspector.class); + + // case PRIMITIVE + when(mockObjectInspector.getCategory()).thenReturn(PRIMITIVE); + try { + PrestoTypes.fromObjectInspector(mockObjectInspector, mockTypeManager); + } + catch (IllegalArgumentException ignored) { + } + // case LIST + when(mockObjectInspector.getCategory()).thenReturn(LIST); + try { + PrestoTypes.fromObjectInspector(mockObjectInspector, mockTypeManager); + } + catch (IllegalArgumentException ignored) { + } + // case MAP + when(mockObjectInspector.getCategory()).thenReturn(MAP); + try { + PrestoTypes.fromObjectInspector(mockObjectInspector, mockTypeManager); + } + catch (IllegalArgumentException ignored) { + } + // case STRUCT + when(mockObjectInspector.getCategory()).thenReturn(STRUCT); + try { + PrestoTypes.fromObjectInspector(mockObjectInspector, mockTypeManager); + } + catch (IllegalArgumentException ignored) { + } + + // throw unsupported type + when(mockObjectInspector.getCategory()).thenReturn(UNION); + try { + PrestoTypes.fromObjectInspector(mockObjectInspector, mockTypeManager); + } + catch (PrestoException e) { + org.locationtech.jts.util.Assert.isTrue(Pattern.matches("Unsupported Hive type .*", + e.getMessage())); + } + } + + @Test(expectedExceptions = PrestoException.class, + expectedExceptionsMessageRegExp = "Unsupported Hive type.*") + public void testFromPrimitive() throws InstantiationException, + IllegalAccessException, NoSuchMethodException, + InvocationTargetException + { + Constructor prestoTypesConstructor = + PrestoTypes.class.getDeclaredConstructor(); + prestoTypesConstructor.setAccessible(true); + Object instance = prestoTypesConstructor.newInstance(); + Class prestoTypesClass = PrestoTypes.class; + Method method = prestoTypesClass.getDeclaredMethod( + "fromPrimitive", PrimitiveObjectInspector.class); + method.setAccessible(true); + PrimitiveObjectInspector mockPrimitiveObjectInspector = + mock(PrimitiveObjectInspector.class); + + // case BOOLEAN + when(mockPrimitiveObjectInspector.getPrimitiveCategory()) + .thenReturn(BOOLEAN); + Assert.assertEquals(BooleanType.BOOLEAN, + method.invoke(instance, mockPrimitiveObjectInspector)); + // case BYTE + when(mockPrimitiveObjectInspector.getPrimitiveCategory()) + .thenReturn(BYTE); + Assert.assertEquals(TinyintType.TINYINT, + method.invoke(instance, mockPrimitiveObjectInspector)); + // case SHORT + when(mockPrimitiveObjectInspector.getPrimitiveCategory()) + .thenReturn(SHORT); + Assert.assertEquals(SmallintType.SMALLINT, + method.invoke(instance, mockPrimitiveObjectInspector)); + // case INT + when(mockPrimitiveObjectInspector.getPrimitiveCategory()) + .thenReturn(INT); + Assert.assertEquals(IntegerType.INTEGER, + method.invoke(instance, mockPrimitiveObjectInspector)); + // case LONG + when(mockPrimitiveObjectInspector.getPrimitiveCategory()) + .thenReturn(LONG); + Assert.assertEquals(BigintType.BIGINT, + method.invoke(instance, mockPrimitiveObjectInspector)); + // case FLOAT + when(mockPrimitiveObjectInspector.getPrimitiveCategory()) + .thenReturn(FLOAT); + Assert.assertEquals(RealType.REAL, + method.invoke(instance, mockPrimitiveObjectInspector)); + // case DOUBLE + when(mockPrimitiveObjectInspector.getPrimitiveCategory()) + .thenReturn(DOUBLE); + Assert.assertEquals(DoubleType.DOUBLE, + method.invoke(instance, mockPrimitiveObjectInspector)); + // case STRING + when(mockPrimitiveObjectInspector.getPrimitiveCategory()) + .thenReturn(STRING); + Assert.assertEquals(VarcharType.VARCHAR, + method.invoke(instance, mockPrimitiveObjectInspector)); + // case DATE + when(mockPrimitiveObjectInspector.getPrimitiveCategory()) + .thenReturn(DATE); + Assert.assertEquals(DateType.DATE, + method.invoke(instance, mockPrimitiveObjectInspector)); + // case TIMESTAMP + when(mockPrimitiveObjectInspector.getPrimitiveCategory()) + .thenReturn(TIMESTAMP); + Assert.assertEquals(TimestampType.TIMESTAMP, + method.invoke(instance, mockPrimitiveObjectInspector)); + // case BINARY + when(mockPrimitiveObjectInspector.getPrimitiveCategory()) + .thenReturn(BINARY); + Assert.assertEquals(VarbinaryType.VARBINARY, + method.invoke(instance, mockPrimitiveObjectInspector)); + // throw unsupported type + when(mockPrimitiveObjectInspector.getPrimitiveCategory()) + .thenReturn(UNKNOWN); + try { + method.invoke(instance, mockPrimitiveObjectInspector); + } + catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof PrestoException) { + PrestoException ex = (PrestoException) cause; + throw new PrestoException(HIVE_FUNCTION_UNSUPPORTED_HIVE_TYPE, + ex.getMessage()); + } + } + } + + @Test + public void testFromList() throws InvocationTargetException, IllegalAccessException, + NoSuchMethodException, InstantiationException + { + Constructor prestoTypesConstructor = + PrestoTypes.class.getDeclaredConstructor(); + prestoTypesConstructor.setAccessible(true); + Object instance = prestoTypesConstructor.newInstance(); + Class prestoTypesClass = PrestoTypes.class; + Method method = prestoTypesClass.getDeclaredMethod( + "fromList", ListObjectInspector.class, TypeManager.class); + method.setAccessible(true); + + ListObjectInspector mockListObjectInspector = + mock(ListObjectInspector.class); + + try { + method.invoke(instance, mockListObjectInspector, mockTypeManager); + } + catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + org.locationtech.jts.util.Assert.isTrue(cause instanceof NullPointerException); + } + } + + @Test + public void testFromMap() throws InvocationTargetException, IllegalAccessException, + NoSuchMethodException, InstantiationException + { + Constructor prestoTypesConstructor = + PrestoTypes.class.getDeclaredConstructor(); + prestoTypesConstructor.setAccessible(true); + Object instance = prestoTypesConstructor.newInstance(); + Class prestoTypesClass = PrestoTypes.class; + Method method = prestoTypesClass.getDeclaredMethod( + "fromMap", MapObjectInspector.class, TypeManager.class); + method.setAccessible(true); + + MapObjectInspector mockMapObjectInspector = + mock(MapObjectInspector.class); + + try { + method.invoke(instance, mockMapObjectInspector, mockTypeManager); + } + catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + org.locationtech.jts.util.Assert.isTrue(cause instanceof NullPointerException); + } + } + + @Test + public void testFromStruct() throws InvocationTargetException, IllegalAccessException, + NoSuchMethodException, InstantiationException + { + Constructor prestoTypesConstructor = + PrestoTypes.class.getDeclaredConstructor(); + prestoTypesConstructor.setAccessible(true); + Object instance = prestoTypesConstructor.newInstance(); + Class prestoTypesClass = PrestoTypes.class; + Method method = prestoTypesClass.getDeclaredMethod( + "fromStruct", StructObjectInspector.class, TypeManager.class); + method.setAccessible(true); + + StructObjectInspector mockStructObjectInspector = + mock(StructObjectInspector.class); + + try { + method.invoke(instance, mockStructObjectInspector, mockTypeManager); + } + catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + org.locationtech.jts.util.Assert.isTrue(cause instanceof IllegalArgumentException); + } + } +} diff --git a/omnidata/omnidata-hiveudf-loader/src/test/resources/UdfTest-1.0-SNAPSHOT.jar b/omnidata/omnidata-hiveudf-loader/src/test/resources/UdfTest-1.0-SNAPSHOT.jar new file mode 100644 index 0000000000000000000000000000000000000000..9e67fd6ea036c65e79da1669da7956245a4c3ac2 Binary files /dev/null and b/omnidata/omnidata-hiveudf-loader/src/test/resources/UdfTest-1.0-SNAPSHOT.jar differ diff --git a/omnidata/omnidata-hiveudf-loader/src/test/sql/function-testing.sql b/omnidata/omnidata-hiveudf-loader/src/test/sql/function-testing.sql new file mode 100644 index 0000000000000000000000000000000000000000..9c63333c8caf7d6f15a0bcbcaab72b21ab25a938 --- /dev/null +++ b/omnidata/omnidata-hiveudf-loader/src/test/sql/function-testing.sql @@ -0,0 +1,51 @@ +---- +CREATE TABLE memory.default.function_testing ( + c_bigint bigint, + c_integer integer, + c_smallint smallint, + c_tinyint tinyint, + c_boolean boolean, + c_date date, + c_decimal_52 decimal(5,2), + c_real real, + c_double double, + c_timestamp timestamp, + c_varchar varchar, + c_varchar_a varchar, + c_varchar_z varchar, + c_varchar_null varchar, + c_varchar_10 varchar(10), + c_char_10 char(10), + c_array_integer array(integer), + c_array_varchar array(varchar), + c_array_varchar10 array(varchar(10)), + c_map_varchar_integer map(varchar(10), integer), + c_map_string_string map(varchar, varchar), + c_map_varchar_varchar map(varchar(10), varchar(10)) +) + +---- +INSERT INTO memory.default.function_testing VALUES ( + BIGINT '1', + INTEGER '1', + SMALLINT '1', + TINYINT '1', + BOOLEAN 'false', + DATE '2020-04-28', + DECIMAL '123.45', + REAL '123.45', + DOUBLE '123.45', + TIMESTAMP '2020-04-28 15:54', + 'varchar', + 'a', + 'z', + null, + 'varchar10', + 'char10', + ARRAY [1, 2, 3], + ARRAY ['a', 'b', 'c'], + ARRAY ['a', 'b', 'c'], + map(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3]), + map(ARRAY[cast('a' as varchar), 'b', 'c'], ARRAY[cast('1' as varchar), '2', '3']), + map(ARRAY['a', 'b', 'c'], ARRAY['1', '2', '3']) +) \ No newline at end of file diff --git a/omnidata/omnidata-openlookeng-connector/connector/pom.xml b/omnidata/omnidata-openlookeng-connector/connector/pom.xml index 7f9e6c4bc1dd92230a18868c9681ce3463756202..1f4c602db2c80eb7a46ead54164dd61566f88bfc 100644 --- a/omnidata/omnidata-openlookeng-connector/connector/pom.xml +++ b/omnidata/omnidata-openlookeng-connector/connector/pom.xml @@ -760,6 +760,23 @@ + + com.mycila + license-maven-plugin + 2.3 + + + ${air.main.basedir}/src/main/resource/license/license-header.txt + ${air.main.basedir}/src/main/resource/license/license-header-alternate-2010.txt + ${air.main.basedir}/src/main/resource/license/license-header-alternate-2012.txt + ${air.main.basedir}/src/main/resource/license/license-header-alternate-2020.txt + ${air.main.basedir}/src/main/resource/license/license-header-alternate-2021.txt + ${air.main.basedir}/src/main/resource/license/license-header-alternate-2022.txt + ${air.main.basedir}/src/main/resource/license/license-header-alternate-2022-2022.txt + ${air.main.basedir}/src/main/resource/license/license-header-third.txt + + + diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/HivePageSourceProvider.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/HivePageSourceProvider.java index e72e2cad802d24afd6f6ac1e7297610267f57403..9af74c7504ff28f6cb9e89badbc7103e1a0dad3e 100644 --- a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/HivePageSourceProvider.java +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/HivePageSourceProvider.java @@ -16,7 +16,6 @@ package io.prestosql.plugin.hive; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.huawei.boostkit.omnidata.decode.impl.OpenLooKengDeserializer; import com.huawei.boostkit.omnidata.model.Predicate; import com.huawei.boostkit.omnidata.model.TaskSource; import com.huawei.boostkit.omnidata.model.datasource.DataSource; @@ -31,6 +30,7 @@ import io.prestosql.plugin.hive.HiveBucketing.BucketingVersion; import io.prestosql.plugin.hive.coercions.HiveCoercer; import io.prestosql.plugin.hive.omnidata.OmniDataNodeManager; import io.prestosql.plugin.hive.omnidata.OmniDataNodeStatus; +import io.prestosql.plugin.hive.omnidata.decode.impl.OpenLooKengDeserializer; import io.prestosql.plugin.hive.orc.OrcConcatPageSource; import io.prestosql.plugin.hive.util.IndexCache; import io.prestosql.spi.Page; diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/AbstractDecoding.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/AbstractDecoding.java new file mode 100644 index 0000000000000000000000000000000000000000..a1b7d5a1e223f0bf4c3f8b438c6fc2a2f7be8455 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/AbstractDecoding.java @@ -0,0 +1,204 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode; + +import com.huawei.boostkit.omnidata.exception.OmniDataException; +import io.airlift.slice.SliceInput; +import io.prestosql.plugin.hive.omnidata.decode.type.DecimalDecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.DecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.LongToByteDecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.LongToFloatDecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.LongToIntDecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.LongToShortDecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.TimestampDecodeType; +import io.prestosql.spi.type.DateType; +import io.prestosql.spi.type.RowType; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.stream.IntStream; + +/** + * Abstract decoding + * + * @param decoding type + * @since 2022-07-18 + */ +public abstract class AbstractDecoding + implements Decoding +{ + private static final Map DECODE_METHODS; + + static { + DECODE_METHODS = new HashMap<>(); + Method[] methods = Decoding.class.getDeclaredMethods(); + for (Method method : methods) { + if (method.isAnnotationPresent(Decode.class)) { + DECODE_METHODS.put(method.getAnnotation(Decode.class).value(), method); + } + } + } + + private Method getDecodeMethod(String decodeName) + { + return DECODE_METHODS.get(decodeName); + } + + private String getDecodeName(SliceInput input) + { + int length = input.readInt(); + byte[] bytes = new byte[length]; + input.readBytes(bytes); + + return new String(bytes, StandardCharsets.UTF_8); + } + + private Optional typeToDecodeName(DecodeType type) + { + Class javaType = null; + if (type.getJavaType().isPresent()) { + javaType = type.getJavaType().get(); + } + if (javaType == double.class) { + return Optional.of("DOUBLE_ARRAY"); + } + else if (javaType == float.class) { + return Optional.of("FLOAT_ARRAY"); + } + else if (javaType == int.class) { + return Optional.of("INT_ARRAY"); + } + else if (javaType == long.class) { + return Optional.of("LONG_ARRAY"); + } + else if (javaType == byte.class) { + return Optional.of("BYTE_ARRAY"); + } + else if (javaType == boolean.class) { + return Optional.of("BOOLEAN_ARRAY"); + } + else if (javaType == short.class) { + return Optional.of("SHORT_ARRAY"); + } + else if (javaType == String.class) { + return Optional.of("VARIABLE_WIDTH"); + } + else if (javaType == RowType.class) { + return Optional.of("ROW"); + } + else if (javaType == DateType.class) { + return Optional.of("DATE"); + } + else if (javaType == LongToIntDecodeType.class) { + return Optional.of("LONG_TO_INT"); + } + else if (javaType == LongToShortDecodeType.class) { + return Optional.of("LONG_TO_SHORT"); + } + else if (javaType == LongToByteDecodeType.class) { + return Optional.of("LONG_TO_BYTE"); + } + else if (javaType == LongToFloatDecodeType.class) { + return Optional.of("LONG_TO_FLOAT"); + } + else if (javaType == DecimalDecodeType.class) { + return Optional.of("DECIMAL"); + } + else if (javaType == TimestampDecodeType.class) { + return Optional.of("TIMESTAMP"); + } + else { + return Optional.empty(); + } + } + + private boolean[] getIsNullValue(byte value) + { + boolean[] isNullValue = new boolean[8]; + isNullValue[0] = ((value & 0b1000_0000) != 0); + isNullValue[1] = ((value & 0b0100_0000) != 0); + isNullValue[2] = ((value & 0b0010_0000) != 0); + isNullValue[3] = ((value & 0b0001_0000) != 0); + isNullValue[4] = ((value & 0b0000_1000) != 0); + isNullValue[5] = ((value & 0b0000_0100) != 0); + isNullValue[6] = ((value & 0b0000_0010) != 0); + isNullValue[7] = ((value & 0b0000_0001) != 0); + + return isNullValue; + } + + @Override + public T decode(Optional type, SliceInput sliceInput) + { + try { + String decodeName = getDecodeName(sliceInput); + if (type.isPresent()) { + Optional decodeNameOpt = typeToDecodeName(type.get()); + if ("DECIMAL".equals(decodeNameOpt.orElse(decodeName)) && !"RLE".equals(decodeName)) { + Method method = getDecodeMethod("DECIMAL"); + return (T) method.invoke(this, type, sliceInput, decodeName); + } + if (!"RLE".equals(decodeName)) { + decodeName = decodeNameOpt.orElse(decodeName); + } + } + Method method = getDecodeMethod(decodeName); + return (T) method.invoke(this, type, sliceInput); + } + catch (IllegalAccessException | InvocationTargetException e) { + throw new OmniDataException("decode failed " + e.getMessage()); + } + } + + /** + * decode Null Bits + * + * @param sliceInput sliceInput + * @param positionCount positionCount + * @return decode boolean[] + * @since 2022-07-18 + */ + public Optional decodeNullBits(SliceInput sliceInput, int positionCount) + { + if (!sliceInput.readBoolean()) { + return Optional.empty(); + } + + // read null bits 8 at a time + boolean[] valueIsNull = new boolean[positionCount]; + for (int position = 0; position < (positionCount & ~0b111); position += 8) { + boolean[] nextEightValue = getIsNullValue(sliceInput.readByte()); + int finalPosition = position; + IntStream.range(0, 8).forEach(pos -> valueIsNull[finalPosition + pos] = nextEightValue[pos]); + } + + // read last null bits + if ((positionCount & 0b111) > 0) { + byte value = sliceInput.readByte(); + int maskInt = 0b1000_0000; + for (int pos = positionCount & ~0b111; pos < positionCount; pos++) { + valueIsNull[pos] = ((value & maskInt) != 0); + maskInt >>>= 1; + } + } + + return Optional.of(valueIsNull); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/Decode.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/Decode.java new file mode 100644 index 0000000000000000000000000000000000000000..efea3a4904d5ff4df9570fb6c495e8751fc1735d --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/Decode.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode; + +import com.google.inject.BindingAnnotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +@BindingAnnotation +public @interface Decode +{ + String value(); +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/Decoding.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/Decoding.java new file mode 100644 index 0000000000000000000000000000000000000000..3f14fc7f467885f095bb3c59b272aefed38b6325 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/Decoding.java @@ -0,0 +1,264 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode; + +import io.airlift.slice.SliceInput; +import io.prestosql.plugin.hive.omnidata.decode.type.DecodeType; + +import java.lang.reflect.InvocationTargetException; +import java.util.Optional; + +/** + * Decode Slice to type + * + * @param + * @since 2022-07-18 + */ +public interface Decoding +{ + /** + * decode + * + * @param type decode type + * @param sliceInput content + * @return T + */ + T decode(Optional type, SliceInput sliceInput); + + /** + * decode array type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("ARRAY") + T decodeArray(Optional type, SliceInput sliceInput); + + /** + * decode byte array type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("BYTE_ARRAY") + T decodeByteArray(Optional type, SliceInput sliceInput); + + /** + * decode boolean array type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("BOOLEAN_ARRAY") + T decodeBooleanArray(Optional type, SliceInput sliceInput); + + /** + * decode int array type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("INT_ARRAY") + T decodeIntArray(Optional type, SliceInput sliceInput); + + /** + * decode int128 array type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("INT128_ARRAY") + T decodeInt128Array(Optional type, SliceInput sliceInput); + + /** + * decode short array type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("SHORT_ARRAY") + T decodeShortArray(Optional type, SliceInput sliceInput); + + /** + * decode long array type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("LONG_ARRAY") + T decodeLongArray(Optional type, SliceInput sliceInput); + + /** + * decode float array type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("FLOAT_ARRAY") + T decodeFloatArray(Optional type, SliceInput sliceInput); + + /** + * decode double array type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("DOUBLE_ARRAY") + T decodeDoubleArray(Optional type, SliceInput sliceInput); + + /** + * decode map type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("MAP") + T decodeMap(Optional type, SliceInput sliceInput); + + /** + * decode map element type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("MAP_ELEMENT") + T decodeSingleMap(Optional type, SliceInput sliceInput); + + /** + * decode variable width type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("VARIABLE_WIDTH") + T decodeVariableWidth(Optional type, SliceInput sliceInput); + + /** + * decode dictionary type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("DICTIONARY") + T decodeDictionary(Optional type, SliceInput sliceInput); + + /** + * decode rle type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + * @throws InvocationTargetException throw invocation target exception + * @throws IllegalAccessException throw illegal access exception + */ + @Decode("RLE") + T decodeRunLength(Optional type, SliceInput sliceInput) + throws InvocationTargetException, IllegalAccessException; + + /** + * decode row type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("ROW") + T decodeRow(Optional type, SliceInput sliceInput); + + /** + * decode date type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("DATE") + T decodeDate(Optional type, SliceInput sliceInput); + + /** + * decode long to int type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("LONG_TO_INT") + T decodeLongToInt(Optional type, SliceInput sliceInput); + + /** + * decode long to short type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("LONG_TO_SHORT") + T decodeLongToShort(Optional type, SliceInput sliceInput); + + /** + * decode long to byte type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("LONG_TO_BYTE") + T decodeLongToByte(Optional type, SliceInput sliceInput); + + /** + * decode long to float type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("LONG_TO_FLOAT") + T decodeLongToFloat(Optional type, SliceInput sliceInput); + + /** + * decode decimal type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @param decodeType storage type of decimal + * @return T + */ + @Decode("DECIMAL") + T decodeDecimal(Optional type, SliceInput sliceInput, String decodeType); + + /** + * decode timestamp type + * + * @param type type of data to decode + * @param sliceInput data to decode + * @return T + */ + @Decode("TIMESTAMP") + T decodeTimestamp(Optional type, SliceInput sliceInput); +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/impl/OpenLooKengDecoding.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/impl/OpenLooKengDecoding.java new file mode 100644 index 0000000000000000000000000000000000000000..66cf8fc9afc6eee5aaafc2eab1b5bba442c0212e --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/impl/OpenLooKengDecoding.java @@ -0,0 +1,292 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.impl; + +import io.airlift.slice.Slice; +import io.airlift.slice.SliceInput; +import io.airlift.slice.Slices; +import io.prestosql.plugin.hive.omnidata.decode.AbstractDecoding; +import io.prestosql.plugin.hive.omnidata.decode.type.ArrayDecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.DecodeType; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.ByteArrayBlock; +import io.prestosql.spi.block.DictionaryBlock; +import io.prestosql.spi.block.DictionaryId; +import io.prestosql.spi.block.Int128ArrayBlock; +import io.prestosql.spi.block.IntArrayBlock; +import io.prestosql.spi.block.LongArrayBlock; +import io.prestosql.spi.block.RunLengthEncodedBlock; +import io.prestosql.spi.block.ShortArrayBlock; +import io.prestosql.spi.block.VariableWidthBlock; + +import java.util.Optional; +import java.util.stream.IntStream; + +import static io.airlift.slice.SizeOf.SIZE_OF_INT; +import static io.prestosql.spi.block.ArrayBlock.fromElementBlock; +import static io.prestosql.spi.block.RowBlock.fromFieldBlocks; + +/** + * Decode data to block + * + * @since 2022-07-18 + */ +public class OpenLooKengDecoding + extends AbstractDecoding> +{ + @Override + public Block decodeArray(Optional type, SliceInput sliceInput) + { + Optional elementType = Optional.empty(); + if (type.isPresent()) { + if (type.get() instanceof ArrayDecodeType) { + ArrayDecodeType arrayDecodeType = (ArrayDecodeType) type.get(); + elementType = Optional.of((arrayDecodeType).getElementType()); + } + } + Block values = decode(elementType, sliceInput); + int positionCount = sliceInput.readInt(); + int[] offsets = new int[positionCount + 1]; + sliceInput.readBytes(Slices.wrappedIntArray(offsets)); + boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElseGet(() -> new boolean[positionCount]); + + return fromElementBlock(positionCount, Optional.ofNullable(valueIsNull), offsets, values); + } + + @Override + public Block decodeByteArray(Optional type, SliceInput sliceInput) + { + int positionCount = sliceInput.readInt(); + boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); + byte[] values = new byte[positionCount]; + + IntStream.range(0, positionCount) + .forEach( + position -> { + if (valueIsNull == null || !valueIsNull[position]) { + values[position] = sliceInput.readByte(); + } + }); + + return new ByteArrayBlock(positionCount, Optional.ofNullable(valueIsNull), values); + } + + @Override + public Block decodeBooleanArray(Optional type, SliceInput sliceInput) + { + return decodeByteArray(type, sliceInput); + } + + @Override + public Block decodeIntArray(Optional type, SliceInput sliceInput) + { + int posCount = sliceInput.readInt(); + boolean[] valueIsNull = decodeNullBits(sliceInput, posCount).orElse(null); + int[] values = new int[posCount]; + + for (int position = 0; position < posCount; position++) { + if (valueIsNull == null || !valueIsNull[position]) { + values[position] = sliceInput.readInt(); + } + } + + return new IntArrayBlock(posCount, Optional.ofNullable(valueIsNull), values); + } + + @Override + public Block decodeInt128Array(Optional type, SliceInput sliceInput) + { + int posCount = sliceInput.readInt(); + boolean[] valueIsNull = decodeNullBits(sliceInput, posCount).orElse(null); + long[] values = new long[posCount * 2]; + + for (int position = 0; position < posCount; position++) { + if (valueIsNull == null || !valueIsNull[position]) { + values[position * 2] = sliceInput.readLong(); + values[(position * 2) + 1] = sliceInput.readLong(); + } + } + + return new Int128ArrayBlock(posCount, Optional.ofNullable(valueIsNull), values); + } + + @Override + public Block decodeShortArray(Optional type, SliceInput sliceInput) + { + int posCount = sliceInput.readInt(); + boolean[] valueIsNull = decodeNullBits(sliceInput, posCount).orElse(null); + short[] values = new short[posCount]; + + for (int position = 0; position < posCount; position++) { + if (valueIsNull == null || !valueIsNull[position]) { + values[position] = sliceInput.readShort(); + } + } + + return new ShortArrayBlock(posCount, Optional.ofNullable(valueIsNull), values); + } + + @Override + public Block decodeLongArray(Optional type, SliceInput sliceInput) + { + int posCount = sliceInput.readInt(); + boolean[] valueIsNull = decodeNullBits(sliceInput, posCount).orElse(null); + long[] values = new long[posCount]; + + for (int position = 0; position < posCount; position++) { + if (valueIsNull == null || !valueIsNull[position]) { + values[position] = sliceInput.readLong(); + } + } + + return new LongArrayBlock(posCount, Optional.ofNullable(valueIsNull), values); + } + + @Override + public Block decodeFloatArray(Optional type, SliceInput sliceInput) + { + return decodeLongArray(type, sliceInput); + } + + @Override + public Block decodeDoubleArray(Optional type, SliceInput sliceInput) + { + return decodeLongArray(type, sliceInput); + } + + @Override + public Block decodeMap(Optional type, SliceInput sliceInput) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block decodeSingleMap(Optional type, SliceInput sliceInput) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block decodeVariableWidth(Optional type, SliceInput sliceInput) + { + int posCount = sliceInput.readInt(); + int[] offsets = new int[posCount + 1]; + + sliceInput.readBytes(Slices.wrappedIntArray(offsets), SIZE_OF_INT, posCount * SIZE_OF_INT); + + boolean[] valueIsNull = decodeNullBits(sliceInput, posCount).orElse(null); + + int blockSize = sliceInput.readInt(); + Slice slice = sliceInput.readSlice(blockSize); + + return new VariableWidthBlock(posCount, slice, offsets, Optional.ofNullable(valueIsNull)); + } + + @Override + public Block decodeDictionary(Optional type, SliceInput sliceInput) + { + int posCount = sliceInput.readInt(); + + Block dictionaryBlock = decode(type, sliceInput); + + int[] ids = new int[posCount]; + sliceInput.readBytes(Slices.wrappedIntArray(ids)); + + long mostSignificantBits = sliceInput.readLong(); + long leastSignificantBits = sliceInput.readLong(); + long sequenceId = sliceInput.readLong(); + + // We always compact the dictionary before we send it. However, dictionaryBlock comes from sliceInput, which may + // over-retain memory. + // As a result, setting dictionaryIsCompacted to true is not appropriate here. + // over-retains memory. + return new DictionaryBlock<>( + posCount, + dictionaryBlock, + ids, + false, + new DictionaryId(mostSignificantBits, leastSignificantBits, sequenceId)); + } + + @Override + public Block decodeRunLength(Optional type, SliceInput sliceInput) + { + int posCount = sliceInput.readInt(); + + Block values = decode(type, sliceInput); + + return new RunLengthEncodedBlock<>(values, posCount); + } + + @Override + public Block decodeRow(Optional type, SliceInput sliceInput) + { + int numFields = sliceInput.readInt(); + Block[] fieldBlocks = new Block[numFields]; + for (int i = 0; i < numFields; i++) { + fieldBlocks[i] = decode(type, sliceInput); + } + + int positionCount = sliceInput.readInt(); + int[] fieldBlockOffsets = new int[positionCount + 1]; + sliceInput.readBytes(Slices.wrappedIntArray(fieldBlockOffsets)); + boolean[] rowIsNull = decodeNullBits(sliceInput, positionCount).orElseGet(() -> new boolean[positionCount]); + + return fromFieldBlocks(positionCount, Optional.of(rowIsNull), fieldBlocks); + } + + @Override + public Block decodeDate(Optional type, SliceInput sliceInput) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block decodeLongToInt(Optional type, SliceInput sliceInput) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block decodeLongToShort(Optional type, SliceInput sliceInput) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block decodeLongToByte(Optional type, SliceInput sliceInput) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block decodeLongToFloat(Optional type, SliceInput sliceInput) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block decodeDecimal(Optional type, SliceInput sliceInput, String decodeType) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block decodeTimestamp(Optional type, SliceInput sliceInput) + { + throw new UnsupportedOperationException(); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/impl/OpenLooKengDeserializer.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/impl/OpenLooKengDeserializer.java new file mode 100644 index 0000000000000000000000000000000000000000..4cbd9963d1b4b4a59e773d2415b558cfd1250b35 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/impl/OpenLooKengDeserializer.java @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.impl; + +import com.huawei.boostkit.omnidata.decode.Deserializer; +import io.airlift.compress.Decompressor; +import io.airlift.compress.zstd.ZstdDecompressor; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceInput; +import io.airlift.slice.Slices; +import io.hetu.core.transport.execution.buffer.SerializedPage; +import io.prestosql.spi.Page; +import io.prestosql.spi.block.Block; + +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; + +/** + * Deserialize block + * + * @since 2022-07-18 + */ +public class OpenLooKengDeserializer + implements Deserializer +{ + private final Decompressor decompressor; + private final OpenLooKengDecoding decoding; + + /** + * Constructor of deserialize block + */ + public OpenLooKengDeserializer() + { + decoding = new OpenLooKengDecoding(); + decompressor = new ZstdDecompressor(); + } + + /** + * Decompress serialized page + * + * @param page page need decompress + * @param decompressor decompressor + * @return Slice decompressed + */ + public static Slice decompressPage(SerializedPage page, Decompressor decompressor) + { + if (!page.isCompressed()) { + return page.getSlice(); + } + Slice slice = page.getSlice(); + int uncompressedSize = page.getUncompressedSizeInBytes(); + byte[] decompressed = new byte[uncompressedSize]; + if (slice.getBase() instanceof byte[]) { + byte[] sliceBase = (byte[]) slice.getBase(); + checkState( + decompressor.decompress( + sliceBase, + (int) (slice.getAddress() - ARRAY_BYTE_BASE_OFFSET), + slice.length(), + decompressed, + 0, + uncompressedSize) == uncompressedSize); + } + + return Slices.wrappedBuffer(decompressed); + } + + @Override + public Page deserialize(SerializedPage page) + { + checkArgument(page != null, "page is null"); + + if (page.isEncrypted()) { + throw new UnsupportedOperationException("unsupported encrypted page."); + } + + Slice slice = decompressPage(page, decompressor); + SliceInput input = slice.getInput(); + int numberOfBlocks = input.readInt(); + Block[] blocks = new Block[numberOfBlocks]; + for (int i = 0; i < blocks.length; i++) { + blocks[i] = decoding.decode(Optional.empty(), input); + } + + return new Page(page.getPositionCount(), page.getPageMetadata(), blocks); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/ArrayDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/ArrayDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..be2650369e1e0af5600e5d70249e933769b748a2 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/ArrayDecodeType.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Array decode type + * + * @param decode type + * @since 2022-07-18 + */ +public class ArrayDecodeType + implements DecodeType +{ + private final T elementType; + + public ArrayDecodeType(T elementType) + { + this.elementType = elementType; + } + + public T getElementType() + { + return elementType; + } + + @Override + public Optional> getJavaType() + { + return Optional.empty(); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/BooleanDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/BooleanDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..e1d5fe64a4f615c5a488d24790390dd8d1976db9 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/BooleanDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Boolean decode type + * + * @since 2022-07-18 + */ +public class BooleanDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(boolean.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/ByteDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/ByteDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..33fe178b76ac641eaac2a7e83be9374493ccb900 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/ByteDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Byte decode type + * + * @since 2022-07-18 + */ +public class ByteDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(byte.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/DateDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/DateDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..d173d70f91a6ac326ae464b9080bbca5f4172ddd --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/DateDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Date Decode Type + * + * @since 2022-07-18 + */ +public class DateDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(io.prestosql.spi.type.DateType.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/DecimalDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/DecimalDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..779e93e384b3e861bf41f1b761a3ff699e2689a6 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/DecimalDecodeType.java @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Decimal decode type + * + * @since 2022-07-18 + */ +public class DecimalDecodeType + implements DecodeType +{ + private final int precision; + private final int scale; + + public DecimalDecodeType(int precision, int scale) + { + this.precision = precision; + this.scale = scale; + } + + public int getPrecision() + { + return precision; + } + + public int getScale() + { + return scale; + } + + @Override + public Optional> getJavaType() + { + return Optional.of(DecimalDecodeType.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/decode/impl/OpenLooKengDeserializer.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/DecodeType.java similarity index 62% rename from omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/decode/impl/OpenLooKengDeserializer.java rename to omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/DecodeType.java index 275327c376cb1245fba06e0d3883cbe843a42042..47dc317e408ba1d290fa7c62927f5a05b3ecf6d1 100644 --- a/omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/decode/impl/OpenLooKengDeserializer.java +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/DecodeType.java @@ -1,5 +1,5 @@ /* - * Copyright (C) 2018-2021. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -12,9 +12,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.huawei.boostkit.omnidata.decode.impl; -import com.huawei.boostkit.omnidata.decode.Deserializer; +package io.prestosql.plugin.hive.omnidata.decode.type; -public class OpenLooKengDeserializer implements Deserializer { +import java.util.Optional; + +/** + * Dcode java type + * + * @since 2022-07-18 + */ +public interface DecodeType +{ + /** + * get java class type + * + * @return class type + */ + Optional> getJavaType(); } diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/DoubleDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/DoubleDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..b1a4b789ec2b598419b125b2a1577540e8b949f2 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/DoubleDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Double decode type + * + * @since 2022-07-18 + */ +public class DoubleDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(double.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/FloatDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/FloatDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..fdd9f725cdd83bca7b7ea366f67863907d44962a --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/FloatDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Float decode type + * + * @since 2022-07-18 + */ +public class FloatDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(float.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/IntDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/IntDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..129760ad46ed62fba774575aa76a3de165aa3f10 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/IntDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Int decode type + * + * @since 2022-07-18 + */ +public class IntDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(int.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..af78cdcff0dcff5c9670068e68c48b0699a49201 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Long decode type + * + * @since 2022-07-18 + */ +public class LongDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(long.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongToByteDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongToByteDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..7c87a0578fd8ed0c87c33c58a219f821db665f17 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongToByteDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Long To Byte decode + * + * @since 2022-07-18 + */ +public class LongToByteDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(LongToByteDecodeType.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongToFloatDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongToFloatDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..f073f157e2f74d9c97e120894d68fa90ba309868 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongToFloatDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Double To Float decode + * + * @since 2022-07-18 + */ +public class LongToFloatDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(LongToFloatDecodeType.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongToIntDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongToIntDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..a78dab75241099548f923ca84fda3ef47c5b646c --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongToIntDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Long To Int decode + * + * @since 2022-07-18 + */ +public class LongToIntDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(LongToIntDecodeType.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongToShortDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongToShortDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..9c0491e1dfcaed735e69f0e6e0bfb02cf8cd1bde --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/LongToShortDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * 功能描述 + * + * @since 2022-07-18 + */ +public class LongToShortDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(LongToShortDecodeType.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/MapDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/MapDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..887689ebe2e96b401d204a658077c6175264464b --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/MapDecodeType.java @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * MapDecode type + * + * @param k + * @param v + * @since 2022-07-18 + */ +public class MapDecodeType + implements DecodeType +{ + private final K keyType; + private final V valueType; + + public MapDecodeType(K keyType, V valueType) + { + this.keyType = keyType; + this.valueType = valueType; + } + + @Override + public Optional> getJavaType() + { + return Optional.empty(); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/RowDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/RowDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..d91008e406357857fdcb8c45cee72342bc61644e --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/RowDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Row decode type + * + * @since 2022-07-18 + */ +public class RowDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(io.prestosql.spi.type.RowType.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/ShortDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/ShortDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..6256173124bf52e15a8d3de77ec79d153346372d --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/ShortDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Short decode type + * + * @since 2022-07-18 + */ +public class ShortDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(short.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/TimestampDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/TimestampDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..375e75769b82728cd58af32367126ae8d7ec4812 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/TimestampDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Timestamp decode type + * + * @since 2022-07-18 + */ +public class TimestampDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(TimestampDecodeType.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/VarcharDecodeType.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/VarcharDecodeType.java new file mode 100644 index 0000000000000000000000000000000000000000..42f2f2f55676c746fa3d410034ffc62e405bac65 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/omnidata/decode/type/VarcharDecodeType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import java.util.Optional; + +/** + * Varchar decode type + * + * @since 2022-07-18 + */ +public class VarcharDecodeType + implements DecodeType +{ + @Override + public Optional> getJavaType() + { + return Optional.of(String.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java index abfc416e7ea564866b9e645580dfa7dd07574675..b106d365b1aba21b6872af5985fe1290230ef0c4 100644 --- a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import com.google.common.util.concurrent.UncheckedExecutionException; -import com.huawei.boostkit.omnidata.decode.impl.OpenLooKengDeserializer; import com.huawei.boostkit.omnidata.model.Predicate; import com.huawei.boostkit.omnidata.model.TaskSource; import com.huawei.boostkit.omnidata.model.datasource.DataSource; @@ -51,6 +50,7 @@ import io.prestosql.plugin.hive.HivePushDownPageSource; import io.prestosql.plugin.hive.HiveSessionProperties; import io.prestosql.plugin.hive.HiveType; import io.prestosql.plugin.hive.HiveUtil; +import io.prestosql.plugin.hive.omnidata.decode.impl.OpenLooKengDeserializer; import io.prestosql.plugin.hive.orc.OrcPageSource.ColumnAdaptation; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/parquet/ParquetPageSourceFactory.java b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/parquet/ParquetPageSourceFactory.java index 639b0ca18ba79e9688dee512ed4fc86f66b36827..dbccc94163cd357cdc2dc4f6c71e1b90d68b64d2 100644 --- a/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/java/io/prestosql/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -16,7 +16,6 @@ package io.prestosql.plugin.hive.parquet; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.huawei.boostkit.omnidata.decode.impl.OpenLooKengDeserializer; import com.huawei.boostkit.omnidata.model.TaskSource; import com.huawei.boostkit.omnidata.model.datasource.DataSource; import com.huawei.boostkit.omnidata.reader.DataReader; @@ -39,6 +38,7 @@ import io.prestosql.plugin.hive.HivePageSourceFactory; import io.prestosql.plugin.hive.HivePartitionKey; import io.prestosql.plugin.hive.HivePushDownPageSource; import io.prestosql.plugin.hive.HiveSessionProperties; +import io.prestosql.plugin.hive.omnidata.decode.impl.OpenLooKengDeserializer; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.connector.ConnectorPageSource; diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/main/resource/license/license-header-alternate-2022-2022.txt b/omnidata/omnidata-openlookeng-connector/connector/src/main/resource/license/license-header-alternate-2022-2022.txt new file mode 100644 index 0000000000000000000000000000000000000000..586419df16f9e4ccb846e33f55342e26286a2a47 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/main/resource/license/license-header-alternate-2022-2022.txt @@ -0,0 +1,12 @@ +Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/test/java/io/prestosql/plugin/hive/omnidata/decode/DeserializerTestUtils.java b/omnidata/omnidata-openlookeng-connector/connector/src/test/java/io/prestosql/plugin/hive/omnidata/decode/DeserializerTestUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..3f1508870e42f1cea9a0e5b7134b528eb1bdcc26 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/test/java/io/prestosql/plugin/hive/omnidata/decode/DeserializerTestUtils.java @@ -0,0 +1,185 @@ +/* + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode; + +import com.google.common.collect.ImmutableList; +import com.huawei.boostkit.omnidata.exception.OmniDataException; +import io.prestosql.metadata.Metadata; +import io.prestosql.spi.PageBuilder; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.RowBlockBuilder; +import io.prestosql.spi.block.SingleRowBlockWriter; +import io.prestosql.spi.type.MapType; +import io.prestosql.spi.type.RowType; +import io.prestosql.spi.type.StandardTypes; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.TypeSignatureParameter; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static io.airlift.slice.Slices.wrappedBuffer; +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.DateType.DATE; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.RealType.REAL; +import static io.prestosql.spi.type.SmallintType.SMALLINT; +import static io.prestosql.spi.type.TinyintType.TINYINT; +import static io.prestosql.spi.type.VarcharType.VARCHAR; + +/** + * Deserializer Test Utils + * + * @since 2022-07-18 + */ +public class DeserializerTestUtils +{ + public static final Metadata METADATA = createTestMetadataManager(); + + private DeserializerTestUtils() + {} + + /** + * Return a map type + * @param keyType keyType + * @param valueType valueType + * @return type + */ + public static MapType mapType(Type keyType, Type valueType) + { + Type type = METADATA.getFunctionAndTypeManager() + .getParameterizedType(StandardTypes.MAP, + ImmutableList.of(TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(valueType.getTypeSignature()))); + if (type instanceof MapType) { + return (MapType) type; + } + throw new OmniDataException("Except Map type"); + } + + /** + * create Long sequence block + * + * @param start start + * @param end end + * @return block + */ + public static Block createLongSequenceBlock(int start, int end) + { + BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(end - start); + + for (int i = start; i < end; i++) { + BIGINT.writeLong(builder, i); + } + + return builder.build(); + } + + /** + * Get Test PageBuilder + * + * @return pageBuilder + */ + public static PageBuilder getTestPageBuilder() + { + // generate rowType + List fieldTypes = new ArrayList<>(); + fieldTypes.add(DOUBLE); + fieldTypes.add(BIGINT); + RowType rowType = RowType.anonymous(fieldTypes); + //generate a page + ImmutableList.Builder typesBuilder = ImmutableList.builder(); + typesBuilder.add(INTEGER, DOUBLE, BOOLEAN, BIGINT, VARCHAR, SMALLINT, DATE, TINYINT, REAL, rowType); + + ImmutableList types = typesBuilder.build(); + PageBuilder pageBuilder = new PageBuilder(types); + + fillUpPageBuilder(pageBuilder); + + // RowType + BlockBuilder builder = pageBuilder.getBlockBuilder(9); + if (!(builder instanceof RowBlockBuilder)) { + throw new OmniDataException("Except RowBlockBuilder but found " + builder.getClass()); + } + RowBlockBuilder rowBlockBuilder = (RowBlockBuilder) builder; + SingleRowBlockWriter singleRowBlockWriter = rowBlockBuilder.beginBlockEntry(); + DOUBLE.writeDouble(singleRowBlockWriter, 1.0); + BIGINT.writeLong(singleRowBlockWriter, 1); + rowBlockBuilder.closeEntry(); + + singleRowBlockWriter = rowBlockBuilder.beginBlockEntry(); + singleRowBlockWriter.appendNull(); + singleRowBlockWriter.appendNull(); + rowBlockBuilder.closeEntry(); + + pageBuilder.declarePositions(2); + + return pageBuilder; + } + + private static void fillUpPageBuilder(PageBuilder pageBuilder) + { + BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); + blockBuilder.writeInt(1); + blockBuilder.appendNull(); + + // DOUBLE + blockBuilder = pageBuilder.getBlockBuilder(1); + blockBuilder.writeLong(Double.doubleToLongBits(1.0)); + blockBuilder.appendNull(); + + // BOOLEAN false + blockBuilder = pageBuilder.getBlockBuilder(2); + blockBuilder.writeByte(0); + blockBuilder.appendNull(); + + // LONG + blockBuilder = pageBuilder.getBlockBuilder(3); + blockBuilder.writeLong(1); + blockBuilder.appendNull(); + + // VARCHAR + blockBuilder = pageBuilder.getBlockBuilder(4); + blockBuilder.writeBytes(wrappedBuffer("test".getBytes(StandardCharsets.UTF_8)), 0, "test".length()); + blockBuilder.closeEntry(); + blockBuilder.appendNull(); + + // SMALLINT + blockBuilder = pageBuilder.getBlockBuilder(5); + blockBuilder.writeShort(1); + blockBuilder.appendNull(); + + // DATE + blockBuilder = pageBuilder.getBlockBuilder(6); + blockBuilder.writeInt(1); + blockBuilder.appendNull(); + + // TINYINT + blockBuilder = pageBuilder.getBlockBuilder(7); + blockBuilder.writeByte(1); + blockBuilder.appendNull(); + + // REAL + blockBuilder = pageBuilder.getBlockBuilder(8); + blockBuilder.writeInt(Float.floatToIntBits((float) 1.0)); + blockBuilder.appendNull(); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/test/java/io/prestosql/plugin/hive/omnidata/decode/TestOpenLooKengDeserializer.java b/omnidata/omnidata-openlookeng-connector/connector/src/test/java/io/prestosql/plugin/hive/omnidata/decode/TestOpenLooKengDeserializer.java new file mode 100644 index 0000000000000000000000000000000000000000..b70349af7d54b913ec2cd3031081627e370d2eda --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/test/java/io/prestosql/plugin/hive/omnidata/decode/TestOpenLooKengDeserializer.java @@ -0,0 +1,334 @@ +/* + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode; + +import com.google.common.collect.ImmutableList; +import com.huawei.boostkit.omnidata.exception.OmniDataException; +import com.huawei.boostkit.omnidata.serialize.OmniDataBlockEncodingSerde; +import io.airlift.compress.Decompressor; +import io.airlift.compress.zstd.ZstdDecompressor; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceInput; +import io.hetu.core.transport.execution.buffer.PagesSerde; +import io.hetu.core.transport.execution.buffer.PagesSerdeFactory; +import io.hetu.core.transport.execution.buffer.SerializedPage; +import io.prestosql.plugin.hive.omnidata.decode.impl.OpenLooKengDecoding; +import io.prestosql.plugin.hive.omnidata.decode.impl.OpenLooKengDeserializer; +import io.prestosql.plugin.hive.omnidata.decode.type.ArrayDecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.DateDecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.DecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.LongDecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.LongToByteDecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.LongToFloatDecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.LongToIntDecodeType; +import io.prestosql.plugin.hive.omnidata.decode.type.LongToShortDecodeType; +import io.prestosql.spi.Page; +import io.prestosql.spi.PageBuilder; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.DictionaryBlock; +import io.prestosql.spi.block.Int128ArrayBlock; +import io.prestosql.spi.block.LongArrayBlock; +import io.prestosql.spi.block.RunLengthEncodedBlock; +import io.prestosql.spi.block.SingleMapBlock; +import io.prestosql.spi.type.ArrayType; +import io.prestosql.spi.type.MapType; +import io.prestosql.spi.type.Type; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.Optional; + +import static io.prestosql.plugin.hive.omnidata.decode.DeserializerTestUtils.METADATA; +import static io.prestosql.plugin.hive.omnidata.decode.DeserializerTestUtils.createLongSequenceBlock; +import static io.prestosql.plugin.hive.omnidata.decode.DeserializerTestUtils.getTestPageBuilder; +import static io.prestosql.plugin.hive.omnidata.decode.DeserializerTestUtils.mapType; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +/** + * Block deserializer test + * + * @since 2022-07-19 + */ +public class TestOpenLooKengDeserializer +{ + /** + * Test all types + */ + @Test + public void testAllSupportTypesDeserializer() + { + Page page = getTestPageBuilder().build(); + + // serialize page + PagesSerdeFactory factory = new PagesSerdeFactory(new OmniDataBlockEncodingSerde(), false); + PagesSerde pagesSerde = factory.createPagesSerde(); + SerializedPage serializedPage = pagesSerde.serialize(page); + + // deserialize page + OpenLooKengDeserializer deserializer = new OpenLooKengDeserializer(); + Page deserializedPage = deserializer.deserialize(serializedPage); + + assertEquals(2, deserializedPage.getPositionCount()); + assertEquals(10, deserializedPage.getChannelCount()); + + assertEquals(deserializedPage.getBlock(0).getSizeInBytes(), 10); + assertEquals(deserializedPage.getBlock(1).getSizeInBytes(), 18); + assertEquals(deserializedPage.getBlock(2).getSizeInBytes(), 4); + assertEquals(deserializedPage.getBlock(3).getSizeInBytes(), 18); + assertEquals(deserializedPage.getBlock(4).getSizeInBytes(), 14); + assertEquals(deserializedPage.getBlock(5).getSizeInBytes(), 6); + assertEquals(deserializedPage.getBlock(6).getSizeInBytes(), 10); + assertEquals(deserializedPage.getBlock(7).getSizeInBytes(), 4); + assertEquals(deserializedPage.getBlock(8).getSizeInBytes(), 10); + assertEquals(deserializedPage.getBlock(9).getSizeInBytes(), 46); + } + + /** + * Test Compressed deserializer + */ + @Test + public void testCompressedDeserializer() + { + long[] values = new long[] {1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0}; + boolean[] valueIsNull = new boolean[] {false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, true}; + LongArrayBlock longArrayBlock = new LongArrayBlock(16, Optional.of(valueIsNull), values); + Page longArrayPage = new Page(longArrayBlock); + + // serialize page + PagesSerdeFactory factory = new PagesSerdeFactory(new OmniDataBlockEncodingSerde(), true); + PagesSerde pagesSerde = factory.createPagesSerde(); + SerializedPage serializedPage = pagesSerde.serialize(longArrayPage); + + // deserialize page + DecodeType[] decodeTypes = new DecodeType[] { + new LongToIntDecodeType(), new LongToShortDecodeType(), new LongToByteDecodeType(), + new LongToFloatDecodeType(), new DateDecodeType() + }; + + OpenLooKengDecoding blockDecoding = new OpenLooKengDecoding(); + Decompressor decompressor = new ZstdDecompressor(); + int failedTimes = 0; + for (int i = 0; i < decodeTypes.length; i++) { + // to decode type + Slice slice = OpenLooKengDeserializer.decompressPage(serializedPage, decompressor); + SliceInput input = slice.getInput(); + int numberOfBlocks = input.readInt(); + Block[] blocks = new Block[numberOfBlocks]; + try { + blocks[0] = blockDecoding.decode(Optional.of(decodeTypes[i]), input); + } + catch (OmniDataException e) { + failedTimes++; + } + } + + assertEquals(failedTimes, decodeTypes.length); + } + + /** + * Test Int128Type deserializer + */ + @Test + public void testInt128TypeDeserializer() + { + Int128ArrayBlock int128ArrayBlock = + new Int128ArrayBlock(0, Optional.empty(), new long[0]); + Page int128ArrayPage = new Page(int128ArrayBlock); + + // serialize page + PagesSerdeFactory factory = new PagesSerdeFactory(new OmniDataBlockEncodingSerde(), false); + PagesSerde pagesSerde = factory.createPagesSerde(); + SerializedPage serializedPage = pagesSerde.serialize(int128ArrayPage); + + // deserialize page + OpenLooKengDeserializer deserializer = new OpenLooKengDeserializer(); + + // generate exception + Page page = deserializer.deserialize(serializedPage); + + assertEquals(0, page.getSizeInBytes()); + } + + /** + * Test RunLengthType deserializer + */ + @Test + public void testRunLengthTypeDeserializer() + { + RunLengthEncodedBlock runLengthEncodedBlock = + new RunLengthEncodedBlock<>(createLongSequenceBlock(4, 5), 100); + Page runLengthPage = new Page(runLengthEncodedBlock); + + // serialize page + PagesSerdeFactory factory = new PagesSerdeFactory(new OmniDataBlockEncodingSerde(), false); + PagesSerde pagesSerde = factory.createPagesSerde(); + SerializedPage serializedPage = pagesSerde.serialize(runLengthPage); + + // deserialize page + OpenLooKengDeserializer deserializer = new OpenLooKengDeserializer(); + + // generate ColumnVector + Page page = deserializer.deserialize(serializedPage); + + assertEquals(9, page.getSizeInBytes()); + } + + /** + * Test ArrayType deserializer + */ + @Test + public void testArrayTypeDeserializer() + { + // generate a page + ImmutableList.Builder typeBuilder = ImmutableList.builder(); + typeBuilder.add(new ArrayType<>(BIGINT)); + + ImmutableList types = typeBuilder.build(); + PageBuilder pageBuilder = new PageBuilder(types); + BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); + BlockBuilder elementBlockBuilder = BIGINT.createBlockBuilder(null, 2); + for (int i = 0; i < 2; i++) { + BIGINT.writeLong(elementBlockBuilder, 1); + } + blockBuilder.appendStructure(elementBlockBuilder.build()); + + pageBuilder.declarePositions(1); + Page page = pageBuilder.build(); + + // serialize page + PagesSerdeFactory factory = new PagesSerdeFactory(new OmniDataBlockEncodingSerde(), false); + PagesSerde pagesSerde = factory.createPagesSerde(); + SerializedPage serializedPage = pagesSerde.serialize(page); + + //deserialize page + DecodeType[] decodeTypes = new DecodeType[] {new ArrayDecodeType<>(new LongDecodeType())}; + DecodeType firstType = decodeTypes[0]; + if (firstType instanceof ArrayDecodeType) { + assertEquals(((ArrayDecodeType) firstType).getElementType().getClass(), LongDecodeType.class); + } + else { + throw new OmniDataException("except arrayType"); + } + OpenLooKengDeserializer deserializer = new OpenLooKengDeserializer(); + + // generate exception + Page deserialized = deserializer.deserialize(serializedPage); + + assertEquals(23, deserialized.getSizeInBytes()); + } + + /** + * Test MapType deserializer + */ + @Test(expectedExceptions = {UnsupportedOperationException.class, RuntimeException.class}) + public void testMapTypeDeserializer() + { + // generate a page + ImmutableList.Builder typeBuilder = ImmutableList.builder(); + MapType mapType = mapType(BIGINT, BIGINT); + + typeBuilder.add(mapType); + + ImmutableList types = typeBuilder.build(); + PageBuilder pageBuilder = new PageBuilder(types); + BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); + blockBuilder.appendNull(); + + pageBuilder.declarePositions(1); + Page page = pageBuilder.build(); + + // serialize page + PagesSerdeFactory factory = new PagesSerdeFactory(METADATA.getFunctionAndTypeManager().getBlockEncodingSerde(), + false); + PagesSerde pagesSerde = factory.createPagesSerde(); + SerializedPage serializedPage = pagesSerde.serialize(page); + + // deserialize page + OpenLooKengDeserializer deserializer = new OpenLooKengDeserializer(); + + // generate exception + deserializer.deserialize(serializedPage); + } + + /** + * Test SingleMapType deserializer + */ + @Test(expectedExceptions = {UnsupportedOperationException.class, RuntimeException.class}) + public void testSingleMapTypeDeserializer() + { + // generate a page + ImmutableList.Builder typeBuilder = ImmutableList.builder(); + MapType mapType = mapType(BIGINT, BIGINT); + + typeBuilder.add(mapType); + + ImmutableList types = typeBuilder.build(); + PageBuilder pageBuilder = new PageBuilder(types); + BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); + blockBuilder.appendNull(); + + pageBuilder.declarePositions(1); + Page page = pageBuilder.build(); + Block block = page.getBlock(0); + + Block elementBlock = mapType.getObject(block, 0); + assertTrue(elementBlock instanceof SingleMapBlock); + + Page singleMapPage = new Page(elementBlock); + + // serialize page + PagesSerdeFactory factory = new PagesSerdeFactory(METADATA.getFunctionAndTypeManager().getBlockEncodingSerde(), + false); + PagesSerde pagesSerde = factory.createPagesSerde(); + SerializedPage serializedPage = pagesSerde.serialize(singleMapPage); + + // deserialize page + OpenLooKengDeserializer deserializer = new OpenLooKengDeserializer(); + + // generate exception + deserializer.deserialize(serializedPage); + } + + /** + * Test DictionaryType deserializer + */ + @Test + public void testDictionaryTypeDeserializer() + { + int[] ids = new int[100]; + Arrays.setAll(ids, index -> index % 10); + Block dictionary = createLongSequenceBlock(0, 10); + DictionaryBlock dictionaryBlock = new DictionaryBlock<>(dictionary, ids); + Page dictionaryPage = new Page(dictionaryBlock); + + // serialize page + PagesSerdeFactory factory = new PagesSerdeFactory(new OmniDataBlockEncodingSerde(), false); + PagesSerde pagesSerde = factory.createPagesSerde(); + SerializedPage serializedPage = pagesSerde.serialize(dictionaryPage); + + // deserialize page + OpenLooKengDeserializer deserializer = new OpenLooKengDeserializer(); + + // generate exception + Page page = deserializer.deserialize(serializedPage); + + assertEquals(490, page.getSizeInBytes()); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/connector/src/test/java/io/prestosql/plugin/hive/omnidata/decode/type/TestOmniDataDecodeTypes.java b/omnidata/omnidata-openlookeng-connector/connector/src/test/java/io/prestosql/plugin/hive/omnidata/decode/type/TestOmniDataDecodeTypes.java new file mode 100644 index 0000000000000000000000000000000000000000..937bdf55640bf08b5ec4ddb45a839332a9160754 --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/connector/src/test/java/io/prestosql/plugin/hive/omnidata/decode/type/TestOmniDataDecodeTypes.java @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.plugin.hive.omnidata.decode.type; + +import io.prestosql.spi.type.RowType; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static org.testng.Assert.assertEquals; + +/** + * Test All OmniData DecodeTypes + * + * @since 2022-07-18 + */ +public class TestOmniDataDecodeTypes +{ + @Test + public void testOmnidataDecodeTypes() + { + // Test DecimalDecodeType + DecodeType decimalDecodeType = new DecimalDecodeType(10, 5); + assertEquals(decimalDecodeType.getJavaType().get(), DecimalDecodeType.class); + + // Test TimestampDecodeType + DecodeType timestampDecodeType = new TimestampDecodeType(); + assertEquals(timestampDecodeType.getJavaType().get(), TimestampDecodeType.class); + + // Test VarcharDecodeType + DecodeType varcharDecodeType = new VarcharDecodeType(); + assertEquals(varcharDecodeType.getJavaType().get(), String.class); + + // Test ShortDecodeType + DecodeType shortDecodeType = new ShortDecodeType(); + assertEquals(shortDecodeType.getJavaType().get(), short.class); + + // Test RowDecodeType + DecodeType rowDecodeType = new RowDecodeType(); + assertEquals(rowDecodeType.getJavaType().get(), RowType.class); + + // Test MapDecodeType + DecodeType mapDecodeType = new MapDecodeType(new ShortDecodeType(), new ShortDecodeType()); + assertEquals(mapDecodeType.getJavaType(), Optional.empty()); + + // Test IntDecodeType + DecodeType intDecodeType = new IntDecodeType(); + assertEquals(intDecodeType.getJavaType().get(), int.class); + + // Test FloatDecodeType + DecodeType floatDecodeType = new FloatDecodeType(); + assertEquals(floatDecodeType.getJavaType().get(), float.class); + + // Test DoubleDecodeType + DecodeType doubleDecodeType = new DoubleDecodeType(); + assertEquals(doubleDecodeType.getJavaType().get(), double.class); + + // Test ByteDecodeType + DecodeType byteDecodeType = new ByteDecodeType(); + assertEquals(byteDecodeType.getJavaType().get(), byte.class); + + // Test BooleanDecodeType + DecodeType booleanDecodeType = new BooleanDecodeType(); + assertEquals(booleanDecodeType.getJavaType().get(), boolean.class); + } +} diff --git a/omnidata/omnidata-openlookeng-connector/stub/server/pom.xml b/omnidata/omnidata-openlookeng-connector/stub/server/pom.xml index 11fde2862f4568188c78e3b0e85751c7343ae8cf..e4b16fac7880bdf8d2aed7fe897a94eb6b40a45b 100644 --- a/omnidata/omnidata-openlookeng-connector/stub/server/pom.xml +++ b/omnidata/omnidata-openlookeng-connector/stub/server/pom.xml @@ -23,5 +23,11 @@ commons-io ${dep.commons.io.version} + + io.hetu.core + hetu-transport + 1.6.1 + compile + \ No newline at end of file diff --git a/omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/decode/Deserializer.java b/omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/decode/Deserializer.java index ea5c2957b0d1ea76e9d72340eb9d0d6f2cdf6c98..e53a6fc6386bec1a6af167795a6d7517801b10ca 100644 --- a/omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/decode/Deserializer.java +++ b/omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/decode/Deserializer.java @@ -1,5 +1,5 @@ /* - * Copyright (C) 2018-2021. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -14,5 +14,8 @@ */ package com.huawei.boostkit.omnidata.decode; +import io.hetu.core.transport.execution.buffer.SerializedPage; + public interface Deserializer { + T deserialize(SerializedPage page); } diff --git a/omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/exception/OmniDataException.java b/omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/exception/OmniDataException.java index 98ca5c4495ba1850e6f0d583868926a94c0ab433..9dd26007a8c97aacecfb27780ceba6ad0dae8699 100644 --- a/omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/exception/OmniDataException.java +++ b/omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/exception/OmniDataException.java @@ -1,5 +1,5 @@ /* - * Copyright (C) 2018-2021. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -16,8 +16,32 @@ package com.huawei.boostkit.omnidata.exception; import static com.huawei.boostkit.omnidata.exception.OmniErrorCode.OMNIDATA_GENERIC_ERROR; -public class OmniDataException { +/** + * OmniDataException + * + * @since 2022-07-18 + */ +public class OmniDataException extends RuntimeException { + private static final long serialVersionUID = -9034897193745766939L; + + private final OmniErrorCode errorCode; + + public OmniDataException(String message) { + super(message); + errorCode = OMNIDATA_GENERIC_ERROR; + } + + public OmniDataException(String message, Throwable throwable) { + super(message, throwable); + errorCode = OMNIDATA_GENERIC_ERROR; + } + + public OmniDataException(OmniErrorCode omniErrorCode, String message) { + super(message); + errorCode = omniErrorCode; + } + public OmniErrorCode getErrorCode() { - return OMNIDATA_GENERIC_ERROR; + return errorCode; } } diff --git a/omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/serialize/OmniDataBlockEncodingSerde.java b/omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/serialize/OmniDataBlockEncodingSerde.java new file mode 100644 index 0000000000000000000000000000000000000000..7ebc828d2b80453fd99af2f565bd4aedf53a6e6a --- /dev/null +++ b/omnidata/omnidata-openlookeng-connector/stub/server/src/main/java/com/huawei/boostkit/omnidata/serialize/OmniDataBlockEncodingSerde.java @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omnidata.serialize; + +import com.google.common.collect.ImmutableMap; +import io.airlift.slice.SliceInput; +import io.airlift.slice.SliceOutput; +import io.prestosql.spi.block.ArrayBlockEncoding; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockEncoding; +import io.prestosql.spi.block.BlockEncodingSerde; +import io.prestosql.spi.block.ByteArrayBlockEncoding; +import io.prestosql.spi.block.DictionaryBlockEncoding; +import io.prestosql.spi.block.Int128ArrayBlockEncoding; +import io.prestosql.spi.block.IntArrayBlockEncoding; +import io.prestosql.spi.block.LazyBlockEncoding; +import io.prestosql.spi.block.LongArrayBlockEncoding; +import io.prestosql.spi.block.RowBlockEncoding; +import io.prestosql.spi.block.RunLengthBlockEncoding; +import io.prestosql.spi.block.ShortArrayBlockEncoding; +import io.prestosql.spi.block.SingleRowBlockEncoding; +import io.prestosql.spi.block.VariableWidthBlockEncoding; + +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.Optional; + +public class OmniDataBlockEncodingSerde implements BlockEncodingSerde { + private final Map blockEncodings; + public OmniDataBlockEncodingSerde() { + blockEncodings = + ImmutableMap.builder() + .put(VariableWidthBlockEncoding.NAME, new VariableWidthBlockEncoding()) + .put(ByteArrayBlockEncoding.NAME, new ByteArrayBlockEncoding()) + .put(ShortArrayBlockEncoding.NAME, new ShortArrayBlockEncoding()) + .put(IntArrayBlockEncoding.NAME, new IntArrayBlockEncoding()) + .put(LongArrayBlockEncoding.NAME, new LongArrayBlockEncoding()) + .put(Int128ArrayBlockEncoding.NAME, new Int128ArrayBlockEncoding()) + .put(DictionaryBlockEncoding.NAME, new DictionaryBlockEncoding()) + .put(ArrayBlockEncoding.NAME, new ArrayBlockEncoding()) + .put(RowBlockEncoding.NAME, new RowBlockEncoding()) + .put(SingleRowBlockEncoding.NAME, new SingleRowBlockEncoding()) + .put(RunLengthBlockEncoding.NAME, new RunLengthBlockEncoding()) + .put(LazyBlockEncoding.NAME, new LazyBlockEncoding()) + .build(); + } + + private static String readLengthPrefixedString(SliceInput sliceInput) + { + int length = sliceInput.readInt(); + byte[] bytes = new byte[length]; + sliceInput.readBytes(bytes); + + return new String(bytes, StandardCharsets.UTF_8); + } + + private static void writeLengthPrefixedString(SliceOutput sliceOutput, String value) + { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + sliceOutput.writeInt(bytes.length); + sliceOutput.writeBytes(bytes); + } + + @Override + public Block readBlock(SliceInput input) + { + return blockEncodings.get(readLengthPrefixedString(input)).readBlock(this, input); + } + + @Override + public void writeBlock(SliceOutput output, Block block) + { + Block readBlock = block; + while (true) { + String encodingName = readBlock.getEncodingName(); + + BlockEncoding blockEncoding = blockEncodings.get(encodingName); + + Optional replacementBlock = blockEncoding.replacementBlockForWrite(readBlock); + if (replacementBlock.isPresent()) { + readBlock = replacementBlock.get(); + continue; + } + + writeLengthPrefixedString(output, encodingName); + + blockEncoding.writeBlock(this, output, readBlock); + + break; + } + } +} diff --git a/omnidata/omnidata-spark-connector/build.sh b/omnidata/omnidata-spark-connector/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..7a528d4476b1a6d758054576f289d09ce8ae9c3e --- /dev/null +++ b/omnidata/omnidata-spark-connector/build.sh @@ -0,0 +1,34 @@ +#!/bin/bash +mvn clean package +jar_name=`ls -n connector/target/*.jar | grep omnidata-spark | awk -F ' ' '{print$9}' | awk -F '/' '{print$3}'` +dir_name=`ls -n connector/target/*.jar | grep omnidata-spark | awk -F ' ' '{print$9}' | awk -F '/' '{print$3}' | awk -F '.jar' '{print$1}'` +rm -r $dir_name-aarch64 +rm -r $dir_name-aarch64.zip +mkdir -p $dir_name-aarch64 +cp connector/target/$jar_name $dir_name-aarch64 +cd $dir_name-aarch64 +wget https://mirrors.huaweicloud.com/repository/maven/org/bouncycastle/bcpkix-jdk15on/1.68/bcpkix-jdk15on-1.68.jar +wget https://mirrors.huaweicloud.com/repository/maven/org/apache/curator/curator-client/2.12.0/curator-client-2.12.0.jar +wget https://mirrors.huaweicloud.com/repository/maven/org/apache/curator/curator-framework/2.12.0/curator-framework-2.12.0.jar +wget https://mirrors.huaweicloud.com/repository/maven/org/apache/curator/curator-recipes/2.12.0/curator-recipes-2.12.0.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/alibaba/fastjson/1.2.76/fastjson-1.2.76.jar +wget https://mirrors.huaweicloud.com/repository/maven/de/ruedigermoeller/fst/2.57/fst-2.57.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/google/guava/guava/26.0-jre/guava-26.0-jre.jar +wget https://mirrors.huaweicloud.com/repository/maven/io/hetu/core/hetu-transport/1.6.1/hetu-transport-1.6.1.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-guava/2.12.4/jackson-datatype-guava-2.12.4.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-jdk8/2.12.4/jackson-datatype-jdk8-2.12.4.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-joda/2.12.4/jackson-datatype-joda-2.12.4.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-jsr310/2.12.4/jackson-datatype-jsr310-2.12.4.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/module/jackson-module-parameter-names/2.12.4/jackson-module-parameter-names-2.12.4.jar +wget https://mirrors.huaweicloud.com/repository/maven/org/jasypt/jasypt/1.9.3/jasypt-1.9.3.jar +wget https://mirrors.huaweicloud.com/repository/maven/org/openjdk/jol/jol-core/0.2/jol-core-0.2.jar +wget https://repo1.maven.org/maven2/io/airlift/joni/2.1.5.3/joni-2.1.5.3.jar +wget https://mirrors.huaweicloud.com/repository/maven/io/airlift/log/0.193/log-0.193.jar +wget https://mirrors.huaweicloud.com/repository/maven/io/perfmark/perfmark-api/0.23.0/perfmark-api-0.23.0.jar +wget https://mirrors.huaweicloud.com/repository/maven/io/hetu/core/presto-main/1.6.1/presto-main-1.6.1.jar +wget https://mirrors.huaweicloud.com/repository/maven/io/hetu/core/presto-spi/1.6.1/presto-spi-1.6.1.jar +wget https://mirrors.huaweicloud.com/repository/maven/com/google/protobuf/protobuf-java/3.12.0/protobuf-java-3.12.0.jar +wget https://mirrors.huaweicloud.com/repository/maven/io/airlift/slice/0.38/slice-0.38.jar +cd .. +zip -r -o $dir_name-aarch64.zip $dir_name-aarch64 +rm -r $dir_name-aarch64 \ No newline at end of file diff --git a/omnioperator/omniop-openlookeng-extension/README.md b/omnioperator/omniop-openlookeng-extension/README.md index 5cad1f0b7356d60a44c4a9830ad3a390936facb4..d92b31d459944d5baa72ec162c2725fd3068fa9b 100644 --- a/omnioperator/omniop-openlookeng-extension/README.md +++ b/omnioperator/omniop-openlookeng-extension/README.md @@ -1 +1,14 @@ -# omniop-openlookeng-extension \ No newline at end of file +# How to build omniop-openlookeng-extension +The project depends on pieces of `hetu-core` and `boostkit-omniop-bindings`, which are recommended to deployed beforehand. `hetu-core` is available in the central repo, and this project requests the version >= `1.6.1`. Thus, you only need to + get `boostkit-omniop-bindings-1.0.0.jar` from [Link](https://www.hikunpeng.com/en/developer/boostkit/big-data?acclerated=3), and install it in your maven repository as follows: + + ``` + mvn install:install-file -DgroupId=com.huawei.boostkit -DartifactId=boostkit-omniop-bindings -Dversion=1.0.0 -Dpackaging=jar -Dfile=boostkit-omniop-bindings-1.0.0.jar + ``` + +Then, you can build omniop-openlookeng-extension as follows: +``` +git clone https://github.com/kunpengcompute/boostkit-bigdata.git +cd ./boostkit-bigdata/omnioperator/omniop-openlookeng-extension +mvn clean install -DskipTests +``` diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/OmniLocalExecutionPlanner.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/OmniLocalExecutionPlanner.java index 51fa693dc2cf0431e920f55360db72a5c057360f..0ebcd413412a2b2168462a9a2d5a81d3c2632e36 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/OmniLocalExecutionPlanner.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/OmniLocalExecutionPlanner.java @@ -20,6 +20,7 @@ import com.google.common.collect.Lists; import com.google.common.primitives.Ints; import io.airlift.log.Logger; import io.airlift.node.NodeInfo; +import io.airlift.slice.Slice; import io.airlift.units.DataSize; import io.hetu.core.transport.execution.buffer.PagesSerdeFactory; import io.prestosql.Session; @@ -87,6 +88,7 @@ import io.prestosql.spi.plan.TopNNode; import io.prestosql.spi.plan.WindowNode; import io.prestosql.spi.predicate.NullableValue; import io.prestosql.spi.relation.CallExpression; +import io.prestosql.spi.relation.ConstantExpression; import io.prestosql.spi.relation.RowExpression; import io.prestosql.spi.relation.VariableReferenceExpression; import io.prestosql.spi.type.StandardTypes; @@ -154,7 +156,6 @@ import nova.hetu.omniruntime.type.DataType; import java.util.ArrayList; import java.util.HashSet; import java.util.LinkedHashMap; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -205,6 +206,7 @@ import static io.prestosql.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUT import static io.prestosql.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static nova.hetu.olk.operator.OrderByOmniOperator.OrderByOmniOperatorFactory.createOrderByOmniOperatorFactory; import static nova.hetu.olk.operator.filterandproject.OmniRowExpressionUtil.expressionStringify; +import static nova.hetu.olk.operator.filterandproject.OmniRowExpressionUtil.generateLikeExpr; import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_AVG; import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL; import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_COLUMN; @@ -849,14 +851,8 @@ public class OmniLocalExecutionPlanner private Optional addRowExpression(Optional rowExpression, Optional translatedFilter) { - List newArgs = new LinkedList(); - newArgs.add(((CallExpression) translatedFilter.get()).getArguments().get(0)); - newArgs.add(rowExpression.get()); - Optional likeTranslatedFilter = Optional - .of(new CallExpression(((CallExpression) translatedFilter.get()).getDisplayName(), - ((CallExpression) translatedFilter.get()).getFunctionHandle(), - ((CallExpression) translatedFilter.get()).getType(), newArgs)); - return likeTranslatedFilter; + String sqlString = ((Slice) ((ConstantExpression) ((CallExpression) rowExpression.get()).getArguments().get(0)).getValue()).toStringUtf8(); + return generateLikeExpr(sqlString, translatedFilter); } @Override diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/LazyOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/LazyOmniBlock.java index 09329694c588cf288abae49f0974b3c9a00289d4..10595ea694a7ac32deff91ef882e86c2fc75bce6 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/LazyOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/LazyOmniBlock.java @@ -154,4 +154,10 @@ public class LazyOmniBlock { return lazyBlock; } + + @Override + public void close() + { + nativeLazyVec.close(); + } } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/RowOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/RowOmniBlock.java index b693270b1ea6d357f03e8e9635e442c80e31c818..5730818abcda43d5fb2786bc326884c7223daf1b 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/RowOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/RowOmniBlock.java @@ -67,6 +67,8 @@ public class RowOmniBlock private final DataType dataType; + private boolean[] isNull; + /** * Create a row block directly from columnar nulls and field blocks. * @@ -238,7 +240,10 @@ public class RowOmniBlock @Nullable public boolean[] getRowIsNull() { - boolean[] isNull = new boolean[rowIsNull.length]; + if (isNull != null) { + return isNull; + } + isNull = new boolean[rowIsNull.length]; for (int i = 0; i < rowIsNull.length; i++) { isNull[i] = rowIsNull[i] == Vec.NULL; } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AggregationOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AggregationOmniOperator.java index 5a6d86a8ff7ed8d71c0f56d09154a834e1b16e92..8fa7b8a998e5a2811b6a4a6235b3b2be6eac6085 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AggregationOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AggregationOmniOperator.java @@ -37,12 +37,15 @@ import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; +import java.util.Arrays; import java.util.List; import java.util.Optional; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; import static nova.hetu.olk.tool.OperatorUtils.buildVecBatch; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_AVG; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; /** * The type Aggregation omni operator. @@ -241,12 +244,14 @@ public class AggregationOmniOperator case StandardTypes.DATE: return; case StandardTypes.VARBINARY: - case StandardTypes.ROW: { - if (this.step == AggregationNode.Step.FINAL) { + if (this.step == AggregationNode.Step.FINAL && this.aggregatorTypes.length != 0 && + Arrays.stream(this.aggregatorTypes).allMatch(item -> item == OMNI_AGGREGATION_TYPE_AVG || item == OMNI_AGGREGATION_TYPE_SUM)) { return; } - else if (this.step == AggregationNode.Step.PARTIAL) { - throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Not support data Type " + base); + case StandardTypes.ROW: { + if (this.step == AggregationNode.Step.FINAL && this.aggregatorTypes.length != 0 && + Arrays.stream(this.aggregatorTypes).allMatch(item -> item == OMNI_AGGREGATION_TYPE_AVG)) { + return; } } default: diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/DistinctLimitOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/DistinctLimitOmniOperator.java index 7570cbdcc6adfd9344c9d43d9e908099145e7500..3f7f1601601e8863b81575dddb1d9c7af95fb8d0 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/DistinctLimitOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/DistinctLimitOmniOperator.java @@ -223,6 +223,7 @@ public class DistinctLimitOmniOperator requireNonNull(page, "page is null"); if (page.getPositionCount() == 0) { + BlockUtils.freePage(page); return; } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/EnforceSingleRowOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/EnforceSingleRowOmniOperator.java index 34b82ba0a5d4dcff1115a4802cd19a9f396085a3..47aed9bf48f40b709947faa5f82e2d470b5f2297 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/EnforceSingleRowOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/EnforceSingleRowOmniOperator.java @@ -51,6 +51,13 @@ public class EnforceSingleRowOmniOperator @Override public void addInput(Page page) { + requireNonNull(page, "page is null"); + checkState(needsInput(), "Operator did not expect any more data"); + if (page.getPositionCount() == 0) { + BlockUtils.freePage(page); + return; + } + this.page = page; super.addInput(page); } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashAggregationOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashAggregationOmniOperator.java index 6b6e56d6050fc3a987dfcd3626ad5cf97ee78289..15efdcd333a6d909fc9824b50fce9b6342f09013 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashAggregationOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashAggregationOmniOperator.java @@ -47,11 +47,14 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.Optional; +import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; import static nova.hetu.olk.tool.OperatorUtils.buildVecBatch; import static nova.hetu.olk.tool.OperatorUtils.createExpressions; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_AVG; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; /** * The type Hash aggregation omni operator. @@ -228,6 +231,12 @@ public class HashAggregationOmniOperator this.step = step; this.groupByInputChannels = Arrays.copyOf( requireNonNull(groupByInputChannels, "groupByInputChannels is null."), groupByInputChannels.length); + List groupByTypes = Arrays.stream(this.groupByInputChannels) + .mapToObj(channel -> this.sourceTypes.get(channel)).collect(Collectors.toList()); + if (groupByTypes.stream().map(type -> type.getTypeSignature().getBase()) + .anyMatch(item -> item.equals(StandardTypes.ROW))) { + throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Not Support data Type " + StandardTypes.ROW); + } this.groupByInputTypes = Arrays.copyOf( requireNonNull(groupByInputTypes, "groupByInputTypes is null."), groupByInputTypes.length); @@ -346,12 +355,14 @@ public class HashAggregationOmniOperator case StandardTypes.DATE: return; case StandardTypes.VARBINARY: - case StandardTypes.ROW: { - if (this.step == AggregationNode.Step.FINAL) { + if (this.step == AggregationNode.Step.FINAL && this.aggregatorTypes.length != 0 && + Arrays.stream(this.aggregatorTypes).allMatch(item -> item == OMNI_AGGREGATION_TYPE_AVG || item == OMNI_AGGREGATION_TYPE_SUM)) { return; } - else if (this.step == AggregationNode.Step.PARTIAL) { - throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Not support data Type " + base); + case StandardTypes.ROW: { + if (this.step == AggregationNode.Step.FINAL && this.aggregatorTypes.length != 0 && + Arrays.stream(this.aggregatorTypes).allMatch(item -> item == OMNI_AGGREGATION_TYPE_AVG)) { + return; } } default: diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashBuilderOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashBuilderOmniOperator.java index 38e0862e600a26ad15d1b67bfb15e083aa4167e5..3b9efaf18a244f13ed32e94aeea0f9096ce0e4cc 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashBuilderOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashBuilderOmniOperator.java @@ -33,6 +33,7 @@ import io.prestosql.operator.PartitionedLookupSourceFactory; import io.prestosql.spi.Page; import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; +import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.OperatorUtils; import nova.hetu.olk.tool.VecAllocatorHelper; import nova.hetu.omniruntime.operator.OmniOperator; @@ -285,6 +286,7 @@ public class HashBuilderOmniOperator checkState(state == State.CONSUMING_INPUT); int positionCount = page.getPositionCount(); if (positionCount == 0) { + BlockUtils.freePage(page); return; } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LimitOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LimitOmniOperator.java index c4973779bbbe9d39152afb09993944584b22384f..c5972882c61b4290ed009313231103e4732f0fba 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LimitOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LimitOmniOperator.java @@ -123,6 +123,7 @@ public class LimitOmniOperator int rowCount = page.getPositionCount(); if (rowCount == 0) { + BlockUtils.freePage(page); return; } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperator.java index 2e3c69393a5a62918a7c3bf791f6c77064610805..49712034bcd79ff41fb9662fac2fd0bb095e7356 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperator.java @@ -188,6 +188,7 @@ public class LookupJoinOmniOperator int positionCount = page.getPositionCount(); if (positionCount == 0) { + BlockUtils.freePage(page); return; } VecBatch vecBatch = buildVecBatch(omniOperator.getVecAllocator(), page, this); diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/OrderByOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/OrderByOmniOperator.java index 620ae93fa8893dc775479deccfff8f161d0b4287..8229784cbece2e087de175121f1ae56f95a773f2 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/OrderByOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/OrderByOmniOperator.java @@ -255,6 +255,7 @@ public class OrderByOmniOperator int positionCount = page.getPositionCount(); if (positionCount == 0) { + BlockUtils.freePage(page); return; } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/PartitionedOutputOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/PartitionedOutputOmniOperator.java index 8495623e05d7a53300c93ca1e983eb721ee0ebb2..c8b362f9a78dd62c250c0df955570480aed62be9 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/PartitionedOutputOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/PartitionedOutputOmniOperator.java @@ -100,6 +100,7 @@ public class PartitionedOutputOmniOperator requireNonNull(page, "page is null"); if (page.getPositionCount() == 0) { + BlockUtils.freePage(page); return; } page = pagePreprocessor.apply(page); diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/JsonifyVisitor.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/JsonifyVisitor.java index 95e08c0f04e959b42a6fdc4ea0fe08d7f6f35868..f8c155bcdf773cee6a443fde6f89895ab2bb17bf 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/JsonifyVisitor.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/JsonifyVisitor.java @@ -230,7 +230,7 @@ class JsonifyVisitor case OMNI_VARCHAR: String varcharValue; if (literal.getValue() instanceof Slice) { - varcharValue = ((Slice) literal.getValue()).toStringAscii(); + varcharValue = ((Slice) literal.getValue()).toStringUtf8(); } else { varcharValue = String.valueOf(literal.getValue()); diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java index c417f43d586eafa60d19cbb81070eccabffcf79c..07d16bb753b8865be3ba6fb2f9c17b8479c0ddf1 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java @@ -559,7 +559,7 @@ public final class OperatorUtils Block block = dictionaryBlock.getDictionary(); Block omniDictionary = buildOffHeapBlock(vecAllocator, block, block.getClass().getSimpleName(), block.getPositionCount(), blockType); - Block dictionaryOmniBlock = new DictionaryOmniBlock((Vec) omniDictionary.getValues(), + Block dictionaryOmniBlock = new DictionaryOmniBlock(inputBlock.getPositionCount(), (Vec) omniDictionary.getValues(), dictionaryBlock.getIdsArray()); omniDictionary.close(); return dictionaryOmniBlock; diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/HashAggregationOmniOperatorTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/HashAggregationOmniOperatorTest.java index a7e8b294e8a56cd6d43192f47e40a17890c07815..ee66af04df786e4c9bef3beab7924e2327104634 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/HashAggregationOmniOperatorTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/HashAggregationOmniOperatorTest.java @@ -20,6 +20,7 @@ import io.prestosql.operator.OperatorFactory; import io.prestosql.spi.Page; import io.prestosql.spi.plan.AggregationNode; import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.type.BooleanType; import io.prestosql.spi.type.Type; import nova.hetu.olk.operator.HashAggregationOmniOperator.HashAggregationOmniOperatorFactory; import nova.hetu.olk.tool.OperatorUtils; @@ -36,7 +37,6 @@ import java.util.List; import java.util.Random; import java.util.UUID; -import static java.lang.Math.abs; import static nova.hetu.olk.mock.MockUtil.mockNewWithWithAnyArguments; import static nova.hetu.olk.mock.MockUtil.mockOmniOperator; import static org.junit.Assert.assertEquals; @@ -57,14 +57,14 @@ public class HashAggregationOmniOperatorTest { private final int operatorId = new Random().nextInt(); private final PlanNodeId planNodeId = new PlanNodeId(UUID.randomUUID().toString()); - private final int[] groupByInputChannels = {abs(new Random().nextInt()), abs(new Random().nextInt())}; + private final int[] groupByInputChannels = {0}; private final DataType[] groupByInputTypes = {}; - private final int[] aggregationInputChannels = {abs(new Random().nextInt()), abs(new Random().nextInt())}; + private final int[] aggregationInputChannels = {}; private final DataType[] aggregationInputTypes = {}; private final FunctionType[] aggregatorTypes = {}; private final AggregationNode.Step step = AggregationNode.Step.SINGLE; private final DataType[] aggregationOutputTypes = {}; - private final List sourceTypes = new ArrayList<>(); + private final List sourceTypes = Arrays.asList(BooleanType.BOOLEAN); private final List inAndOutputTypes = Arrays.asList(new DataType[]{}, new DataType[]{}); private OmniHashAggregationOperatorFactory omniHashAggregationOperatorFactory; private OmniOperator omniOperator; diff --git a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f400007aad9540233b68d66c00fd5d2b560d8d92 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt @@ -0,0 +1,39 @@ +# project name +project(spark-thestral-plugin) + +# required cmake version +cmake_minimum_required(VERSION 3.10) + +# configure cmake +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_COMPILER "g++") + +set(root_directory ${PROJECT_BINARY_DIR}) + +# configure file +configure_file( + "${PROJECT_SOURCE_DIR}/config.h.in" + "${PROJECT_SOURCE_DIR}/config.h" +) + +# for header searching +include_directories(SYSTEM src) + +# compile library +add_subdirectory(src) + +message(STATUS "Build by ${CMAKE_BUILD_TYPE}") + +option(BUILD_CPP_TESTS "test" ON) +message(STATUS "Option BUILD_CPP_TESTS: ${BUILD_CPP_TESTS}") +if(${BUILD_CPP_TESTS}) + enable_testing() + add_subdirectory(test) +endif () + +# options +option(DEBUG_RUNTIME "Debug" OFF) +message(STATUS "Option DEBUG: ${DEBUG_RUNTIME}") + +option(TRACE_RUNTIME "Trace" OFF) +message(STATUS "Option TRACE: ${TRACE_RUNTIME}") diff --git a/omnioperator/omniop-spark-extension/cpp/build.sh b/omnioperator/omniop-spark-extension/cpp/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..26f83e2cb612fc3af663d955df4bcabac177997b --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/build.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eu + +CURRENT_DIR=$(cd "$(dirname "$BASH_SOURCE")"; pwd) +echo $CURRENT_DIR +cd ${CURRENT_DIR} +if [ -d build ]; then + rm -r build +fi +mkdir build +cd build + +# options +if [ $# != 0 ] ; then + options="" + if [ $1 = 'debug' ]; then + echo "-- Enable Debug" + options="$options -DCMAKE_BUILD_TYPE=Debug -DDEBUG_RUNTIME=ON" + elif [ $1 = 'trace' ]; then + echo "-- Enable Trace" + options="$options -DCMAKE_BUILD_TYPE=Debug -DTRACE_RUNTIME=ON" + elif [ $1 = 'release' ];then + echo "-- Enable Release" + options="$options -DCMAKE_BUILD_TYPE=Release" + elif [ $1 = 'test' ];then + echo "-- Enable Test" + options="$options -DCMAKE_BUILD_TYPE=Test -DBUILD_CPP_TESTS=TRUE" + else + echo "-- Enable Release" + options="$options -DCMAKE_BUILD_TYPE=Release" + fi + cmake .. $options +else + echo "-- Enable Release" + cmake .. -DCMAKE_BUILD_TYPE=Release +fi + +make + +set +eu \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/config.h b/omnioperator/omniop-spark-extension/cpp/config.h new file mode 100644 index 0000000000000000000000000000000000000000..9c9637a16d96737d75c128d3fe0bda6c87c82172 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/config.h @@ -0,0 +1,20 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//#cmakedefine DEBUG_RUNTIME +//#cmakedefine TRACE_RUNTIME \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/config.h.in b/omnioperator/omniop-spark-extension/cpp/config.h.in new file mode 100644 index 0000000000000000000000000000000000000000..43c74967c62ab21066187cc4eaf6a692706f4a97 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/config.h.in @@ -0,0 +1,2 @@ +#cmakedefine DEBUG_RUNTIME +#cmakedefine TRACE_RUNTIME \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d233edd19c0f5e83b3ee8887976eca0b64519d6f --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt @@ -0,0 +1,58 @@ +include_directories(SYSTEM "/opt/lib/include") +include_directories(SYSTEM "/user/local/include") + +set (PROJ_TARGET spark_columnar_plugin) + + +set (SOURCE_FILES + io/ColumnWriter.cc + io/Compression.cc + io/MemoryPool.cc + io/OutputStream.cc + io/SparkFile.cc + io/WriterOptions.cc + shuffle/splitter.cpp + common/common.cpp + jni/SparkJniWrapper.cpp + jni/OrcColumnarBatchJniReader.cpp + ) + +#Find required protobuf package +find_package(Protobuf REQUIRED) +if(PROTOBUF_FOUND) + message(STATUS "protobuf library found") +else() + message(FATAL_ERROR "protobuf library is needed but cant be found") +endif() + +include_directories(${Protobuf_INCLUDE_DIRS}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +protobuf_generate_cpp(PROTO_SRCS_VB PROTO_HDRS_VB proto/vec_data.proto) +add_library (${PROJ_TARGET} SHARED ${SOURCE_FILES} ${PROTO_SRCS} ${PROTO_HDRS} ${PROTO_SRCS_VB} ${PROTO_HDRS_VB}) + +#JNI +target_include_directories(${PROJ_TARGET} PUBLIC $ENV{JAVA_HOME}/include) +target_include_directories(${PROJ_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) +target_include_directories(${PROJ_TARGET} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) +target_include_directories(${PROJ_TARGET} PUBLIC /opt/lib/include) + +target_link_libraries (${PROJ_TARGET} PUBLIC + orc + crypto + sasl2 + protobuf + z + snappy + lz4 + zstd + boostkit-omniop-runtime-1.0.0-aarch64 + boostkit-omniop-vector-1.0.0-aarch64 + ) + +set_target_properties(${PROJ_TARGET} PROPERTIES + LIBRARY_OUTPUT_DIRECTORY ${root_directory}/releases +) + +target_compile_options(${PROJ_TARGET} PUBLIC -g -O2 -fPIC) + +install(TARGETS ${PROJ_TARGET} DESTINATION lib) diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/BinaryLocation.h b/omnioperator/omniop-spark-extension/cpp/src/common/BinaryLocation.h new file mode 100644 index 0000000000000000000000000000000000000000..683b0fa9d7fd5cb2f849b3637816758c29bf53d4 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/common/BinaryLocation.h @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_THESTRAL_PLUGIN_BINARYLOCATION_H +#define SPARK_THESTRAL_PLUGIN_BINARYLOCATION_H +class VCLocation { +public: + VCLocation(uint64_t vc_addr, uint32_t vc_len, bool isnull) + : vc_len(vc_len), vc_addr(vc_addr), is_null(isnull) { + } + ~VCLocation() { + } + uint32_t get_vc_len() { + return vc_len; + } + + uint64_t get_vc_addr() { + return vc_addr; + } + + bool get_is_null() { + return is_null; + } + +public: + uint32_t vc_len; + uint64_t vc_addr; + bool is_null; +}; + +class VCBatchInfo { +public: + VCBatchInfo(uint32_t vcb_capacity) { + this->vc_list.reserve(vcb_capacity); + this->vcb_capacity = vcb_capacity; + this->vcb_total_len = 0; + } + + ~VCBatchInfo() { + vc_list.clear(); + } + + uint32_t getVcbCapacity() { + return vcb_capacity; + } + + uint32_t getVcbTotalLen() { + return vcb_total_len; + } + + std::vector &getVcList() { + return vc_list; + } + +public: + uint32_t vcb_capacity; + uint32_t vcb_total_len; + std::vector vc_list; + + +}; +#endif //SPARK_THESTRAL_PLUGIN_BINARYLOCATION_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h b/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h new file mode 100644 index 0000000000000000000000000000000000000000..64160069323ee0eece5f5485a38ce00a6dc4f9eb --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #ifndef CPP_BUFFER_H + #define CPP_BUFFER_H + + #include + #include + #include + #include + #include + + class Buffer { + public: + Buffer(uint8_t* data, int64_t size, int64_t capacity) + : data_(data), + size_(size), + capacity_(capacity) { + } + + public: + uint8_t * data_; + int64_t size_; + int64_t capacity_; + }; + + #endif //CPP_BUFFER_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp b/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c6b9fab89ec31c7df596cc4e9b14e3f869a12b2 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp @@ -0,0 +1,96 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common.h" + +using namespace omniruntime::vec; + +int32_t BytesGen(uint64_t offsetsAddr, uint64_t nullsAddr, uint64_t valuesAddr, VCBatchInfo& vcb) +{ + int32_t* offsets = reinterpret_cast(offsetsAddr); + char *nulls = reinterpret_cast(nullsAddr); + char* values = reinterpret_cast(valuesAddr); + std::vector &lst = vcb.getVcList(); + int itemsTotalLen = lst.size(); + int valueTotalLen = 0; + for (int i = 0; i < itemsTotalLen; i++) { + char* addr = reinterpret_cast(lst[i].get_vc_addr()); + int len = lst[i].get_vc_len(); + if (i == 0) { + offsets[0] = 0; + } else { + offsets[i] = offsets[i -1] + lst[i - 1].get_vc_len(); + } + if (lst[i].get_is_null()) { + nulls[i] = 1; + } else { + nulls[i] = 0; + } + if (len != 0) { + memcpy((char *) (values + offsets[i]), addr, len); + valueTotalLen += len; + } + } + offsets[itemsTotalLen] = offsets[itemsTotalLen -1] + lst[itemsTotalLen - 1].get_vc_len(); + return valueTotalLen; +} + +uint32_t reversebytes_uint32t(uint32_t const value) +{ + return (value & 0x000000FFU) << 24 | (value & 0x0000FF00U) << 8 | (value & 0x00FF0000U) >> 8 | (value & 0xFF000000U) >> 24; +} + +spark::CompressionKind GetCompressionType(const std::string& name) { + if (name == "uncompressed") { + return spark::CompressionKind::CompressionKind_NONE; + } else if (name == "zlib") { + return spark::CompressionKind::CompressionKind_ZLIB; + } else if (name == "snappy") { + return spark::CompressionKind::CompressionKind_SNAPPY; + } else if (name == "lz4") { + return spark::CompressionKind::CompressionKind_LZ4; + } else if (name == "zstd") { + return spark::CompressionKind::CompressionKind_ZSTD; + } else { + throw std::logic_error("compression codec not supported"); + } +} + +// return: 1 文件存在可访问 +// 0 文件不存在或不能访问 +int IsFileExist(const std::string path) +{ + return !access(path.c_str(), F_OK); +} + +void ReleaseVectorBatch(omniruntime::vec::VectorBatch& vb) +{ + int tmpVectorNum = vb.GetVectorCount(); + std::set vectorBatchAddresses; + vectorBatchAddresses.clear(); + for (int vecIndex = 0; vecIndex < tmpVectorNum; ++vecIndex) { + vectorBatchAddresses.insert(vb.GetVector(vecIndex)); + } + for (Vector * tmpAddress : vectorBatchAddresses) { + if (nullptr == tmpAddress) { + throw std::runtime_error("delete nullptr error for release vectorBatch"); + } + delete tmpAddress; + } + vectorBatchAddresses.clear(); +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/common.h b/omnioperator/omniop-spark-extension/cpp/src/common/common.h new file mode 100644 index 0000000000000000000000000000000000000000..fdc3b10e692e3944eeee9cf70f96ed47262a5e77 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/common/common.h @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CPP_COMMON_H +#define CPP_COMMON_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../io/Common.hh" +#include "../utils/macros.h" +#include "BinaryLocation.h" +#include "debug.h" +#include "Buffer.h" +#include "BinaryLocation.h" + +int32_t BytesGen(uint64_t offsets, uint64_t nulls, uint64_t values, VCBatchInfo& vcb); + +uint32_t reversebytes_uint32t(uint32_t value); + +spark::CompressionKind GetCompressionType(const std::string& name); + +int IsFileExist(const std::string path); + +void ReleaseVectorBatch(omniruntime::vec::VectorBatch& vb); + +#endif //CPP_COMMON_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/debug.h b/omnioperator/omniop-spark-extension/cpp/src/common/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..39415e255f40fe1a03808f3fa9dcc00f0ad44159 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/common/debug.h @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "../../config.h" +#include "util/global_log.h" + +#ifdef TRACE_RUNTIME +#define LogsTrace(format, ...) \ + do { \ + printf("[TRACE][%s][%s][%d]:" format "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__); \ + } while (0) +#else +#define LogsTrace(format, ...) +#endif + + +#define LogsDebug(format, ...) \ + do { \ + if (static_cast(LogType::LOG_DEBUG) >= GetLogLevel()) { \ + char logBuf[GLOBAL_LOG_BUF_SIZE]; \ + LogsInfoVargMacro(logBuf, format, ##__VA_ARGS__); \ + std::string logString(logBuf); \ + Log(logString, LogType::LOG_DEBUG); \ + } \ + } while (0) + + +#define LogsInfo(format, ...) \ + do { \ + if (static_cast(LogType::LOG_INFO) >= GetLogLevel()) { \ + char logBuf[GLOBAL_LOG_BUF_SIZE]; \ + LogsInfoVargMacro(logBuf, format, ##__VA_ARGS__); \ + std::string logString(logBuf); \ + Log(logString, LogType::LOG_INFO); \ + } \ + } while (0) + +#define LogsWarn(format, ...) \ + do { \ + if (static_cast(LogType::LOG_WARN) >= GetLogLevel()) { \ + char logBuf[GLOBAL_LOG_BUF_SIZE]; \ + LogsInfoVargMacro(logBuf, format, ##__VA_ARGS__); \ + std::string logString(logBuf); \ + Log(logString, LogType::LOG_WARN); \ + } \ + } while (0) + +#define LogsError(format, ...) \ + do { \ + if (static_cast(LogType::LOG_ERROR) >= GetLogLevel()) { \ + char logBuf[GLOBAL_LOG_BUF_SIZE]; \ + LogsInfoVargMacro(logBuf, format, ##__VA_ARGS__); \ + std::string logString(logBuf); \ + Log(logString, LogType::LOG_ERROR); \ + } \ + } while (0) \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/Adaptor.hh b/omnioperator/omniop-spark-extension/cpp/src/io/Adaptor.hh new file mode 100644 index 0000000000000000000000000000000000000000..5ae477e8bc2933bf0ac512a919addb7322bfbf32 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/Adaptor.hh @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ADAPTER_HH +#define ADAPTER_HH + +#define PRAGMA(TXT) _Pragma(#TXT) + +#ifdef __clang__ + #define DIAGNOSTIC_IGNORE(XXX) PRAGMA(clang diagnostic ignored XXX) +#elif defined(__GNUC__) + #define DIAGNOSTIC_IGNORE(XXX) PRAGMA(GCC diagnostic ignored XXX) +#elif defined(_MSC_VER) + #define DIAGNOSTIC_IGNORE(XXX) __pragma(warning(disable : XXX)) +#else + #define DIAGNOSTIC_IGNORE(XXX) +#endif + +#endif \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/ColumnWriter.cc b/omnioperator/omniop-spark-extension/cpp/src/io/ColumnWriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..8804070d51db86984ef53aebdef79b8ffba72f03 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/ColumnWriter.cc @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "WriterOptions.hh" + +#include "ColumnWriter.hh" + +namespace spark { + StreamsFactory::~StreamsFactory() { + //PASS + } + + class StreamsFactoryImpl : public StreamsFactory { + public: + StreamsFactoryImpl( + const WriterOptions& writerOptions, + OutputStream* outputStream) : + options(writerOptions), + outStream(outputStream) { + } + + virtual std::unique_ptr + createStream() const override; + private: + const WriterOptions& options; + OutputStream* outStream; + }; + + std::unique_ptr StreamsFactoryImpl::createStream() const { + return createCompressor( + options.getCompression(), + outStream, + options.getCompressionStrategy(), + // BufferedOutputStream initial capacity + 1 * 1024 * 1024, + options.getCompressionBlockSize(), + *options.getMemoryPool()); + } + + std::unique_ptr createStreamsFactory( + const WriterOptions& options, + OutputStream* outStream) { + return std::unique_ptr( + new StreamsFactoryImpl(options, outStream)); + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/ColumnWriter.hh b/omnioperator/omniop-spark-extension/cpp/src/io/ColumnWriter.hh new file mode 100644 index 0000000000000000000000000000000000000000..72cfe0e4bd483b83fef98305e724f51f3e193c02 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/ColumnWriter.hh @@ -0,0 +1,39 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_COLUMN_WRITER_HH +#define SPARK_COLUMN_WRITER_HH + +#include "Compression.hh" + +namespace spark { + + class StreamsFactory { + public: + virtual ~StreamsFactory(); + + virtual std::unique_ptr + createStream() const = 0; + }; + + std::unique_ptr createStreamsFactory( + const WriterOptions& options, + OutputStream * outStream); +} + +#endif \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/Common.hh b/omnioperator/omniop-spark-extension/cpp/src/io/Common.hh new file mode 100644 index 0000000000000000000000000000000000000000..e240363567544263de2d8daed503b079ae649fae --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/Common.hh @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_COMMON_HH +#define SPARK_COMMON_HH + +namespace spark { + + enum CompressionKind { + CompressionKind_NONE = 0, + CompressionKind_ZLIB = 1, + CompressionKind_SNAPPY = 2, + CompressionKind_LZO = 3, + CompressionKind_LZ4 = 4, + CompressionKind_ZSTD = 5 + }; +} + +#endif \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/Compression.cc b/omnioperator/omniop-spark-extension/cpp/src/io/Compression.cc new file mode 100644 index 0000000000000000000000000000000000000000..720f7ff1951b389fa230c1e074bdb6e9619f644b --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/Compression.cc @@ -0,0 +1,644 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Adaptor.hh" +#include "Compression.hh" +#include "lz4.h" + +#include +#include +#include +#include + +#include "zlib.h" +#include "zstd.h" + +#include "wrap/snappy_wrapper.h" + +#ifndef ZSTD_CLEVEL_DEFAULT +#define ZSTD_CLEVEL_DEFAULT 3 +#endif + +/* These macros are defined in lz4.c */ +#ifndef LZ4_ACCELERATION_DEFAULT +#define LZ4_ACCELERATION_DEFAULT 1 +#endif + +#ifndef LZ4_ACCELERATION_MAX +#define LZ4_ACCELERATION_MAX 65537 +#endif + +namespace spark { + + class CompressionStreamBase: public BufferedOutputStream { + public: + CompressionStreamBase(OutputStream * outStream, + int compressionLevel, + uint64_t capacity, + uint64_t blockSize, + MemoryPool& pool); + + virtual bool Next(void** data, int*size) override = 0; + virtual void BackUp(int count) override; + + virtual std::string getName() const override = 0; + virtual uint64_t flush() override; + + virtual bool isCompressed() const override { return true; } + virtual uint64_t getSize() const override; + + protected: + void writeHeader(char * buffer, size_t compressedSize, bool original) { + buffer[0] = static_cast((compressedSize << 1) + (original ? 1 : 0)); + buffer[1] = static_cast(compressedSize >> 7); + buffer[2] = static_cast(compressedSize >> 15); + } + + // ensure enough room for compression block header + void ensureHeader(); + + // Buffer to hold uncompressed data until user calls Next() + DataBuffer rawInputBuffer; + + // Compress level + int level; + + // Compressed data output buffer + char * outputBuffer; + + // Size for compressionBuffer + int bufferSize; + + // Compress output position + int outputPosition; + + // Compress output buffer size + int outputSize; + }; + + CompressionStreamBase::CompressionStreamBase(OutputStream * outStream, + int compressionLevel, + uint64_t capacity, + uint64_t blockSize, + MemoryPool& pool) : + BufferedOutputStream(pool, + outStream, + capacity, + blockSize), + rawInputBuffer(pool, blockSize), + level(compressionLevel), + outputBuffer(nullptr), + bufferSize(0), + outputPosition(0), + outputSize(0) { + // PASS + } + + void CompressionStreamBase::BackUp(int count) { + if (count > bufferSize) { + throw std::logic_error("Can't backup that much!"); + } + bufferSize -= count; + } + + uint64_t CompressionStreamBase::flush() { + void * data; + int size; + if (!Next(&data, &size)) { + throw std::runtime_error("Failed to flush compression buffer."); + } + BufferedOutputStream::BackUp(outputSize - outputPosition); + bufferSize = outputSize = outputPosition = 0; + return BufferedOutputStream::flush(); + } + + uint64_t CompressionStreamBase::getSize() const { + return BufferedOutputStream::getSize() - + static_cast(outputSize - outputPosition); + } + + void CompressionStreamBase::ensureHeader() { + // adjust 3 bytes for the compression header + if (outputPosition + 3 >= outputSize) { + int newPosition = outputPosition + 3 - outputSize; + if (!BufferedOutputStream::Next( + reinterpret_cast(&outputBuffer), + &outputSize)) { + throw std::runtime_error( + "Failed to get next output buffer from output stream."); + } + outputPosition = newPosition; + } else { + outputPosition += 3; + } + } + + /** + * Streaming compression base class + */ + class CompressionStream: public CompressionStreamBase { + public: + CompressionStream(OutputStream * outStream, + int compressionLevel, + uint64_t capacity, + uint64_t blockSize, + MemoryPool& pool); + + virtual bool Next(void** data, int*size) override; + virtual std::string getName() const override = 0; + + protected: + // return total compressed size + virtual uint64_t doStreamingCompression() = 0; + }; + + CompressionStream::CompressionStream(OutputStream * outStream, + int compressionLevel, + uint64_t capacity, + uint64_t blockSize, + MemoryPool& pool) : + CompressionStreamBase(outStream, + compressionLevel, + capacity, + blockSize, + pool) { + // PASS + } + + bool CompressionStream::Next(void** data, int*size) { + if (bufferSize != 0) { + ensureHeader(); + + uint64_t totalCompressedSize = doStreamingCompression(); + + char * header = outputBuffer + outputPosition - totalCompressedSize - 3; + if (totalCompressedSize >= static_cast(bufferSize)) { + writeHeader(header, static_cast(bufferSize), true); + memcpy( + header + 3, + rawInputBuffer.data(), + static_cast(bufferSize)); + + int backup = static_cast(totalCompressedSize) - bufferSize; + BufferedOutputStream::BackUp(backup); + outputPosition -= backup; + outputSize -= backup; + } else { + writeHeader(header, totalCompressedSize, false); + } + } + + *data = rawInputBuffer.data(); + *size = static_cast(rawInputBuffer.size()); + bufferSize = *size; + + return true; + } + + class ZlibCompressionStream: public CompressionStream { + public: + ZlibCompressionStream(OutputStream * outStream, + int compressionLevel, + uint64_t capacity, + uint64_t blockSize, + MemoryPool& pool); + + virtual ~ZlibCompressionStream() override { + end(); + } + + virtual std::string getName() const override; + + protected: + virtual uint64_t doStreamingCompression() override; + + private: + void init(); + void end(); + z_stream strm; + }; + + ZlibCompressionStream::ZlibCompressionStream( + OutputStream * outStream, + int compressionLevel, + uint64_t capacity, + uint64_t blockSize, + MemoryPool& pool) + : CompressionStream(outStream, + compressionLevel, + capacity, + blockSize, + pool) { + init(); + } + + uint64_t ZlibCompressionStream::doStreamingCompression() { + if (deflateReset(&strm) != Z_OK) { + throw std::runtime_error("Failed to reset inflate."); + } + + strm.avail_in = static_cast(bufferSize); + strm.next_in = rawInputBuffer.data(); + + do { + if (outputPosition >= outputSize) { + if (!BufferedOutputStream::Next( + reinterpret_cast(&outputBuffer), + &outputSize)) { + throw std::runtime_error( + "Failed to get next output buffer from output stream."); + } + outputPosition = 0; + } + strm.next_out = reinterpret_cast + (outputBuffer + outputPosition); + strm.avail_out = static_cast + (outputSize - outputPosition); + + int ret = deflate(&strm, Z_FINISH); + outputPosition = outputSize - static_cast(strm.avail_out); + + if (ret == Z_STREAM_END) { + break; + } else if (ret == Z_OK) { + // needs more buffer so will continue the loop + } else { + throw std::runtime_error("Failed to deflate input data."); + } + } while (strm.avail_out == 0); + + return strm.total_out; + } + + std::string ZlibCompressionStream::getName() const { + return "ZlibCompressionStream"; + } + +// DIAGNOSTIC_PUSH + +#if defined(__GNUC__) || defined(__clang__) + DIAGNOSTIC_IGNORE("-Wold-style-cast") +#endif + + void ZlibCompressionStream::init() { + strm.zalloc = nullptr; + strm.zfree = nullptr; + strm.opaque = nullptr; + strm.next_in = nullptr; + + if (deflateInit2(&strm, level, Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY) + != Z_OK) { + throw std::runtime_error("Error while calling deflateInit2() for zlib."); + } + } + + void ZlibCompressionStream::end() { + (void)deflateEnd(&strm); + } + +// DIAGNOSTIC_PUSH + + enum DecompressState { DECOMPRESS_HEADER, + DECOMPRESS_START, + DECOMPRESS_CONTINUE, + DECOMPRESS_ORIGINAL, + DECOMPRESS_EOF}; + +// DIAGNOSTIC_PUSH + +#if defined(__GNUC__) || defined(__clang__) + DIAGNOSTIC_IGNORE("-Wold-style-cast") +#endif + + /** + * Block compression base class + */ + class BlockCompressionStream: public CompressionStreamBase { + public: + BlockCompressionStream(OutputStream * outStream, + int compressionLevel, + uint64_t capacity, + uint64_t blockSize, + MemoryPool& pool) + : CompressionStreamBase(outStream, + compressionLevel, + capacity, + blockSize, + pool) + , compressorBuffer(pool) { + // PASS + } + + virtual bool Next(void** data, int*size) override; + virtual std::string getName() const override = 0; + + protected: + // compresses a block and returns the compressed size + virtual uint64_t doBlockCompression() = 0; + + // return maximum possible compression size for allocating space for + // compressorBuffer below + virtual uint64_t estimateMaxCompressionSize() = 0; + + // should allocate max possible compressed size + DataBuffer compressorBuffer; + }; + + bool BlockCompressionStream::Next(void** data, int*size) { + if (bufferSize != 0) { + ensureHeader(); + + // perform compression + size_t totalCompressedSize = doBlockCompression(); + + const unsigned char * dataToWrite = nullptr; + int totalSizeToWrite = 0; + char * header = outputBuffer + outputPosition - 3; + + if (totalCompressedSize >= static_cast(bufferSize)) { + writeHeader(header, static_cast(bufferSize), true); + dataToWrite = rawInputBuffer.data(); + totalSizeToWrite = bufferSize; + } else { + writeHeader(header, totalCompressedSize, false); + dataToWrite = compressorBuffer.data(); + totalSizeToWrite = static_cast(totalCompressedSize); + } + + char * dst = header + 3; + while (totalSizeToWrite > 0) { + if (outputPosition == outputSize) { + if (!BufferedOutputStream::Next(reinterpret_cast(&outputBuffer), + &outputSize)) { + throw std::logic_error( + "Failed to get next output buffer from output stream."); + } + outputPosition = 0; + dst = outputBuffer; + } else if (outputPosition > outputSize) { + // this will unlikely happen, but we have seen a few on zstd v1.1.0 + throw std::logic_error("Write to an out-of-bound place!"); + } + + int sizeToWrite = std::min(totalSizeToWrite, outputSize - outputPosition); + memcpy(dst, dataToWrite, static_cast(sizeToWrite)); + + outputPosition += sizeToWrite; + dataToWrite += sizeToWrite; + totalSizeToWrite -= sizeToWrite; + dst += sizeToWrite; + } + } + + *data = rawInputBuffer.data(); + *size = static_cast(rawInputBuffer.size()); + bufferSize = *size; + compressorBuffer.resize(estimateMaxCompressionSize()); + + return true; + } + + /** + * LZ4 block compression + */ + class Lz4CompressionSteam: public BlockCompressionStream { + public: + Lz4CompressionSteam(OutputStream * outStream, + int compressionLevel, + uint64_t capacity, + uint64_t blockSize, + MemoryPool& pool) + : BlockCompressionStream(outStream, + compressionLevel, + capacity, + blockSize, + pool) { + this->init(); + } + + virtual std::string getName() const override { + return "Lz4CompressionStream"; + } + + virtual ~Lz4CompressionSteam() override { + this->end(); + } + + protected: + virtual uint64_t doBlockCompression() override; + + virtual uint64_t estimateMaxCompressionSize() override { + return static_cast(LZ4_compressBound(bufferSize)); + } + + private: + void init(); + void end(); + LZ4_stream_t *state; + }; + + uint64_t Lz4CompressionSteam::doBlockCompression() { + int result = LZ4_compress_fast_extState(static_cast(state), + reinterpret_cast(rawInputBuffer.data()), + reinterpret_cast(compressorBuffer.data()), + bufferSize, + static_cast(compressorBuffer.size()), + level); + if (result == 0) { + throw std::runtime_error("Error during block compression using lz4."); + } + return static_cast(result); + } + + void Lz4CompressionSteam::init() { + state = LZ4_createStream(); + if (!state) { + throw std::runtime_error("Error while allocating state for lz4."); + } + } + + void Lz4CompressionSteam::end() { + (void)LZ4_freeStream(state); + state = nullptr; + } + + /** + * Snappy block compression + */ + class SnappyCompressionStream: public BlockCompressionStream { + public: + SnappyCompressionStream(OutputStream * outStream, + int compressionLevel, + uint64_t capacity, + uint64_t blockSize, + MemoryPool& pool) + : BlockCompressionStream(outStream, + compressionLevel, + capacity, + blockSize, + pool) { + } + + virtual std::string getName() const override { + return "SnappyCompressionStream"; + } + + virtual ~SnappyCompressionStream() override { + // PASS + } + + protected: + virtual uint64_t doBlockCompression() override; + + virtual uint64_t estimateMaxCompressionSize() override { + return static_cast + (snappy::MaxCompressedLength(static_cast(bufferSize))); + } + }; + + uint64_t SnappyCompressionStream::doBlockCompression() { + size_t compressedLength; + snappy::RawCompress(reinterpret_cast(rawInputBuffer.data()), + static_cast(bufferSize), + reinterpret_cast(compressorBuffer.data()), + &compressedLength); + return static_cast(compressedLength); + } + + /** + * ZSTD block compression + */ + class ZSTDCompressionStream: public BlockCompressionStream{ + public: + ZSTDCompressionStream(OutputStream * outStream, + int compressionLevel, + uint64_t capacity, + uint64_t blockSize, + MemoryPool& pool) + : BlockCompressionStream(outStream, + compressionLevel, + capacity, + blockSize, + pool) { + this->init(); + } + + virtual std::string getName() const override { + return "ZstdCompressionStream"; + } + + virtual ~ZSTDCompressionStream() override { + this->end(); + } + + protected: + virtual uint64_t doBlockCompression() override; + + virtual uint64_t estimateMaxCompressionSize() override { + return ZSTD_compressBound(static_cast(bufferSize)); + } + + private: + void init(); + void end(); + ZSTD_CCtx *cctx; + }; + + uint64_t ZSTDCompressionStream::doBlockCompression() { + return ZSTD_compressCCtx(cctx, + compressorBuffer.data(), + compressorBuffer.size(), + rawInputBuffer.data(), + static_cast(bufferSize), + level); + } + +// DIAGNOSTIC_PUSH + +#if defined(__GNUC__) || defined(__clang__) + DIAGNOSTIC_IGNORE("-Wold-style-cast") +#endif + + void ZSTDCompressionStream::init() { + + cctx = ZSTD_createCCtx(); + if (!cctx) { + throw std::runtime_error("Error while calling ZSTD_createCCtx() for zstd."); + } + } + + + void ZSTDCompressionStream::end() { + (void)ZSTD_freeCCtx(cctx); + cctx = nullptr; + } + +#if defined(__GNUC__) || defined(__clang__) + DIAGNOSTIC_IGNORE("-Wold-style-cast") +#endif + +// DIAGNOSTIC_PUSH + + std::unique_ptr + createCompressor( + CompressionKind kind, + OutputStream * outStream, + CompressionStrategy strategy, + uint64_t bufferCapacity, + uint64_t compressionBlockSize, + MemoryPool& pool) { + switch (static_cast(kind)) { + case CompressionKind_NONE: { + return std::unique_ptr + (new BufferedOutputStream( + pool, outStream, bufferCapacity, compressionBlockSize)); + } + case CompressionKind_ZLIB: { + int level = (strategy == CompressionStrategy_SPEED) ? + Z_BEST_SPEED + 1 : Z_DEFAULT_COMPRESSION; + return std::unique_ptr + (new ZlibCompressionStream( + outStream, level, bufferCapacity, compressionBlockSize, pool)); + } + case CompressionKind_ZSTD: { + int level = (strategy == CompressionStrategy_SPEED) ? + 1 : ZSTD_CLEVEL_DEFAULT; + return std::unique_ptr + (new ZSTDCompressionStream( + outStream, level, bufferCapacity, compressionBlockSize, pool)); + } + case CompressionKind_LZ4: { + int level = (strategy == CompressionStrategy_SPEED) ? + LZ4_ACCELERATION_MAX : LZ4_ACCELERATION_DEFAULT; + return std::unique_ptr + (new Lz4CompressionSteam( + outStream, level, bufferCapacity, compressionBlockSize, pool)); + } + case CompressionKind_SNAPPY: { + int level = 0; + return std::unique_ptr + (new SnappyCompressionStream( + outStream, level, bufferCapacity, compressionBlockSize, pool)); + } + case CompressionKind_LZO: + default: + throw std::logic_error("compression codec not supported"); + } + } + +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/Compression.hh b/omnioperator/omniop-spark-extension/cpp/src/io/Compression.hh new file mode 100644 index 0000000000000000000000000000000000000000..6364d675038af460ed50a81ad5e316edb52f0df8 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/Compression.hh @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_COMPRESSION_HH +#define SPARK_COMPRESSION_HH + +#include "OutputStream.hh" +#include "Common.hh" +#include "WriterOptions.hh" + +namespace spark { + /** + * Create a compressor for the given compression kind. + * @param kind the compression type to implement + * @param outStream the output stream that is the underlying target + * @param strategy compression strategy + * @param bufferCapacity compression stream buffer total capacity + * @param compressionBlockSize compression buffer block size + * @param pool the memory pool + */ + std::unique_ptr + createCompressor(CompressionKind kind, + OutputStream * outStream, + CompressionStrategy strategy, + uint64_t bufferCapacity, + uint64_t compressionBlockSize, + MemoryPool& pool); +} + +#endif \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/MemoryPool.cc b/omnioperator/omniop-spark-extension/cpp/src/io/MemoryPool.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4fd9e345672ad8fdc8821f85eef9cf2323b74ad --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/MemoryPool.cc @@ -0,0 +1,158 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MemoryPool.hh" + +#include "Adaptor.hh" + +#include +#include +#include + +namespace spark { + + MemoryPool::~MemoryPool() { + // PASS + } + + class MemoryPoolImpl: public MemoryPool { + public: + virtual ~MemoryPoolImpl() override; + + char* malloc(uint64_t size) override; + void free(char* p) override; + }; + + char* MemoryPoolImpl::malloc(uint64_t size) { + return static_cast(std::malloc(size)); + } + + void MemoryPoolImpl::free(char* p) { + std::free(p); + } + + MemoryPoolImpl::~MemoryPoolImpl() { + // PASS + } + + template + DataBuffer::DataBuffer(MemoryPool& pool, + uint64_t newSize + ): memoryPool(pool), + buf(nullptr), + currentSize(0), + currentCapacity(0) { + resize(newSize); + } + + template + DataBuffer::~DataBuffer() { + for (uint64_t i = currentSize; i > 0; --i) { + (buf + i - 1)->~T(); + } + if (buf) { + memoryPool.free(reinterpret_cast(buf)); + } + } + + template + void DataBuffer::resize(uint64_t newSize) { + reserve(newSize); + if (currentSize > newSize) { + for (uint64_t i = currentSize; i > newSize; --i) { + (buf + i -1)->~T(); + } + } else if (newSize > currentSize) { + for (uint64_t i = currentSize; i < newSize; ++i) { + new (buf + i) T(); + } + } + currentSize = newSize; + } + + template + void DataBuffer::reserve(uint64_t newCapacity) { + if (newCapacity > currentCapacity || !buf) { + if (buf) { + T* buf_old = buf; + buf = reinterpret_cast(memoryPool.malloc(sizeof(T) * newCapacity)); + memcpy(buf, buf_old, sizeof(T) * currentSize); + memoryPool.free(reinterpret_cast(buf_old)); + } else { + buf = reinterpret_cast(memoryPool.malloc(sizeof(T) * newCapacity)); + } + currentCapacity = newCapacity; + } + } + + // Specializations for char + + template <> + DataBuffer::~DataBuffer() { + if (buf) { + memoryPool.free(reinterpret_cast(buf)); + } + } + + template <> + void DataBuffer::resize(uint64_t newSize) { + reserve(newSize); + if (newSize > currentSize) { + memset(buf + currentSize, 0, newSize - currentSize); + } + currentSize = newSize; + } + + // Specializations for unsigned char + + template <> + DataBuffer::~DataBuffer() { + if (buf) { + memoryPool.free(reinterpret_cast(buf)); + } + } + + template <> + void DataBuffer::resize(uint64_t newSize) { + reserve(newSize); + if (newSize > currentSize) { + memset(buf + currentSize, 0, newSize - currentSize); + } + currentSize = newSize; + } + + #ifdef __clang__ + #pragma clang diagnostic ignored "-Wweak-template-vtables" + #endif + + template class DataBuffer; + template class DataBuffer; + template class DataBuffer; + template class DataBuffer; + template class DataBuffer; + template class DataBuffer; + + #ifdef __clang__ + #pragma clang diagnostic ignored "-Wexit-time-destructors" + #endif + + MemoryPool* getDefaultPool() { + static MemoryPoolImpl internal; + return &internal; + } +} // namespace spark \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/MemoryPool.hh b/omnioperator/omniop-spark-extension/cpp/src/io/MemoryPool.hh new file mode 100644 index 0000000000000000000000000000000000000000..0c267d91ced5a98b5072a423a7840f373b69cee0 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/MemoryPool.hh @@ -0,0 +1,111 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MEMORYPOOL_HH_ +#define MEMORYPOOL_HH_ + +#include + +namespace spark { + + class MemoryPool { + public: + virtual ~MemoryPool(); + + virtual char* malloc(uint64_t size) = 0; + virtual void free(char* p) = 0; + }; + MemoryPool* getDefaultPool(); + + template + class DataBuffer { + private: + MemoryPool& memoryPool; + T* buf; + // current size + uint64_t currentSize; + // maximal capacity (actual allocated memory) + uint64_t currentCapacity; + + // not implemented + DataBuffer(DataBuffer& buffer); + DataBuffer& operator = (DataBuffer& buffer); + + public: + DataBuffer(MemoryPool& pool, uint64_t _size = 0); + virtual ~DataBuffer(); + + T* data() { + return buf; + } + + const T* data() const { + return buf; + } + + uint64_t size() { + return currentSize; + } + + uint64_t capacity() { + return currentCapacity; + } + + T& operator[](uint64_t i) { + return buf[i]; + } + + void reserve(uint64_t _size); + void resize(uint64_t _size); + }; + + // Specializations for char + + template <> + DataBuffer::~DataBuffer(); + + template <> + void DataBuffer::resize(uint64_t newSize); + + // Specializations for unsigned char + + template <> + DataBuffer::~DataBuffer(); + + template <> + void DataBuffer::resize(uint64_t newSize); + + #ifdef __clang__ + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wweak-template-vtables" + #endif + + extern template class DataBuffer; + extern template class DataBuffer; + extern template class DataBuffer; + extern template class DataBuffer; + extern template class DataBuffer; + extern template class DataBuffer; + + #ifdef __clang__ + #pragma clang diagnostic pop + #endif +} // namespace spark + + +#endif /* MEMORYPOOL_HH_ */ diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/OutputStream.cc b/omnioperator/omniop-spark-extension/cpp/src/io/OutputStream.cc new file mode 100644 index 0000000000000000000000000000000000000000..657bb9827c4ed55dec4f2b83a61b99ff77def2c5 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/OutputStream.cc @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OutputStream.hh" + +#include + +namespace spark { + + BufferedOutputStream::BufferedOutputStream( + MemoryPool& pool, + OutputStream * outStream, + uint64_t capacity_, + uint64_t blockSize_) + : outputStream(outStream), + blockSize(blockSize_) { + dataBuffer.reset(new DataBuffer(pool)); + dataBuffer->reserve(capacity_); + } + + BufferedOutputStream::~BufferedOutputStream() { + // PASS + } + + bool BufferedOutputStream::Next(void** buffer, int* size) { + *size = static_cast(blockSize); + uint64_t oldSize = dataBuffer->size(); + uint64_t newSize = oldSize + blockSize; + uint64_t newCapacity = dataBuffer->capacity(); + while (newCapacity < newSize) { + newCapacity += dataBuffer->capacity(); + } + dataBuffer->reserve(newCapacity); + dataBuffer->resize(newSize); + *buffer = dataBuffer->data() + oldSize; + return true; + } + + bool BufferedOutputStream::NextNBytes(void** buffer, int size) { + uint64_t oldSize = dataBuffer->size(); + uint64_t newSize = oldSize + size; + uint64_t newCapacity = dataBuffer->capacity(); + while (newCapacity < newSize) { + newCapacity += dataBuffer->capacity(); + } + dataBuffer->reserve(newCapacity); + dataBuffer->resize(newSize); + *buffer = dataBuffer->data() + oldSize; + return true; + } + + void BufferedOutputStream::BackUp(int count) { + if (count >= 0) { + uint64_t unsignedCount = static_cast(count); + if (unsignedCount <= dataBuffer->size()) { + dataBuffer->resize(dataBuffer->size() - unsignedCount); + } else { + throw std::logic_error("Can't backup that much!"); + } + } + } + + google::protobuf::int64 BufferedOutputStream::ByteCount() const { + return static_cast(dataBuffer->size()); + } + + bool BufferedOutputStream::WriteAliasedRaw(const void *, int) { + throw std::logic_error("WriteAliasedRaw is not supported."); + } + + bool BufferedOutputStream::AllowsAliasing() const { + return false; + } + + std::string BufferedOutputStream::getName() const { + std::ostringstream result; + result << "BufferedOutputStream " << dataBuffer->size() << " of " + << dataBuffer->capacity(); + return result.str(); + } + + uint64_t BufferedOutputStream::getSize() const { + return dataBuffer->size(); + } + + uint64_t BufferedOutputStream::flush() { + uint64_t dataSize = dataBuffer->size(); + outputStream->write(dataBuffer->data(), dataSize); + dataBuffer->resize(0); + return dataSize; + } + +} diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/OutputStream.hh b/omnioperator/omniop-spark-extension/cpp/src/io/OutputStream.hh new file mode 100644 index 0000000000000000000000000000000000000000..0ab4bad3c73a6188e1f142ac09ce07d39dc424f0 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/OutputStream.hh @@ -0,0 +1,63 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_OUTPUTSTREAM_HH +#define SPARK_OUTPUTSTREAM_HH + +#include "SparkFile.hh" +#include "MemoryPool.hh" +#include "wrap/zero_copy_stream_wrapper.h" + +namespace spark { + + /** + * A subclass of Google's ZeroCopyOutputStream that supports output to memory + * buffer, and flushing to OutputStream. + * By extending Google's class, we get the ability to pass it directly + * to the protobuf writers. + */ + class BufferedOutputStream: public google::protobuf::io::ZeroCopyOutputStream { + private: + OutputStream * outputStream; + std::unique_ptr > dataBuffer; + uint64_t blockSize; + + public: + BufferedOutputStream(MemoryPool& pool, + OutputStream * outStream, + uint64_t capacity, + uint64_t block_size); + virtual ~BufferedOutputStream() override; + + virtual bool Next(void** data, int*size) override; + virtual void BackUp(int count) override; + virtual google::protobuf::int64 ByteCount() const override; + virtual bool WriteAliasedRaw(const void * data, int size) override; + virtual bool AllowsAliasing() const override; + + virtual std::string getName() const; + virtual uint64_t getSize() const; + virtual uint64_t flush(); + virtual bool NextNBytes(void** data, int size); + + virtual bool isCompressed() const { return false; } + }; + +} + +#endif // SPARK_OUTPUTSTREAM_HH diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc b/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc new file mode 100644 index 0000000000000000000000000000000000000000..51ff4b98f3eb4df234d927276e75bce1cb7bc158 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc @@ -0,0 +1,179 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Adaptor.hh" +#include "SparkFile.hh" + +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +#include +#define S_IRUSR _S_IREAD +#define S_IWUSR _S_IWRITE +#define stat _stat64 +#define fstat _fstat64 +#else +#include +#define O_BINARY 0 +#endif + +namespace spark { + + class FileInputStream : public InputStream { + private: + std::string filename; + int file; + uint64_t totalLength; + + public: + FileInputStream(std::string _filename) { + filename = _filename; + file = open(filename.c_str(), O_BINARY | O_RDONLY); + if (file == -1) { + throw std::runtime_error("Can't open " + filename); + } + struct stat fileStat; + if (fstat(file, &fileStat) == -1) { + throw std::runtime_error("Can't stat " + filename); + } + totalLength = static_cast(fileStat.st_size); + } + + ~FileInputStream() override; + + uint64_t getLength() const override { + return totalLength; + } + + uint64_t getNaturalReadSize() const override { + return 128 * 1024; + } + + void read(void* buf, + uint64_t length, + uint64_t offset) override { + if (!buf) { + throw std::runtime_error("Buffer is null"); + } + ssize_t bytesRead = pread(file, buf, length, static_cast(offset)); + + if (bytesRead == -1) { + throw std::runtime_error("Bad read of " + filename); + } + if (static_cast(bytesRead) != length) { + throw std::runtime_error("Short read of " + filename); + } + } + + const std::string& getName() const override { + return filename; + } + }; + + FileInputStream::~FileInputStream() { + close(file); + } + + std::unique_ptr readFile(const std::string& path) { + return spark::readLocalFile(std::string(path)); + } + + std::unique_ptr readLocalFile(const std::string& path) { + return std::unique_ptr(new FileInputStream(path)); + } + + OutputStream::~OutputStream() { + // PASS + }; + + class FileOutputStream : public OutputStream { + private: + std::string filename; + int file; + uint64_t bytesWritten; + bool closed; + + public: + FileOutputStream(std::string _filename) { + bytesWritten = 0; + filename = _filename; + closed = false; + file = open( + filename.c_str(), + O_BINARY | O_CREAT | O_WRONLY | O_TRUNC, + S_IRUSR | S_IWUSR); + if (file == -1) { + throw std::runtime_error("Can't open " + filename); + } + } + + ~FileOutputStream() override; + + uint64_t getLength() const override { + return bytesWritten; + } + + uint64_t getNaturalWriteSize() const override { + return 128 * 1024; + } + + void write(const void* buf, size_t length) override { + if (closed) { + throw std::logic_error("Cannot write to closed stream."); + } + ssize_t bytesWrite = ::write(file, buf, length); + if (bytesWrite == -1) { + throw std::runtime_error("Bad write of " + filename); + } + if (static_cast(bytesWrite) != length) { + throw std::runtime_error("Short write of " + filename); + } + bytesWritten += static_cast(bytesWrite); + } + + const std::string& getName() const override { + return filename; + } + + void close() override { + if (!closed) { + ::close(file); + closed = true; + } + } + }; + + FileOutputStream::~FileOutputStream() { + if (!closed) { + ::close(file); + closed = true; + } + } + + std::unique_ptr writeLocalFile(const std::string& path) { + return std::unique_ptr(new FileOutputStream(path)); + } + + InputStream::~InputStream() { + // PASS + }; +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.hh b/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.hh new file mode 100644 index 0000000000000000000000000000000000000000..7c3d8d03bc0fe3865105e13590698155fe9f9ef6 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.hh @@ -0,0 +1,117 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_FILE_HH +#define SPARK_FILE_HH + +#include +#include + +namespace spark { + + /** + * An abstract interface for providing readers a stream of bytes. + */ + class InputStream { + public: + virtual ~InputStream(); + + /** + * Get the total length of the file in bytes. + */ + virtual uint64_t getLength() const = 0; + + /** + * Get the natural size for reads. + * @return the number of bytes that should be read at once + */ + virtual uint64_t getNaturalReadSize() const = 0; + + /** + * Read length bytes from the file starting at offset into + * the buffer starting at buf. + * @param buf the starting position of a buffer. + * @param length the number of bytes to read. + * @param offset the position in the stream to read from. + */ + virtual void read(void* buf, + uint64_t length, + uint64_t offset) = 0; + + /** + * Get the name of the stream for error messages. + */ + virtual const std::string& getName() const = 0; + }; + + /** + * An abstract interface for providing writer a stream of bytes. + */ + class OutputStream { + public: + virtual ~OutputStream(); + + /** + * Get the total length of bytes written. + */ + virtual uint64_t getLength() const = 0; + + /** + * Get the natural size for reads. + * @return the number of bytes that should be written at once + */ + virtual uint64_t getNaturalWriteSize() const =0; + + /** + * Write/Append length bytes pointed by buf to the file stream + * @param buf the starting position of a buffer. + * @param length the number of bytes to write. + */ + virtual void write(const void* buf, size_t length) = 0; + + /** + * Get the name of the stream for error messages. + */ + virtual const std::string& getName() const = 0; + + /** + * Close the stream and flush any pending data to the disk. + */ + virtual void close() = 0; + }; + + /** + * Create a stream to a local file + * @param path the name of the file in the local file system + */ + std::unique_ptr readFile(const std::string& path); + + /** + * Create a stream to a local file. + * @param path the name of the file in the local file system + */ + std::unique_ptr readLocalFile(const std::string& path); + + /** + * Create a stream to write to a local file. + * @param path the name of the file in the local file system + */ + std::unique_ptr writeLocalFile(const std::string& path); +} + +#endif \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/WriterOptions.cc b/omnioperator/omniop-spark-extension/cpp/src/io/WriterOptions.cc new file mode 100644 index 0000000000000000000000000000000000000000..43bc0b38b09d74721460fac7b20065977f509b2c --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/WriterOptions.cc @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Common.hh" +#include "SparkFile.hh" +#include "ColumnWriter.hh" + +#include + +namespace spark { + + struct WriterOptionsPrivate { + + uint64_t compressionBlockSize; + CompressionKind compression; + CompressionStrategy compressionStrategy; + MemoryPool* memoryPool; + + WriterOptionsPrivate() { // default to Hive_0_12 + compressionBlockSize = 64 * 1024; // 64K + compression = CompressionKind_ZLIB; + compressionStrategy = CompressionStrategy_SPEED; + memoryPool = getDefaultPool(); + } + }; + + WriterOptions::WriterOptions(): + privateBits(std::unique_ptr + (new WriterOptionsPrivate())) { + // PASS + } + + WriterOptions::WriterOptions(const WriterOptions& rhs): + privateBits(std::unique_ptr + (new WriterOptionsPrivate(*(rhs.privateBits.get())))) { + // PASS + } + + WriterOptions::WriterOptions(WriterOptions& rhs) { + // swap privateBits with rhs + privateBits.swap(rhs.privateBits); + } + + WriterOptions& WriterOptions::operator=(const WriterOptions& rhs) { + if (this != &rhs) { + privateBits.reset(new WriterOptionsPrivate(*(rhs.privateBits.get()))); + } + return *this; + } + + WriterOptions::~WriterOptions() { + // PASS + } + + WriterOptions& WriterOptions::setCompressionBlockSize(uint64_t size) { + privateBits->compressionBlockSize = size; + return *this; + } + + uint64_t WriterOptions::getCompressionBlockSize() const { + return privateBits->compressionBlockSize; + } + + WriterOptions& WriterOptions::setCompression(CompressionKind comp) { + privateBits->compression = comp; + return *this; + } + + CompressionKind WriterOptions::getCompression() const { + return privateBits->compression; + } + + WriterOptions& WriterOptions::setCompressionStrategy( + CompressionStrategy strategy) { + privateBits->compressionStrategy = strategy; + return *this; + } + + CompressionStrategy WriterOptions::getCompressionStrategy() const { + return privateBits->compressionStrategy; + } + + WriterOptions& WriterOptions::setMemoryPool(MemoryPool* memoryPool) { + privateBits->memoryPool = memoryPool; + return *this; + } + + MemoryPool* WriterOptions::getMemoryPool() const { + return privateBits->memoryPool; + } + +} + diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/WriterOptions.hh b/omnioperator/omniop-spark-extension/cpp/src/io/WriterOptions.hh new file mode 100644 index 0000000000000000000000000000000000000000..d942fd5131fdf88d6ada35f96aa25af33ea99269 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/WriterOptions.hh @@ -0,0 +1,94 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_WRITER_OPTIONS_HH +#define SPARK_WRITER_OPTIONS_HH + +#include +#include "MemoryPool.hh" +#include "Common.hh" + +namespace spark { + // classes that hold data members so we can maintain binary compatibility + struct WriterOptionsPrivate; + + enum CompressionStrategy { + CompressionStrategy_SPEED = 0, + CompressionStrategy_COMPRESSION + }; + + /** + * Options for creating a Writer. + */ + class WriterOptions { + private: + std::unique_ptr privateBits; + public: + WriterOptions(); + WriterOptions(const WriterOptions&); + WriterOptions(WriterOptions&); + WriterOptions& operator=(const WriterOptions&); + virtual ~WriterOptions(); + + /** + * Set the data compression block size. + */ + WriterOptions& setCompressionBlockSize(uint64_t size); + + /** + * Get the data compression block size. + * @return if not set, return default value. + */ + uint64_t getCompressionBlockSize() const; + + /** + * Set compression kind. + */ + WriterOptions& setCompression(CompressionKind comp); + + /** + * Get the compression kind. + * @return if not set, return default value which is ZLIB. + */ + CompressionKind getCompression() const; + + /** + * Set the compression strategy. + */ + WriterOptions& setCompressionStrategy(CompressionStrategy strategy); + + /** + * Get the compression strategy. + * @return if not set, return default value which is speed. + */ + CompressionStrategy getCompressionStrategy() const; + + /** + * Set the memory pool. + */ + WriterOptions& setMemoryPool(MemoryPool * memoryPool); + + /** + * Get the memory pool. + * @return if not set, return default memory pool. + */ + MemoryPool * getMemoryPool() const; + }; +} + +#endif \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/wrap/coded_stream_wrapper.h b/omnioperator/omniop-spark-extension/cpp/src/io/wrap/coded_stream_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..d623c7d01bfcfeaf0799d17c3d2b5a5c4a97af57 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/wrap/coded_stream_wrapper.h @@ -0,0 +1,30 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef CODED_STREAM_WRAPPER_HH +#define CODED_STREAM_WRAPPER_HH + +#include "io/Adaptor.hh" + + +#ifdef __clang__ + DIAGNOSTIC_IGNORE("-Wshorten-64-to-32") + DIAGNOSTIC_IGNORE("-Wreserved-id-macro") +#endif + +#if defined(__GNUC__) || defined(__clang__) + DIAGNOSTIC_IGNORE("-Wconversion") +#endif + +#include + +#endif diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/wrap/snappy_wrapper.h b/omnioperator/omniop-spark-extension/cpp/src/io/wrap/snappy_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..56ac837ee0de85282bebf1f046c2de075933e30e --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/wrap/snappy_wrapper.h @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SNAPPY_WRAPPER_HH +#define SNAPPY_WRAPPER_HH + +#include "../Adaptor.hh" + +#ifdef __clang__ + DIAGNOSTIC_IGNORE("-Wreserved-id-macro") +#endif + +#include + +#endif diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/wrap/zero_copy_stream_wrapper.h b/omnioperator/omniop-spark-extension/cpp/src/io/wrap/zero_copy_stream_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..252839ec87ad5493d1531b8fc064f05df96e1b89 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/wrap/zero_copy_stream_wrapper.h @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ZERO_COPY_STREAM_WRAPPER_HH +#define ZERO_COPY_STREAM_WRAPPER_HH + +#include "../Adaptor.hh" + +#if defined(__GNUC__) || defined(__clang__) + DIAGNOSTIC_IGNORE("-Wdeprecated") + DIAGNOSTIC_IGNORE("-Wpadded") + DIAGNOSTIC_IGNORE("-Wunused-parameter") +#endif + +#ifdef __clang__ + DIAGNOSTIC_IGNORE("-Wreserved-id-macro") +#endif + +#include + +#endif diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3b38f0deae9ebf178e2878cb4e69b8d54ec7e902 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp @@ -0,0 +1,553 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OrcColumnarBatchJniReader.h" +using namespace omniruntime::vec; +using namespace std; +using namespace orc; + +jclass runtimeExceptionClass; +jclass jsonClass; +jclass arrayListClass; +jmethodID jsonMethodInt; +jmethodID jsonMethodLong; +jmethodID jsonMethodHas; +jmethodID jsonMethodString; +jmethodID jsonMethodJsonObj; +jmethodID arrayListGet; +jmethodID arrayListSize; +jmethodID jsonMethodObj; + +int initJniId(JNIEnv *env) +{ + /* + * init table scan log + */ + jsonClass = env->FindClass("org/json/JSONObject"); + arrayListClass = env->FindClass("java/util/ArrayList"); + + arrayListGet = env->GetMethodID(arrayListClass, "get", "(I)Ljava/lang/Object;"); + arrayListSize = env->GetMethodID(arrayListClass, "size", "()I"); + + // get int method + jsonMethodInt = env->GetMethodID(jsonClass, "getInt", "(Ljava/lang/String;)I"); + if (jsonMethodInt == NULL) + return -1; + + // get long method + jsonMethodLong = env->GetMethodID(jsonClass, "getLong", "(Ljava/lang/String;)J"); + if (jsonMethodLong == NULL) + return -1; + + // get has method + jsonMethodHas = env->GetMethodID(jsonClass, "has", "(Ljava/lang/String;)Z"); + if (jsonMethodHas == NULL) + return -1; + + // get string method + jsonMethodString = env->GetMethodID(jsonClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;"); + if (jsonMethodString == NULL) + return -1; + + // get json object method + jsonMethodJsonObj = env->GetMethodID(jsonClass, "getJSONObject", "(Ljava/lang/String;)Lorg/json/JSONObject;"); + if (jsonMethodJsonObj == NULL) + return -1; + + // get json object method + jsonMethodObj = env->GetMethodID(jsonClass, "get", "(Ljava/lang/String;)Ljava/lang/Object;"); + if (jsonMethodJsonObj == NULL) + return -1; + + jclass local_class = env->FindClass("Ljava/lang/RuntimeException;"); + runtimeExceptionClass = (jclass)env->NewGlobalRef(local_class); + env->DeleteLocalRef(local_class); + if (runtimeExceptionClass == NULL) + return -1; + + return 0; +} + +void JNI_OnUnload(JavaVM *vm, const void *reserved) +{ + JNIEnv *env = nullptr; + vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_8); + env->DeleteGlobalRef(runtimeExceptionClass); +} + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_initializeReader(JNIEnv *env, + jobject jObj, jstring path, jobject jsonObj) +{ + /* + * init logger and jni env method id + */ + initJniId(env); + + /* + * get tailLocation from json obj + */ + jlong tailLocation = env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("tailLocation")); + jstring serTailJstr = + (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("serializedTail")); + const char *pathPtr = env->GetStringUTFChars(path, nullptr); + std::string filePath(pathPtr); + orc::MemoryPool *pool = orc::getDefaultPool(); + orc::ReaderOptions readerOptions; + readerOptions.setMemoryPool(*pool); + readerOptions.setTailLocation(tailLocation); + if (serTailJstr != NULL) { + const char *ptr = env->GetStringUTFChars(serTailJstr, nullptr); + std::string serTail(ptr); + readerOptions.setSerializedFileTail(serTail); + env->ReleaseStringUTFChars(serTailJstr, ptr); + } + + std::unique_ptr reader = createReader(orc::readFile(filePath), readerOptions); + env->ReleaseStringUTFChars(path, pathPtr); + orc::Reader *readerNew = reader.release(); + return (jlong)(readerNew); +} + +int getLiteral(orc::Literal &lit, int leafType, string value) +{ + switch ((orc::PredicateDataType)leafType) { + case orc::PredicateDataType::LONG: { + lit = orc::Literal(static_cast(std::stol(value))); + break; + } + case orc::PredicateDataType::STRING: { + lit = orc::Literal(value.c_str(), value.size()); + break; + } + case orc::PredicateDataType::DATE: { + lit = orc::Literal(PredicateDataType::DATE, static_cast(std::stol(value))); + break; + } + case orc::PredicateDataType::DECIMAL: { + vector valList; + istringstream tmpAllStr(value); + string tmpStr; + while (tmpAllStr >> tmpStr) { + valList.push_back(tmpStr); + } + Decimal decimalVal(valList[0]); + lit = orc::Literal(decimalVal.value, static_cast(std::stoi(valList[1])), + static_cast(std::stoi(valList[2]))); + break; + } + default: { + LogsError("ERROR: TYPE ERROR: TYPEID"); + } + } + return 0; +} + +int buildLeafs(int leafOp, vector &litList, Literal &lit, string leafNameString, int leafType, + SearchArgumentBuilder &builder) +{ + switch ((PredicateOperatorType)leafOp) { + case PredicateOperatorType::LESS_THAN: { + builder.lessThan(leafNameString, (PredicateDataType)leafType, lit); + break; + } + case PredicateOperatorType::LESS_THAN_EQUALS: { + builder.lessThanEquals(leafNameString, (PredicateDataType)leafType, lit); + break; + } + case PredicateOperatorType::EQUALS: { + builder.equals(leafNameString, (PredicateDataType)leafType, lit); + break; + } + case PredicateOperatorType::NULL_SAFE_EQUALS: { + builder.nullSafeEquals(leafNameString, (PredicateDataType)leafType, lit); + break; + } + case PredicateOperatorType::IS_NULL: { + builder.isNull(leafNameString, (PredicateDataType)leafType); + break; + } + case PredicateOperatorType::IN: { + builder.in(leafNameString, (PredicateDataType)leafType, litList); + break; + } + default: { + LogsError("ERROR operator ID"); + } + } + return 1; +} + +int initLeafs(JNIEnv *env, SearchArgumentBuilder &builder, jobject &jsonExp, jobject &jsonLeaves) +{ + jstring leaf = (jstring)env->CallObjectMethod(jsonExp, jsonMethodString, env->NewStringUTF("leaf")); + jobject leafJsonObj = env->CallObjectMethod(jsonLeaves, jsonMethodJsonObj, leaf); + jstring leafName = (jstring)env->CallObjectMethod(leafJsonObj, jsonMethodString, env->NewStringUTF("name")); + std::string leafNameString(env->GetStringUTFChars(leafName, nullptr)); + jint leafOp = (jint)env->CallIntMethod(leafJsonObj, jsonMethodInt, env->NewStringUTF("op")); + jint leafType = (jint)env->CallIntMethod(leafJsonObj, jsonMethodInt, env->NewStringUTF("type")); + Literal lit(0L); + jstring leafValue = (jstring)env->CallObjectMethod(leafJsonObj, jsonMethodString, env->NewStringUTF("literal")); + if (leafValue != nullptr) { + std::string leafValueString(env->GetStringUTFChars(leafValue, nullptr)); + if (leafValueString.size() != 0) { + getLiteral(lit, leafType, leafValueString); + } + } + std::vector litList; + jobject litListValue = env->CallObjectMethod(leafJsonObj, jsonMethodObj, env->NewStringUTF("literalList")); + if (litListValue != nullptr) { + int childs = (int)env->CallIntMethod(litListValue, arrayListSize); + for (int i = 0; i < childs; i++) { + jstring child = (jstring)env->CallObjectMethod(litListValue, arrayListGet, i); + std::string childString(env->GetStringUTFChars(child, nullptr)); + getLiteral(lit, leafType, childString); + litList.push_back(lit); + } + } + buildLeafs((int)leafOp, litList, lit, leafNameString, (int)leafType, builder); + return 1; +} + +int initExpressionTree(JNIEnv *env, SearchArgumentBuilder &builder, jobject &jsonExp, jobject &jsonLeaves) +{ + int op = env->CallIntMethod(jsonExp, jsonMethodInt, env->NewStringUTF("op")); + if (op == (int)(Operator::LEAF)) { + initLeafs(env, builder, jsonExp, jsonLeaves); + } else { + switch ((Operator)op) { + case Operator::OR: { + builder.startOr(); + break; + } + case Operator::AND: { + builder.startAnd(); + break; + } + case Operator::NOT: { + builder.startNot(); + break; + } + } + jobject childList = env->CallObjectMethod(jsonExp, jsonMethodObj, env->NewStringUTF("child")); + int childs = (int)env->CallIntMethod(childList, arrayListSize); + for (int i = 0; i < childs; i++) { + jobject child = env->CallObjectMethod(childList, arrayListGet, i); + initExpressionTree(env, builder, child, jsonLeaves); + } + builder.end(); + } + return 0; +} + + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_initializeRecordReader(JNIEnv *env, + jobject jObj, jlong reader, jobject jsonObj) +{ + orc::Reader *readerPtr = (orc::Reader *)reader; + // get offset from json obj + jlong offset = env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("offset")); + jlong length = env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("length")); + jobjectArray includedColumns = + (jobjectArray)env->CallObjectMethod(jsonObj, jsonMethodObj, env->NewStringUTF("includedColumns")); + if (includedColumns == NULL) + return -1; + std::list includedColumnsLenArray; + jint arrLen = env->GetArrayLength(includedColumns); + jboolean isCopy = JNI_FALSE; + for (int i = 0; i < arrLen; i++) { + jstring colName = (jstring)env->GetObjectArrayElement(includedColumns, i); + const char *convertedValue = (env)->GetStringUTFChars(colName, &isCopy); + std::string colNameString = convertedValue; + includedColumnsLenArray.push_back(colNameString); + } + RowReaderOptions rowReaderOpts; + if (arrLen != 0) { + rowReaderOpts.include(includedColumnsLenArray); + } else { + std::list includeFirstCol; + includeFirstCol.push_back(0); + rowReaderOpts.include(includeFirstCol); + } + rowReaderOpts.range(offset, length); + + jboolean hasExpressionTree = env->CallBooleanMethod(jsonObj, jsonMethodHas, env->NewStringUTF("expressionTree")); + if (hasExpressionTree) { + jobject expressionTree = env->CallObjectMethod(jsonObj, jsonMethodJsonObj, env->NewStringUTF("expressionTree")); + jobject leaves = env->CallObjectMethod(jsonObj, jsonMethodJsonObj, env->NewStringUTF("leaves")); + std::unique_ptr builder = SearchArgumentFactory::newBuilder(); + initExpressionTree(env, *builder, expressionTree, leaves); + auto sargBuilded = (*builder).build(); + rowReaderOpts.searchArgument(std::unique_ptr(sargBuilded.release())); + } + + + std::unique_ptr rowReader = readerPtr->createRowReader(rowReaderOpts); + return (jlong)(rowReader.release()); +} + + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_initializeBatch(JNIEnv *env, + jobject jObj, jlong rowReader, jlong batchSize) +{ + orc::RowReader *rowReaderPtr = (orc::RowReader *)(rowReader); + uint64_t batchLen = (uint64_t)batchSize; + std::unique_ptr batch = rowReaderPtr->createRowBatch(batchLen); + orc::ColumnVectorBatch *rtn = batch.release(); + return (jlong)rtn; +} + +template uint64_t copyFixwidth(orc::ColumnVectorBatch *field) +{ + VectorAllocator *allocator = omniruntime::vec::GetProcessGlobalVecAllocator(); + using T = typename NativeType::type; + ORC_TYPE *lvb = dynamic_cast(field); + FixedWidthVector *originalVector = new FixedWidthVector(allocator, lvb->numElements); + for (int i = 0; i < lvb->numElements; i++) { + if (lvb->notNull.data()[i]) { + originalVector->SetValue(i, (T)(lvb->data.data()[i])); + } else { + originalVector->SetValueNull(i); + } + } + return (uint64_t)originalVector; +} + + +uint64_t copyVarwidth(int maxLen, orc::ColumnVectorBatch *field, int vcType) +{ + VectorAllocator *allocator = omniruntime::vec::GetProcessGlobalVecAllocator(); + orc::StringVectorBatch *lvb = dynamic_cast(field); + uint64_t totalLen = + maxLen * (lvb->numElements) > lvb->getMemoryUsage() ? maxLen * (lvb->numElements) : lvb->getMemoryUsage(); + VarcharVector *originalVector = new VarcharVector(allocator, totalLen, lvb->numElements); + for (int i = 0; i < lvb->numElements; i++) { + if (lvb->notNull.data()[i]) { + string tmpStr(reinterpret_cast(lvb->data.data()[i]), lvb->length.data()[i]); + if (vcType == orc::TypeKind::CHAR && tmpStr.back() == ' ') { + tmpStr.erase(tmpStr.find_last_not_of(" ") + 1); + } + originalVector->SetValue(i, reinterpret_cast(tmpStr.data()), tmpStr.length()); + } else { + originalVector->SetValueNull(i); + } + } + return (uint64_t)originalVector; +} + +int copyToOminVec(int maxLen, int vcType, int &ominTypeId, uint64_t &ominVecId, orc::ColumnVectorBatch *field) +{ + switch (vcType) { + case orc::TypeKind::DATE: + case orc::TypeKind::INT: { + if (vcType == orc::TypeKind::DATE) { + ominTypeId = static_cast(OMNI_DATE32); + } else { + ominTypeId = static_cast(OMNI_INT); + } + ominVecId = copyFixwidth(field); + break; + } + case orc::TypeKind::LONG: { + ominTypeId = static_cast(OMNI_LONG); + ominVecId = copyFixwidth(field); + break; + } + case orc::TypeKind::CHAR: + case orc::TypeKind::STRING: + case orc::TypeKind::VARCHAR: { + ominTypeId = static_cast(OMNI_VARCHAR); + ominVecId = (uint64_t)copyVarwidth(maxLen, field, vcType); + break; + } + default: { + LogsError("orc::TypeKind::UNKNOWN ERROR %d", vcType); + } + } + return 1; +} + +int copyToOminDecimalVec(int vcType, int &ominTypeId, uint64_t &ominVecId, orc::ColumnVectorBatch *field) +{ + VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); + if (vcType > 18) { + ominTypeId = static_cast(OMNI_DECIMAL128); + orc::Decimal128VectorBatch *lvb = dynamic_cast(field); + FixedWidthVector *originalVector = + new FixedWidthVector(allocator, lvb->numElements); + for (int i = 0; i < lvb->numElements; i++) { + if (lvb->notNull.data()[i]) { + bool wasNegative = false; + int64_t highbits = lvb->values.data()[i].getHighBits(); + uint64_t lowbits = lvb->values.data()[i].getLowBits(); + uint64_t high = 0; + uint64_t low = 0; + if (highbits < 0) { + low = ~lowbits + 1; + high = static_cast(~highbits); + if (low == 0) { + high += 1; + } + highbits = high | ((uint64_t)1 << 63); + } + Decimal128 d128(highbits, low); + originalVector->SetValue(i, d128); + } else { + originalVector->SetValueNull(i); + } + } + ominVecId = (uint64_t)originalVector; + } else { + ominTypeId = static_cast(OMNI_DECIMAL64); + orc::Decimal64VectorBatch *lvb = dynamic_cast(field); + FixedWidthVector *originalVector = new FixedWidthVector(allocator, lvb->numElements); + for (int i = 0; i < lvb->numElements; i++) { + if (lvb->notNull.data()[i]) { + originalVector->SetValue(i, (int64_t)(lvb->values.data()[i])); + } else { + originalVector->SetValueNull(i); + } + } + ominVecId = (uint64_t)originalVector; + } + return 1; +} + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderNext(JNIEnv *env, + jobject jObj, jlong rowReader, jlong reader, jlong batch, jintArray typeId, jlongArray vecNativeId) +{ + orc::RowReader *rowReaderPtr = (orc::RowReader *)rowReader; + orc::ColumnVectorBatch *columnVectorBatch = (orc::ColumnVectorBatch *)batch; + orc::Reader *readerPtr = (orc::Reader *)reader; + const orc::Type &baseTp = rowReaderPtr->getSelectedType(); + int vecCnt = 0; + long batchRowSize = 0; + if (rowReaderPtr->next(*columnVectorBatch)) { + orc::StructVectorBatch *root = dynamic_cast(columnVectorBatch); + vecCnt = root->fields.size(); + batchRowSize = root->fields[0]->numElements; + for (int id = 0; id < vecCnt; id++) { + int vcType = baseTp.getSubtype(id)->getKind(); + int maxLen = baseTp.getSubtype(id)->getMaximumLength(); + int ominTypeId = 0; + uint64_t ominVecId = 0; + try { + if (vcType != orc::TypeKind::DECIMAL) { + copyToOminVec(maxLen, vcType, ominTypeId, ominVecId, root->fields[id]); + } else { + copyToOminDecimalVec(baseTp.getSubtype(id)->getPrecision(), ominTypeId, ominVecId, + root->fields[id]); + } + } catch (omniruntime::exception::OmniException &e) { + env->ThrowNew(runtimeExceptionClass, e.what()); + return (jlong)batchRowSize; + } + env->SetIntArrayRegion(typeId, id, 1, &ominTypeId); + jlong ominVec = static_cast(ominVecId); + env->SetLongArrayRegion(vecNativeId, id, 1, &ominVec); + } + } + return (jlong)batchRowSize; +} + +/* + * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Method: recordReaderGetRowNumber + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderGetRowNumber( + JNIEnv *env, jobject jObj, jlong rowReader) +{ + orc::RowReader *rowReaderPtr = (orc::RowReader *)rowReader; + uint64_t rownum = rowReaderPtr->getRowNumber(); + return (jlong)rownum; +} + +/* + * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Method: recordReaderGetProgress + * Signature: (J)F + */ +JNIEXPORT jfloat JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderGetProgress( + JNIEnv *env, jobject jObj, jlong rowReader) +{ + jfloat curProgress = 1; + throw std::runtime_error("recordReaderGetProgress is unsupported"); + return curProgress; +} + +/* + * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Method: recordReaderClose + * Signature: (J)F + */ +JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderClose(JNIEnv *env, + jobject jObj, jlong rowReader, jlong reader, jlong batchReader) +{ + orc::ColumnVectorBatch *columnVectorBatch = (orc::ColumnVectorBatch *)batchReader; + if (nullptr == columnVectorBatch) { + throw std::runtime_error("delete nullptr error for batch reader"); + } + delete columnVectorBatch; + orc::RowReader *rowReaderPtr = (orc::RowReader *)rowReader; + if (nullptr == rowReaderPtr) { + throw std::runtime_error("delete nullptr error for row reader"); + } + delete rowReaderPtr; + orc::Reader *readerPtr = (orc::Reader *)reader; + if (nullptr == readerPtr) { + throw std::runtime_error("delete nullptr error for reader"); + } + delete readerPtr; +} + +/* + * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Method: recordReaderSeekToRow + * Signature: (JJ)F + */ +JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderSeekToRow(JNIEnv *env, + jobject jObj, jlong rowReader, jlong rowNumber) +{ + orc::RowReader *rowReaderPtr = (orc::RowReader *)rowReader; + rowReaderPtr->seekToRow((long)rowNumber); +} + + +JNIEXPORT jobjectArray JNICALL +Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_getAllColumnNames(JNIEnv *env, jobject jObj, jlong reader) +{ + orc::Reader *readerPtr = (orc::Reader *)reader; + int32_t cols = static_cast(readerPtr->getType().getSubtypeCount()); + jobjectArray ret = + (jobjectArray)env->NewObjectArray(cols, env->FindClass("java/lang/String"), env->NewStringUTF("")); + for (int i = 0; i < cols; i++) { + env->SetObjectArrayElement(ret, i, env->NewStringUTF(readerPtr->getType().getFieldName(i).data())); + } + return ret; +} + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_getNumberOfRows(JNIEnv *env, + jobject jObj, jlong rowReader, jlong batch) +{ + orc::RowReader *rowReaderPtr = (orc::RowReader *)rowReader; + orc::ColumnVectorBatch *columnVectorBatch = (orc::ColumnVectorBatch *)batch; + rowReaderPtr->next(*columnVectorBatch); + jlong rows = columnVectorBatch->numElements; + return rows; +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h b/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h new file mode 100644 index 0000000000000000000000000000000000000000..5d05f73471a014de14fe870ec602570b73110d84 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h @@ -0,0 +1,148 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Header for class THESTRAL_PLUGIN_ORCCOLUMNARBATCHJNIREADER_H */ + +#ifndef THESTRAL_PLUGIN_ORCCOLUMNARBATCHJNIREADER_H +#define THESTRAL_PLUGIN_ORCCOLUMNARBATCHJNIREADER_H + +#include "orc/ColumnPrinter.hh" +#include "orc/Exceptions.hh" +#include "orc/Type.hh" +#include "orc/Vector.hh" +#include "orc/Reader.hh" +#include "orc/OrcFile.hh" +#include "orc/MemoryPool.hh" +#include "orc/sargs/SearchArgument.hh" +#include "orc/sargs/Literal.hh" +#include +#include +#include +#include +#include +#include +#include "jni.h" +#include "json/json.h" +#include "vector/vector_common.h" +#include "util/omni_exception.h" +#include +#include +#include "../common/debug.h" + +#ifdef __cplusplus +extern "C" { +#endif + +enum class Operator { + OR, + AND, + NOT, + LEAF, + CONSTANT +}; + +enum class PredicateOperatorType { + EQUALS = 0, + NULL_SAFE_EQUALS, + LESS_THAN, + LESS_THAN_EQUALS, IN, BETWEEN, IS_NULL +}; + +/* + * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Method: initializeReader + * Signature: (Ljava/lang/String;Lorg/json/simple/JSONObject;)J + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_initializeReader + (JNIEnv* env, jobject jObj, jstring path, jobject job); + +/* + * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Method: initializeRecordReader + * Signature: (JLorg/json/simple/JSONObject;)J + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_initializeRecordReader + (JNIEnv* env, jobject jObj, jlong reader, jobject job); + +/* + * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Method: initializeRecordReader + * Signature: (JLorg/json/simple/JSONObject;)J + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_initializeBatch + (JNIEnv* env, jobject jObj, jlong rowReader, jlong batchSize); + +/* + * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Method: recordReaderNext + * Signature: (J[I[J)J + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderNext + (JNIEnv *, jobject, jlong, jlong, jlong, jintArray, jlongArray); + +/* + * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Method: recordReaderGetRowNumber + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderGetRowNumber + (JNIEnv *, jobject, jlong); + +/* + * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Method: recordReaderGetProgress + * Signature: (J)F + */ +JNIEXPORT jfloat JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderGetProgress + (JNIEnv *, jobject, jlong); + + +/* + * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Method: recordReaderClose + * Signature: (J)F + */ +JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderClose + (JNIEnv *, jobject, jlong, jlong, jlong); + +/* + * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Method: recordReaderSeekToRow + * Signature: (JJ)F + */ +JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderSeekToRow + (JNIEnv *, jobject, jlong, jlong); + +JNIEXPORT jobjectArray JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_getAllColumnNames + (JNIEnv *, jobject, jlong); + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_getNumberOfRows(JNIEnv *env, + jobject jObj, jlong rowReader, jlong batch); + +int getLiteral(orc::Literal &lit, int leafType, std::string value); + +int buildLeafs(int leafOp, std::vector &litList, orc::Literal &lit, std::string leafNameString, int leafType, + orc::SearchArgumentBuilder &builder); + +int copyToOminVec(int maxLen, int vcType, int &ominTypeId, uint64_t &ominVecId, orc::ColumnVectorBatch *field); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0f03512410d0c403945f2cec3993c6638b17a8af --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp @@ -0,0 +1,223 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "../io/SparkFile.hh" +#include "../io/ColumnWriter.hh" +#include "../shuffle/splitter.h" +#include "jni_common.h" +#include "SparkJniWrapper.hh" +#include "concurrent_map.h" + +static jint JNI_VERSION = JNI_VERSION_1_8; + +static jclass split_result_class; +static jclass runtime_exception_class; + +static jmethodID split_result_constructor; + +using namespace spark; +using namespace google::protobuf::io; +using namespace omniruntime::vec; + +static ConcurrentMap> shuffle_splitter_holder_; + +jint JNI_OnLoad(JavaVM* vm, void* reserved) { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { + return JNI_ERR; + } + + illegal_access_exception_class = + CreateGlobalClassReference(env, "Ljava/lang/IllegalAccessException;"); + + split_result_class = + CreateGlobalClassReference(env, "Lcom/huawei/boostkit/spark/vectorized/SplitResult;"); + split_result_constructor = GetMethodID(env, split_result_class, "", "(JJJJJ[J)V"); + + runtime_exception_class = CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;"); + + return JNI_VERSION; +} + +void JNI_OnUnload(JavaVM* vm, void* reserved) { + JNIEnv* env; + vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); + + env->DeleteGlobalRef(split_result_class); + + env->DeleteGlobalRef(runtime_exception_class); + + shuffle_splitter_holder_.Clear(); +} + +JNIEXPORT jlong JNICALL +Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativeMake( + JNIEnv* env, jobject, jstring partitioning_name_jstr, jint num_partitions, + jstring jInputType, jint jNumCols, jint buffer_size, + jstring compression_type_jstr, jstring data_file_jstr, jint num_sub_dirs, + jstring local_dirs_jstr, jlong compress_block_size, + jint spill_batch_row, jlong spill_memory_threshold) { + if (partitioning_name_jstr == nullptr) { + env->ThrowNew(env->FindClass("java/lang/Exception"), + std::string("Short partitioning name can't be null").c_str()); + return 0; + } + + const char* inputTypeCharPtr = env->GetStringUTFChars(jInputType, JNI_FALSE); + DataTypes inputVecTypes = Deserialize(inputTypeCharPtr); + const int32_t *inputVecTypeIds = inputVecTypes.GetIds(); + // + std::vector inputDataTpyes = inputVecTypes.Get(); + int32_t size = inputDataTpyes.size(); + uint32_t *inputDataPrecisions = new uint32_t[size]; + uint32_t *inputDataScales = new uint32_t[size]; + for (int i = 0; i < size; ++i) { + inputDataPrecisions[i] = inputDataTpyes[i].GetPrecision(); + inputDataScales[i] = inputDataTpyes[i].GetScale(); + } + inputDataTpyes.clear(); + + InputDataTypes inputDataTypesTmp; + inputDataTypesTmp.inputVecTypeIds = (int32_t *)inputVecTypeIds; + inputDataTypesTmp.inputDataPrecisions = inputDataPrecisions; + inputDataTypesTmp.inputDataScales = inputDataScales; + + if (data_file_jstr == nullptr) { + env->ThrowNew(env->FindClass("java/lang/Exception"), + std::string("Shuffle DataFile can't be null").c_str()); + return 0; + } + if (local_dirs_jstr == nullptr) { + env->ThrowNew(env->FindClass("java/lang/Exception"), + std::string("Shuffle DataFile can't be null").c_str()); + return 0; + } + + auto partitioning_name_c = env->GetStringUTFChars(partitioning_name_jstr, JNI_FALSE); + auto partitioning_name = std::string(partitioning_name_c); + env->ReleaseStringUTFChars(partitioning_name_jstr, partitioning_name_c); + + auto splitOptions = SplitOptions::Defaults(); + if (buffer_size > 0) { + splitOptions.buffer_size = buffer_size; + } + if (num_sub_dirs > 0) { + splitOptions.num_sub_dirs = num_sub_dirs; + } + if (compression_type_jstr != NULL) { + auto compression_type_result = GetCompressionType(env, compression_type_jstr); + splitOptions.compression_type = compression_type_result; + } + + auto data_file_c = env->GetStringUTFChars(data_file_jstr, JNI_FALSE); + splitOptions.data_file = std::string(data_file_c); + env->ReleaseStringUTFChars(data_file_jstr, data_file_c); + + //TODO: memory pool select + + auto local_dirs = env->GetStringUTFChars(local_dirs_jstr, JNI_FALSE); + setenv("NATIVESQL_SPARK_LOCAL_DIRS", local_dirs, 1); + env->ReleaseStringUTFChars(local_dirs_jstr, local_dirs); + + if (spill_batch_row > 0){ + splitOptions.spill_batch_row_num = spill_batch_row; + } + if (spill_memory_threshold > 0){ + splitOptions.spill_mem_threshold = spill_memory_threshold; + } + if (compress_block_size > 0){ + splitOptions.compress_block_size = compress_block_size; + } + + jclass cls = env->FindClass("java/lang/Thread"); + jmethodID mid = env->GetStaticMethodID(cls, "currentThread", "()Ljava/lang/Thread;"); + jobject thread = env->CallStaticObjectMethod(cls, mid); + if (thread == NULL) { + std::cout << "Thread.currentThread() return NULL" <GetMethodID(cls, "getId", "()J"); + jlong sid = env->CallLongMethod(thread, mid_getid); + splitOptions.thread_id = (int64_t)sid; + } + + try{ + auto splitter = Splitter::Make(partitioning_name, inputDataTypesTmp, jNumCols, num_partitions, std::move(splitOptions)); + return shuffle_splitter_holder_.Insert(std::shared_ptr(splitter)); + } catch (omniruntime::exception::OmniException & e) { + env->ThrowNew(runtime_exception_class, e.what()); + } +} + +JNIEXPORT jlong JNICALL +Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split( + JNIEnv *env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress) { + auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); + if (!splitter) { + std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + env->ThrowNew(env->FindClass("java/lang/Exception"), error_message.c_str()); + return -1; + } + + auto vecBatch = (VectorBatch *) jVecBatchAddress; + + try { + splitter->Split(*vecBatch); + } catch (omniruntime::exception::OmniException & e) { + env->ThrowNew(runtime_exception_class, e.what()); + } +} + +JNIEXPORT jobject JNICALL +Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_stop( + JNIEnv* env, jobject, jlong splitter_id) { + auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); + if (!splitter) { + std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + env->ThrowNew(env->FindClass("java/lang/Exception"), error_message.c_str()); + } + try { + splitter->Stop(); + } catch (omniruntime::exception::OmniException & e) { + env->ThrowNew(runtime_exception_class, e.what()); + } + const auto& partition_length = splitter->PartitionLengths(); + auto partition_length_arr = env->NewLongArray(partition_length.size()); + auto src = reinterpret_cast(partition_length.data()); + env->SetLongArrayRegion(partition_length_arr, 0, partition_length.size(), src); + jobject split_result = env->NewObject( + split_result_class, split_result_constructor, splitter->TotalComputePidTime(), + splitter->TotalWriteTime(), splitter->TotalSpillTime(), + splitter->TotalBytesWritten(), splitter->TotalBytesSpilled(), partition_length_arr); + + return split_result; +} + +JNIEXPORT void JNICALL +Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_close( + JNIEnv* env, jobject, jlong splitter_id) { + auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); + if (!splitter) { + std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + env->ThrowNew(env->FindClass("java/lang/Exception"), error_message.c_str()); + } + shuffle_splitter_holder_.Erase(splitter_id); +} diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh new file mode 100644 index 0000000000000000000000000000000000000000..91ff665e4ea2448295722b9260615207074d801d --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh @@ -0,0 +1,57 @@ +/** + * Copyright (C) 2021-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#ifndef SPARK_JNI_WRAPPER +#define SPARK_JNI_WRAPPER +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: com_huawei_boostkit_spark_jni_SparkJniWrapper + * Method: nativeMake + * Signature: ()V + */ +JNIEXPORT jlong JNICALL +Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativeMake( + JNIEnv* env, jobject, jstring partitioning_name_jstr, jint num_partitions, + jstring jInputType, jint jNumCols, jint buffer_size, + jstring compression_type_jstr, jstring data_file_jstr, jint num_sub_dirs, + jstring local_dirs_jstr, jlong compress_block_size, + jint spill_batch_row, jlong spill_memory_threshold); + +JNIEXPORT jlong JNICALL +Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split( + JNIEnv* env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress); + +JNIEXPORT jobject JNICALL +Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_stop( + JNIEnv* env, jobject, jlong splitter_id); + +JNIEXPORT void JNICALL +Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_close( + JNIEnv* env, jobject, jlong splitter_id); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/concurrent_map.h b/omnioperator/omniop-spark-extension/cpp/src/jni/concurrent_map.h new file mode 100644 index 0000000000000000000000000000000000000000..e7888010d72e38e9e70f02f461b405bafe3a682a --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/concurrent_map.h @@ -0,0 +1,80 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THESTRAL_PLUGIN_MASTER_CONCURRENT_MAP_H +#define THESTRAL_PLUGIN_MASTER_CONCURRENT_MAP_H + +#include +#include +#include +#include +#include + +/** + * An utility class that map module id to module points + * @tparam Holder class of the object to hold + */ +template +class ConcurrentMap { +public: + ConcurrentMap() : module_id_(init_module_id_) {} + + jlong Insert(Holder holder) { + std::lock_guard lock(mtx_); + jlong result = module_id_++; + map_.insert(std::pair(result, holder)); + return result; + } + + void Erase(jlong module_id) { + std::lock_guard lock(mtx_); + map_.erase(module_id); + } + + Holder Lookup(jlong module_id) { + std::lock_guard lock(mtx_); + auto it = map_.find(module_id); + if (it != map_.end()) { + return it->second; + } + return nullptr; + } + + void Clear() { + std::lock_guard lock(mtx_); + map_.clear(); + } + + size_t Size() { + std::lock_guard lock(mtx_); + return map_.size(); + } +private: + // Initialize the module id starting value to a number greater than zero + // to allow for easier debugging of uninitialized java variables. + static constexpr int init_module_id_ = 4; + + int64_t module_id_; + std::mutex mtx_; + // map from module ids return to Java and module pointers + std::unordered_map map_; + +}; + +#endif diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h new file mode 100644 index 0000000000000000000000000000000000000000..d67561a58648cd26699130bb0a2aa2bfd680c5e3 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h @@ -0,0 +1,57 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THESTRAL_PLUGIN_MASTER_JNI_COMMON_H +#define THESTRAL_PLUGIN_MASTER_JNI_COMMON_H + +#include + +#include "../common/common.h" + +static jclass illegal_access_exception_class; + +spark::CompressionKind GetCompressionType(JNIEnv* env, jstring codec_jstr) { + auto codec_c = env->GetStringUTFChars(codec_jstr, JNI_FALSE); + auto codec = std::string(codec_c); + auto compression_type = GetCompressionType(codec); + env->ReleaseStringUTFChars(codec_jstr, codec_c); + return compression_type; +} + +jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { + jclass local_class = env->FindClass(class_name); + jclass global_class = (jclass)env->NewGlobalRef(local_class); + env->DeleteLocalRef(local_class); + if (global_class == nullptr) { + std::string error_message = "Unable to createGlobalClassReference for" + std::string(class_name); + env->ThrowNew(illegal_access_exception_class, error_message.c_str()); + } + return global_class; +} + +jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig) { + jmethodID ret = env->GetMethodID(this_class, name, sig); + if (ret == nullptr) { + std::string error_message = "Unable to find method " + std::string(name) + " within signature" + std::string(sig); + env->ThrowNew(illegal_access_exception_class, error_message.c_str()); + } + + return ret; +} +#endif //THESTRAL_PLUGIN_MASTER_JNI_COMMON_H diff --git a/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto b/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto new file mode 100644 index 0000000000000000000000000000000000000000..c40472020171692ea7b0acde2dd873efeda691f4 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto @@ -0,0 +1,60 @@ +syntax = "proto3"; + +package spark; +option java_package = "com.huawei.boostkit.spark.serialize"; +option java_outer_classname = "VecData"; + +message VecBatch { + int32 rowCnt = 1; + int32 vecCnt = 2; + repeated Vec vecs = 3; +} + +message Vec { + VecType vecType = 1; + bytes offset = 2; + bytes values = 3; + bytes nulls = 4; +} + +message VecType { + enum VecTypeId { + VEC_TYPE_NONE = 0; + VEC_TYPE_INT = 1; + VEC_TYPE_LONG = 2; + VEC_TYPE_DOUBLE = 3; + VEC_TYPE_BOOLEAN = 4; + VEC_TYPE_SHORT = 5; + VEC_TYPE_DECIMAL64 = 6; + VEC_TYPE_DECIMAL128 = 7; + VEC_TYPE_DATE32 = 8; + VEC_TYPE_DATE64 = 9; + VEC_TYPE_TIME32 = 10; + VEC_TYPE_TIME64 = 11; + VEC_TYPE_TIMESTAMP = 12; + VEC_TYPE_INTERVAL_MONTHS = 13; + VEC_TYPE_INTERVAL_DAY_TIME =14; + VEC_TYPE_VARCHAR = 15; + VEC_TYPE_CHAR = 16; + VEC_TYPE_DICTIONARY = 17; + VEC_TYPE_CONTAINER = 18; + VEC_TYPE_INVALID = 19; + } + + VecTypeId typeId = 1; + int32 width = 2; + uint32 precision = 3; + uint32 scale = 4; + enum DateUnit { + DAY = 0; + MILLI = 1; + } + DateUnit dateUnit = 5; + enum TimeUnit { + SEC = 0; + MILLISEC = 1; + MICROSEC = 2; + NANOSEC = 3; + } + TimeUnit timeUnit = 6; +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5fdff534440c394fb8aaddd2311badfca57bb00c --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp @@ -0,0 +1,975 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "splitter.h" +#include "utils.h" + +uint64_t SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD = UINT64_MAX; + +SplitOptions SplitOptions::Defaults() { return SplitOptions(); } + +// 计算分区id,每个batch初始化 +int Splitter::ComputeAndCountPartitionId(VectorBatch& vb) { + auto num_rows = vb.GetRowCount(); + std::fill(std::begin(partition_id_cnt_cur_), std::end(partition_id_cnt_cur_), 0); + partition_id_.resize(num_rows); + + if (singlePartitionFlag) { + partition_id_cnt_cur_[0] = num_rows; + partition_id_cnt_cache_[0] += num_rows; + for (auto i = 0; i < num_rows; ++i) { + partition_id_[i] = 0; + } + } else { + IntVector* hashVct = static_cast(vb.GetVector(0)); + for (auto i = 0; i < num_rows; ++i) { + // positive mod + int32_t pid = hashVct->GetValue(i); + if (pid >= num_partitions_) { + LogsError(" Illegal pid Value: %d >= partition number %d .", pid, num_partitions_); + throw std::runtime_error("Shuffle pidVec Illegal pid Value!"); + } + partition_id_[i] = pid; + partition_id_cnt_cur_[pid]++; + partition_id_cnt_cache_[pid]++; + } + } + return 0; +} + +//分区信息内存分配 +int Splitter::AllocatePartitionBuffers(int32_t partition_id, int32_t new_size) { + std::vector> new_binary_builders; + std::vector> new_value_buffers; + std::vector> new_validity_buffers; + + int num_fields = column_type_id_.size(); + auto fixed_width_idx = 0; + + for (auto i = 0; i < num_fields; ++i) { + switch (column_type_id_[i]) { + case SHUFFLE_BINARY: { + break; + } + case SHUFFLE_LARGE_BINARY: + case SHUFFLE_NULL: + break; + case SHUFFLE_1BYTE: + case SHUFFLE_2BYTE: + case SHUFFLE_4BYTE: + case SHUFFLE_8BYTE: + case SHUFFLE_DECIMAL128: + default: { + void *ptr_tmp = static_cast(options_.allocator->alloc(new_size * (1 << column_type_id_[i]))); + fixed_valueBuffer_size_[partition_id] = new_size * (1 << column_type_id_[i]); + if (nullptr == ptr_tmp) { + throw std::runtime_error("Allocator for AllocatePartitionBuffers Failed! "); + } + std::shared_ptr value_buffer (new Buffer((uint8_t *)ptr_tmp, 0, new_size * (1 << column_type_id_[i]))); + new_value_buffers.push_back(std::move(value_buffer)); + new_validity_buffers.push_back(nullptr); + fixed_width_idx++; + break; + } + } + } + + // point to newly allocated buffers + fixed_width_idx = 0; + for (auto i = 0; i < num_fields; ++i) { + switch (column_type_id_[i]) { + case SHUFFLE_1BYTE: + case SHUFFLE_2BYTE: + case SHUFFLE_4BYTE: + case SHUFFLE_8BYTE: + case SHUFFLE_DECIMAL128: { + partition_fixed_width_value_addrs_[fixed_width_idx][partition_id] = + const_cast(new_value_buffers[fixed_width_idx].get()->data_); + partition_fixed_width_validity_addrs_[fixed_width_idx][partition_id] = nullptr; + // partition_fixed_width_buffers_[fixed_width_idx][partition_id] 位置0执行bitmap,位置1指向数据 + partition_fixed_width_buffers_[fixed_width_idx][partition_id] = { + std::move(new_validity_buffers[fixed_width_idx]), + std::move(new_value_buffers[fixed_width_idx])}; + fixed_width_idx++; + break; + } + case SHUFFLE_BINARY: + default: { + break; + } + } + } + + partition_buffer_size_[partition_id] = new_size; + return 0; +} + +int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { + const auto num_rows = vb.GetRowCount(); + for (auto col = 0; col < fixed_width_array_idx_.size(); ++col) { + std::fill(std::begin(partition_buffer_idx_offset_), + std::end(partition_buffer_idx_offset_), 0); + auto col_idx_vb = fixed_width_array_idx_[col]; + auto col_idx_schema = singlePartitionFlag ? col_idx_vb : (col_idx_vb - 1); + const auto& dst_addrs = partition_fixed_width_value_addrs_[col]; + if (vb.GetVector(col_idx_vb)->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY) { + LogsDebug("Dictionary Columnar process!"); + auto ids_tmp = static_cast(options_.allocator->alloc(num_rows * sizeof(int32_t))); + Buffer *ids (new Buffer((uint8_t*)ids_tmp, 0, num_rows * sizeof(int32_t))); + if (ids->data_ == nullptr) { + throw std::runtime_error("Allocator for SplitFixedWidthValueBuffer ids Failed! "); + } + auto dictionaryTmp = ((DictionaryVector *)(vb.GetVector(col_idx_vb)))->ExtractDictionaryAndIds(0, num_rows, (int32_t *)(ids->data_)); + auto src_addr = VectorHelper::GetValuesAddr(dictionaryTmp); + switch (column_type_id_[col_idx_schema]) { +#define PROCESS(SHUFFLE_TYPE, CTYPE) \ + case SHUFFLE_TYPE: \ + for (auto row = 0; row < num_rows; ++row) { \ + auto pid = partition_id_[row]; \ + auto dst_offset = \ + partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; \ + reinterpret_cast(dst_addrs[pid])[dst_offset] = \ + reinterpret_cast(src_addr)[reinterpret_cast(ids->data_)[row]]; \ + partition_fixed_width_buffers_[col][pid][1]->size_ += (1 << SHUFFLE_TYPE); \ + partition_buffer_idx_offset_[pid]++; \ + } \ + break; + PROCESS(SHUFFLE_1BYTE, uint8_t) + PROCESS(SHUFFLE_2BYTE, uint16_t) + PROCESS(SHUFFLE_4BYTE, uint32_t) + PROCESS(SHUFFLE_8BYTE, uint64_t) +#undef PROCESS + case SHUFFLE_DECIMAL128: + for (auto row = 0; row < num_rows; ++row) { + auto pid = partition_id_[row]; + auto dst_offset = + partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; + reinterpret_cast(dst_addrs[pid])[dst_offset << 1] = + reinterpret_cast(src_addr)[reinterpret_cast(ids->data_)[row] << 1]; // 前64位取值、赋值 + reinterpret_cast(dst_addrs[pid])[dst_offset << 1 | 1] = + reinterpret_cast(src_addr)[reinterpret_cast(ids->data_)[row] << 1 | 1]; // 后64位取值、赋值 + partition_fixed_width_buffers_[col][pid][1]->size_ += + (1 << SHUFFLE_DECIMAL128); //decimal128 16Bytes + partition_buffer_idx_offset_[pid]++; + } + break; + default: { + LogsError("SplitFixedWidthValueBuffer not match this type: %d", column_type_id_[col_idx_schema]); + throw std::runtime_error("SplitFixedWidthValueBuffer not match this type: " + column_type_id_[col_idx_schema]); + } + } + options_.allocator->free(ids->data_, ids->capacity_); + if (nullptr == ids) { + throw std::runtime_error("delete nullptr error for ids"); + } + delete ids; + } else { + auto src_addr = VectorHelper::GetValuesAddr(vb.GetVector(col_idx_vb)); + switch (column_type_id_[col_idx_schema]) { +#define PROCESS(SHUFFLE_TYPE, CTYPE) \ + case SHUFFLE_TYPE: \ + for (auto row = 0; row < num_rows; ++row) { \ + auto pid = partition_id_[row]; \ + auto dst_offset = \ + partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; \ + reinterpret_cast(dst_addrs[pid])[dst_offset] = \ + reinterpret_cast(src_addr)[row]; \ + partition_fixed_width_buffers_[col][pid][1]->size_ += (1 << SHUFFLE_TYPE); \ + partition_buffer_idx_offset_[pid]++; \ + } \ + break; + PROCESS(SHUFFLE_1BYTE, uint8_t) + PROCESS(SHUFFLE_2BYTE, uint16_t) + PROCESS(SHUFFLE_4BYTE, uint32_t) + PROCESS(SHUFFLE_8BYTE, uint64_t) +#undef PROCESS + case SHUFFLE_DECIMAL128: + for (auto row = 0; row < num_rows; ++row) { + auto pid = partition_id_[row]; + auto dst_offset = + partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; + reinterpret_cast(dst_addrs[pid])[dst_offset << 1] = + reinterpret_cast(src_addr)[row << 1]; // 前64位取值、赋值 + reinterpret_cast(dst_addrs[pid])[(dst_offset << 1) | 1] = + reinterpret_cast(src_addr)[(row << 1) | 1]; // 后64位取值、赋值 + partition_fixed_width_buffers_[col][pid][1]->size_ += + (1 << SHUFFLE_DECIMAL128); //decimal128 16Bytes + partition_buffer_idx_offset_[pid]++; + } + break; + default: { + LogsError("ERROR: SplitFixedWidthValueBuffer not match this type: %d", column_type_id_[col_idx_schema]); + throw std::runtime_error("ERROR: SplitFixedWidthValueBuffer not match this type: " + column_type_id_[col_idx_schema]); + } + } + } + } + return 0; +} + +int Splitter::SplitBinaryArray(VectorBatch& vb) +{ + const auto numRows = vb.GetRowCount(); + auto vecCntVb = vb.GetVectorCount(); + auto vecCntSchema = singlePartitionFlag ? vecCntVb : vecCntVb - 1; + for (auto colSchema = 0; colSchema < vecCntSchema; ++colSchema) { + switch (column_type_id_[colSchema]) { + case SHUFFLE_BINARY: { + auto colVb = singlePartitionFlag ? colSchema : colSchema + 1; + if (vb.GetVector(colVb)->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY) { + for (auto row = 0; row < numRows; ++row) { + auto pid = partition_id_[row]; + uint8_t *dst = nullptr; + auto str_len = ((DictionaryVector *)(vb.GetVector(colVb)))->GetVarchar(row, &dst); + bool isnull = ((DictionaryVector *)(vb.GetVector(colVb)))->IsValueNull(row); + cached_vectorbatch_size_ += str_len; // 累计变长部分cache数据 + VCLocation cl((uint64_t) dst, str_len, isnull); + if ((vc_partition_array_buffers_[pid][colSchema].size() != 0) && + (vc_partition_array_buffers_[pid][colSchema].back().getVcList().size() < + options_.spill_batch_row_num)) { + vc_partition_array_buffers_[pid][colSchema].back().getVcList().push_back(cl); + vc_partition_array_buffers_[pid][colSchema].back().vcb_total_len += str_len; + } else { + VCBatchInfo svc(options_.spill_batch_row_num); + svc.getVcList().push_back(cl); + svc.vcb_total_len += str_len; + vc_partition_array_buffers_[pid][colSchema].push_back(svc); + } + } + } else { + VarcharVector *vc = nullptr; + vc = static_cast(vb.GetVector(colVb)); + for (auto row = 0; row < numRows; ++row) { + auto pid = partition_id_[row]; + uint8_t *dst = nullptr; + int str_len = vc->GetValue(row, &dst); + bool isnull = vc->IsValueNull(row); + cached_vectorbatch_size_ += str_len; // 累计变长部分cache数据 + VCLocation cl((uint64_t) dst, str_len, isnull); + if ((vc_partition_array_buffers_[pid][colSchema].size() != 0) && + (vc_partition_array_buffers_[pid][colSchema].back().getVcList().size() < + options_.spill_batch_row_num)) { + vc_partition_array_buffers_[pid][colSchema].back().getVcList().push_back(cl); + vc_partition_array_buffers_[pid][colSchema].back().vcb_total_len += str_len; + } else { + VCBatchInfo svc(options_.spill_batch_row_num); + svc.getVcList().push_back(cl); + svc.vcb_total_len += str_len; + vc_partition_array_buffers_[pid][colSchema].push_back(svc); + } + } + } + break; + } + case SHUFFLE_LARGE_BINARY: + break; + default:{ + break; + } + } + } + return 0; +} + +int Splitter::SplitFixedWidthValidityBuffer(VectorBatch& vb){ + for (auto col = 0; col < fixed_width_array_idx_.size(); ++col) { + auto col_idx = fixed_width_array_idx_[col]; + auto& dst_addrs = partition_fixed_width_validity_addrs_[col]; + // 分配内存并初始化 + for (auto pid = 0; pid < num_partitions_; ++pid) { + if (partition_id_cnt_cur_[pid] > 0 && dst_addrs[pid] == nullptr) { + // init bitmap if it's null + auto new_size = partition_id_cnt_cur_[pid] > options_.buffer_size ? partition_id_cnt_cur_[pid] : options_.buffer_size; + auto ptr_tmp = static_cast(options_.allocator->alloc(new_size)); + if (nullptr == ptr_tmp) { + throw std::runtime_error("Allocator for ValidityBuffer Failed! "); + } + std::shared_ptr validity_buffer (new Buffer((uint8_t *)ptr_tmp, 0, new_size)); + dst_addrs[pid] = const_cast(validity_buffer->data_); + std::memset(validity_buffer->data_, 0, new_size); + partition_fixed_width_buffers_[col][pid][0] = std::move(validity_buffer); + fixed_nullBuffer_size_[pid] = new_size; + } + } + + // 计算并填充数据 + auto src_addr = const_cast((uint8_t*)((vb.GetVector(col_idx))->GetValueNulls())); + std::fill(std::begin(partition_buffer_idx_offset_), + std::end(partition_buffer_idx_offset_), 0); + const auto num_rows = vb.GetRowCount(); + for (auto row = 0; row < num_rows; ++row) { + auto pid = partition_id_[row]; + auto dst_offset = partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; + dst_addrs[pid][dst_offset] = src_addr[row]; + partition_buffer_idx_offset_[pid]++; + partition_fixed_width_buffers_[col][pid][0]->size_ += 1; + } + } + return 0; +} + +int Splitter::CacheVectorBatch(int32_t partition_id, bool reset_buffers) { + if (partition_buffer_idx_base_[partition_id] > 0 && fixed_width_array_idx_.size() > 0) { + auto fixed_width_idx = 0; + auto num_fields = num_fields_; + int64_t batch_partition_size = 0; + std::vector>> bufferArrayTotal(num_fields); + + for (int i = 0; i < num_fields; ++i) { + switch (column_type_id_[i]) { + case SHUFFLE_BINARY: { + break; + } + case SHUFFLE_LARGE_BINARY: { + break; + } + case SHUFFLE_NULL: { + break; + } + default: { + auto& buffers = partition_fixed_width_buffers_[fixed_width_idx][partition_id]; + batch_partition_size += buffers[0]->capacity_; // 累计null数组所占内存大小 + batch_partition_size += buffers[1]->capacity_; // 累计value数组所占内存大小 + if (reset_buffers) { + bufferArrayTotal[fixed_width_idx] = std::move(buffers); + buffers = {nullptr}; + partition_fixed_width_validity_addrs_[fixed_width_idx][partition_id] = nullptr; + partition_fixed_width_value_addrs_[fixed_width_idx][partition_id] = nullptr; + } else { + bufferArrayTotal[fixed_width_idx] = buffers; + } + fixed_width_idx++; + break; + } + } + } + cached_vectorbatch_size_ += batch_partition_size; + partition_cached_vectorbatch_[partition_id].push_back(std::move(bufferArrayTotal)); + partition_buffer_idx_base_[partition_id] = 0; + } + return 0; +} + +int Splitter::DoSplit(VectorBatch& vb) { + // for the first input record batch, scan binary arrays and large binary + // arrays to get their empirical sizes + + if (!first_vector_batch_) { + first_vector_batch_ = true; + } + + for (auto col = 0; col < fixed_width_array_idx_.size(); ++col) { + auto col_idx = fixed_width_array_idx_[col]; + if (vb.GetVector(col_idx)->GetValueNulls() != nullptr) { + input_fixed_width_has_null_[col] = true; + } + } + + // prepare partition buffers and spill if necessary + for (auto pid = 0; pid < num_partitions_; ++pid) { + if (fixed_width_array_idx_.size() > 0 && + partition_id_cnt_cur_[pid] > 0 && + partition_buffer_idx_base_[pid] + partition_id_cnt_cur_[pid] > partition_buffer_size_[pid]) { + auto new_size = partition_id_cnt_cur_[pid] > options_.buffer_size ? partition_id_cnt_cur_[pid] : options_.buffer_size; + if (partition_buffer_size_[pid] == 0) { // first allocate? + AllocatePartitionBuffers(pid, new_size); + } else { // not first allocate, spill + CacheVectorBatch(pid, true); + AllocatePartitionBuffers(pid, new_size); + } + } + } + SplitFixedWidthValueBuffer(vb); + SplitFixedWidthValidityBuffer(vb); + + current_fixed_alloc_buffer_size_ = 0; // 用于统计定长split但未cache部分内存大小 + for (auto pid = 0; pid < num_partitions_; ++pid) { + // update partition buffer base + partition_buffer_idx_base_[pid] += partition_id_cnt_cur_[pid]; + current_fixed_alloc_buffer_size_ += fixed_valueBuffer_size_[pid]; + current_fixed_alloc_buffer_size_ += fixed_nullBuffer_size_[pid]; + } + + // Binary split last vector batch... + SplitBinaryArray(vb); + vectorBatch_cache_.push_back(&vb); // record for release vector + + // 阈值检查,是否溢写 + num_row_splited_ += vb.GetRowCount(); + if (num_row_splited_ + vb.GetRowCount() >= SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) { + LogsDebug(" Spill For Row Num Threshold."); + TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFile()); + } + if (cached_vectorbatch_size_ + current_fixed_alloc_buffer_size_ >= options_.spill_mem_threshold) { + LogsDebug(" Spill For Memory Size Threshold."); + TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFile()); + } + return 0; +} + +void Splitter::ToSplitterTypeId(int num_cols) +{ + for (int i = 0; i < num_cols; ++i) { + switch (input_col_types.inputVecTypeIds[i]) { + case OMNI_INT:{ + column_type_id_.push_back(ShuffleTypeId::SHUFFLE_4BYTE); + vector_batch_col_types_.push_back(OMNI_INT); + break; + } + case OMNI_LONG:{ + column_type_id_.push_back(ShuffleTypeId::SHUFFLE_8BYTE); + vector_batch_col_types_.push_back(OMNI_LONG); + break; + } + case OMNI_DOUBLE:{ + column_type_id_.push_back(ShuffleTypeId::SHUFFLE_8BYTE); + vector_batch_col_types_.push_back(OMNI_DOUBLE); + break; + } + case OMNI_DATE32:{ + column_type_id_.push_back(ShuffleTypeId::SHUFFLE_4BYTE); + vector_batch_col_types_.push_back(OMNI_DATE32); + break; + } + case OMNI_DATE64:{ + column_type_id_.push_back(ShuffleTypeId::SHUFFLE_8BYTE); + vector_batch_col_types_.push_back(OMNI_DATE64); + break; + } + case OMNI_DECIMAL64:{ + column_type_id_.push_back(ShuffleTypeId::SHUFFLE_8BYTE); + vector_batch_col_types_.push_back(OMNI_DECIMAL64); + break; + } + case OMNI_DECIMAL128:{ + column_type_id_.push_back(ShuffleTypeId::SHUFFLE_DECIMAL128); + vector_batch_col_types_.push_back(OMNI_DECIMAL128); + break; + } + case OMNI_CHAR:{ + column_type_id_.push_back(ShuffleTypeId::SHUFFLE_BINARY); + vector_batch_col_types_.push_back(OMNI_CHAR); + break; + } + case OMNI_VARCHAR:{ + column_type_id_.push_back(ShuffleTypeId::SHUFFLE_BINARY); + vector_batch_col_types_.push_back(OMNI_VARCHAR); + break; + } + default:{ + throw std::runtime_error("Unsupported DataTypeId."); + } + } + } +} + +int Splitter::Split_Init(){ + num_row_splited_ = 0; + cached_vectorbatch_size_ = 0; + partition_id_cnt_cur_.resize(num_partitions_); + partition_id_cnt_cache_.resize(num_partitions_); + partition_buffer_size_.resize(num_partitions_); + partition_buffer_idx_base_.resize(num_partitions_); + partition_buffer_idx_offset_.resize(num_partitions_); + partition_cached_vectorbatch_.resize(num_partitions_); + partition_serialization_size_.resize(num_partitions_); + fixed_width_array_idx_.clear(); + partition_lengths_.resize(num_partitions_); + fixed_valueBuffer_size_.resize(num_partitions_); + fixed_nullBuffer_size_.resize(num_partitions_); + + //obtain configed dir from Environment Variables + configured_dirs_ = GetConfiguredLocalDirs(); + sub_dir_selection_.assign(configured_dirs_.size(), 0); + + // Both data_file and shuffle_index_file should be set through jni. + // For test purpose, Create a temporary subdirectory in the system temporary + // dir with prefix "columnar-shuffle" + if (options_.data_file.length() == 0) { + options_.data_file = CreateTempShuffleFile(configured_dirs_[0]); + } + + for (int i = 0; i < column_type_id_.size(); ++i) { + switch (column_type_id_[i]) { + case ShuffleTypeId::SHUFFLE_1BYTE: + case ShuffleTypeId::SHUFFLE_2BYTE: + case ShuffleTypeId::SHUFFLE_4BYTE: + case ShuffleTypeId::SHUFFLE_8BYTE: + case ShuffleTypeId::SHUFFLE_DECIMAL128: + if (singlePartitionFlag) { + fixed_width_array_idx_.push_back(i); + } else { + fixed_width_array_idx_.push_back(i + 1); + } + break; + case ShuffleTypeId::SHUFFLE_BINARY: + default: + break; + } + } + auto num_fixed_width = fixed_width_array_idx_.size(); + partition_fixed_width_validity_addrs_.resize(num_fixed_width); + partition_fixed_width_value_addrs_.resize(num_fixed_width); + partition_fixed_width_buffers_.resize(num_fixed_width); + input_fixed_width_has_null_.resize(num_fixed_width, false); + for (auto i = 0; i < num_fixed_width; ++i) { + partition_fixed_width_validity_addrs_[i].resize(num_partitions_); + partition_fixed_width_value_addrs_[i].resize(num_partitions_); + partition_fixed_width_buffers_[i].resize(num_partitions_); + } + + /* init varchar partition */ + vc_partition_array_buffers_.resize(num_partitions_); + for (auto i = 0; i < num_partitions_; ++i) { + vc_partition_array_buffers_[i].resize(column_type_id_.size()); + } + return 0; +} + +int Splitter::Split(VectorBatch& vb ) +{ + //计算vectorBatch分区信息 + LogsTrace(" split vb row number: %d ", vb.GetRowCount()); + TIME_NANO_OR_RAISE(total_compute_pid_time_, ComputeAndCountPartitionId(vb)); + //执行分区动作 + DoSplit(vb); + return 0; +} + +std::shared_ptr Splitter::CaculateSpilledTmpFilePartitionOffsets() { + void *ptr_tmp = static_cast(options_.allocator->alloc((num_partitions_ + 1) * sizeof(uint32_t))); + if (nullptr == ptr_tmp) { + throw std::runtime_error("Allocator for partitionOffsets Failed! "); + } + std::shared_ptr ptrPartitionOffsets (new Buffer((uint8_t*)ptr_tmp, 0, (num_partitions_ + 1) * sizeof(uint32_t))); + uint32_t pidOffset = 0; + // 顺序记录每个partition的offset + auto pid = 0; + for (pid = 0; pid < num_partitions_; ++pid) { + reinterpret_cast(ptrPartitionOffsets->data_)[pid] = pidOffset; + pidOffset += partition_serialization_size_[pid]; + // reset partition_cached_vectorbatch_size_ to 0 + partition_serialization_size_[pid] = 0; + } + reinterpret_cast(ptrPartitionOffsets->data_)[pid] = pidOffset; + return ptrPartitionOffsets; +} + +spark::VecType::VecTypeId CastShuffleTypeIdToVecType(int32_t tmpType) { + switch (tmpType) { + case OMNI_NONE: + return spark::VecType::VEC_TYPE_NONE; + case OMNI_INT: + return spark::VecType::VEC_TYPE_INT; + case OMNI_LONG: + return spark::VecType::VEC_TYPE_LONG; + case OMNI_DOUBLE: + return spark::VecType::VEC_TYPE_DOUBLE; + case OMNI_BOOLEAN: + return spark::VecType::VEC_TYPE_BOOLEAN; + case OMNI_SHORT: + return spark::VecType::VEC_TYPE_SHORT; + case OMNI_DECIMAL64: + return spark::VecType::VEC_TYPE_DECIMAL64; + case OMNI_DECIMAL128: + return spark::VecType::VEC_TYPE_DECIMAL128; + case OMNI_DATE32: + return spark::VecType::VEC_TYPE_DATE32; + case OMNI_DATE64: + return spark::VecType::VEC_TYPE_DATE64; + case OMNI_TIME32: + return spark::VecType::VEC_TYPE_TIME32; + case OMNI_TIME64: + return spark::VecType::VEC_TYPE_TIME64; + case OMNI_TIMESTAMP: + return spark::VecType::VEC_TYPE_TIMESTAMP; + case OMNI_INTERVAL_MONTHS: + return spark::VecType::VEC_TYPE_INTERVAL_MONTHS; + case OMNI_INTERVAL_DAY_TIME: + return spark::VecType::VEC_TYPE_INTERVAL_DAY_TIME; + case OMNI_VARCHAR: + return spark::VecType::VEC_TYPE_VARCHAR; + case OMNI_CHAR: + return spark::VecType::VEC_TYPE_CHAR; + case OMNI_CONTAINER: + return spark::VecType::VEC_TYPE_CONTAINER; + case OMNI_INVALID: + return spark::VecType::VEC_TYPE_INVALID; + default: { + throw std::runtime_error("castShuffleTypeIdToVecType() unexpected ShuffleTypeId"); + } + } +}; + +int Splitter::SerializingFixedColumns(int32_t partitionId, + spark::Vec& vec, + int fixColIndexTmp, + SplitRowInfo* splitRowInfoTmp) +{ + LogsDebug(" Fix col :%d th...", fixColIndexTmp); + LogsDebug(" partition_cached_vectorbatch_[%d].size: %ld", partitionId, partition_cached_vectorbatch_[partitionId].size()); + if (splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp] < partition_cached_vectorbatch_[partitionId].size()) { + auto colIndexTmpSchema = 0; + colIndexTmpSchema = singlePartitionFlag ? fixed_width_array_idx_[fixColIndexTmp] : fixed_width_array_idx_[fixColIndexTmp] - 1; + auto onceCopyLen = splitRowInfoTmp->onceCopyRow * (1 << column_type_id_[colIndexTmpSchema]); + // 临时内存,拷贝拼接onceCopyRow批,用完释放 + void *ptr_value_tmp = static_cast(options_.allocator->alloc(onceCopyLen)); + std::shared_ptr ptr_value (new Buffer((uint8_t*)ptr_value_tmp, 0, onceCopyLen)); + void *ptr_validity_tmp = static_cast(options_.allocator->alloc(splitRowInfoTmp->onceCopyRow)); + std::shared_ptr ptr_validity (new Buffer((uint8_t*)ptr_validity_tmp, 0, splitRowInfoTmp->onceCopyRow)); + if (nullptr == ptr_value->data_ || nullptr == ptr_validity->data_) { + throw std::runtime_error("Allocator for tmp buffer Failed! "); + } + // options_.spill_batch_row_num长度切割与拼接 + int destCopyedLength = 0; + int memCopyLen = 0; + int cacheBatchSize = 0; + while (destCopyedLength < onceCopyLen) { + if (splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp] >= partition_cached_vectorbatch_[partitionId].size()) { // 数组越界保护 + throw std::runtime_error("Columnar shuffle CacheBatchIndex out of bound."); + } + cacheBatchSize = partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->size_; + LogsDebug(" partitionId:%d splitRowInfoTmp.cacheBatchIndex[%d]:%d cacheBatchSize:%d onceCopyLen:%d destCopyedLength:%d splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp]:%d ", + partitionId, + fixColIndexTmp, + splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp], + cacheBatchSize, + onceCopyLen, + destCopyedLength, + splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp]); + if ((onceCopyLen - destCopyedLength) >= (cacheBatchSize - splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp])) { + memCopyLen = cacheBatchSize - splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp]; + memcpy((uint8_t*)(ptr_value->data_) + destCopyedLength, + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->data_ + splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp], + memCopyLen); + // (destCopyedLength / (1 << column_type_id_[colIndexTmpSchema])) 等比例计算null数组偏移 + memcpy((uint8_t*)(ptr_validity->data_) + (destCopyedLength / (1 << column_type_id_[colIndexTmpSchema])), + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_ + (splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] / (1 << column_type_id_[colIndexTmpSchema])), + memCopyLen / (1 << column_type_id_[colIndexTmpSchema])); + // 释放内存 + LogsDebug(" free memory Partition[%d] cacheindex[col%d]:%d ", partitionId, fixColIndexTmp, splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]); + options_.allocator->free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_, + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->capacity_); + options_.allocator->free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->data_, + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->capacity_); + destCopyedLength += memCopyLen; + splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp] += 1; // cacheBatchIndex下标后移 + splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] = 0; // 初始化下一个cacheBatch的起始偏移 + } else { + memCopyLen = onceCopyLen - destCopyedLength; + memcpy((uint8_t*)(ptr_value->data_) + destCopyedLength, + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->data_ + splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp], + memCopyLen); + // (destCopyedLength / (1 << column_type_id_[colIndexTmpSchema])) 等比例计算null数组偏移 + memcpy((uint8_t*)(ptr_validity->data_) + (destCopyedLength / (1 << column_type_id_[colIndexTmpSchema])), + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_ + (splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] / (1 << column_type_id_[colIndexTmpSchema])), + memCopyLen / (1 << column_type_id_[colIndexTmpSchema])); + destCopyedLength = onceCopyLen; // copy目标完成,结束while循环 + splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] += memCopyLen; + } + LogsDebug(" memCopyedLen=%d.", memCopyLen); + LogsDebug(" splitRowInfoTmp.cacheBatchIndex[fix_col%d]=%d splitRowInfoTmp.cacheBatchCopyedLen[fix_col%d]=%d ", + fixColIndexTmp, + splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp], + fixColIndexTmp, + splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp]); + } + vec.set_values(ptr_value->data_, onceCopyLen); + vec.set_nulls(ptr_validity->data_, splitRowInfoTmp->onceCopyRow); + // 临时内存,拷贝拼接onceCopyRow批,用完释放 + options_.allocator->free(ptr_value->data_, ptr_value->capacity_); + options_.allocator->free(ptr_validity->data_, ptr_validity->capacity_); + } + // partition_cached_vectorbatch_[partition_id][cache_index][col][0]代表ByteMap, + // partition_cached_vectorbatch_[partition_id][cache_index][col][1]代表value +} + +int Splitter::SerializingBinaryColumns(int32_t partitionId, spark::Vec& vec, int colIndex, int curBatch) +{ + LogsDebug(" vc_partition_array_buffers_[partitionId:%d][colIndex:%d] cacheBatchNum:%lu curBatch:%d", partitionId, colIndex, vc_partition_array_buffers_[partitionId][colIndex].size(), curBatch); + VCBatchInfo vcb = vc_partition_array_buffers_[partitionId][colIndex][curBatch]; + int valuesTotalLen = vcb.getVcbTotalLen(); + std::vector lst = vcb.getVcList(); + int itemsTotalLen = lst.size(); + auto OffsetsByte(std::make_unique(itemsTotalLen + 1)); + auto nullsByte(std::make_unique(itemsTotalLen)); + auto valuesByte(std::make_unique(valuesTotalLen)); + BytesGen(reinterpret_cast(OffsetsByte.get()), + reinterpret_cast(nullsByte.get()), + reinterpret_cast(valuesByte.get()), vcb); + vec.set_values(valuesByte.get(), valuesTotalLen); + // nulls add boolean array; serizelized tobytearray + vec.set_nulls((char *)nullsByte.get(), itemsTotalLen); + vec.set_offset(OffsetsByte.get(), (itemsTotalLen + 1) * sizeof(int32_t)); +} + +int Splitter::protoSpillPartition(int32_t partition_id, std::unique_ptr &bufferStream) { + LogsDebug(" Spill Pid:%d.", partition_id); + SplitRowInfo splitRowInfoTmp; + splitRowInfoTmp.copyedRow = 0; + splitRowInfoTmp.remainCopyRow = partition_id_cnt_cache_[partition_id]; + splitRowInfoTmp.cacheBatchIndex.resize(fixed_width_array_idx_.size()); + splitRowInfoTmp.cacheBatchCopyedLen.resize(fixed_width_array_idx_.size()); + LogsDebug(" remainCopyRow %d ", splitRowInfoTmp.remainCopyRow); + auto partition_cache_batch_num = partition_cached_vectorbatch_[partition_id].size(); + LogsDebug(" partition_cache_batch_num %lu ", partition_cache_batch_num); + int curBatch = 0; // 变长cache batch下标,split已按照options_.spill_batch_row_num切割完成 + total_spill_row_num_ += splitRowInfoTmp.remainCopyRow; + while (0 < splitRowInfoTmp.remainCopyRow) { + if (options_.spill_batch_row_num < splitRowInfoTmp.remainCopyRow) { + splitRowInfoTmp.onceCopyRow = options_.spill_batch_row_num; + } else { + splitRowInfoTmp.onceCopyRow = splitRowInfoTmp.remainCopyRow; + } + + vecBatchProto->set_rowcnt(splitRowInfoTmp.onceCopyRow); + vecBatchProto->set_veccnt(column_type_id_.size()); + int fixColIndexTmp = 0; + for (size_t indexSchema = 0; indexSchema < column_type_id_.size(); indexSchema++) { + spark::Vec * vec = vecBatchProto->add_vecs(); + switch (column_type_id_[indexSchema]) { + case ShuffleTypeId::SHUFFLE_1BYTE: + case ShuffleTypeId::SHUFFLE_2BYTE: + case ShuffleTypeId::SHUFFLE_4BYTE: + case ShuffleTypeId::SHUFFLE_8BYTE: + case ShuffleTypeId::SHUFFLE_DECIMAL128:{ + SerializingFixedColumns(partition_id, *vec, fixColIndexTmp, &splitRowInfoTmp); + fixColIndexTmp++; // 定长序列化数量++ + break; + } + case ShuffleTypeId::SHUFFLE_BINARY: { + SerializingBinaryColumns(partition_id, *vec, indexSchema, curBatch); + break; + } + default: { + throw std::runtime_error("Unsupported ShuffleType."); + } + } + spark::VecType *vt = vec->mutable_vectype(); + vt->set_typeid_(CastShuffleTypeIdToVecType(vector_batch_col_types_[indexSchema])); + LogsDebug("precision[indexSchema %d]: %d ", indexSchema, input_col_types.inputDataPrecisions[indexSchema]); + LogsDebug("scale[indexSchema %d]: %d ", indexSchema, input_col_types.inputDataScales[indexSchema]); + vt->set_precision(input_col_types.inputDataPrecisions[indexSchema]); + vt->set_scale(input_col_types.inputDataScales[indexSchema]); + } + curBatch++; + + uint32_t vecBatchProtoSize = reversebytes_uint32t(vecBatchProto->ByteSize()); + void *buffer = nullptr; + if (!bufferStream->NextNBytes(&buffer, sizeof(vecBatchProtoSize))) { + LogsError("Allocate Memory Failed: Flush Spilled Data, Next failed."); + throw std::runtime_error("Allocate Memory Failed: Flush Spilled Data, Next failed."); + } + // set serizalized bytes to stream + memcpy(buffer, &vecBatchProtoSize, sizeof(vecBatchProtoSize)); + LogsDebug(" A Slice Of vecBatchProtoSize: %d ", reversebytes_uint32t(vecBatchProtoSize)); + + vecBatchProto->SerializeToZeroCopyStream(bufferStream.get()); + + splitRowInfoTmp.remainCopyRow -= splitRowInfoTmp.onceCopyRow; + splitRowInfoTmp.copyedRow += splitRowInfoTmp.onceCopyRow; + LogsTrace(" SerializeVecBatch:\n%s", vecBatchProto->DebugString().c_str()); + vecBatchProto->Clear(); + } + + uint64_t partitionBatchSize = bufferStream->flush(); + total_bytes_spilled_ += partitionBatchSize; + partition_serialization_size_[partition_id] = partitionBatchSize; + LogsDebug(" partitionBatch write length: %lu", partitionBatchSize); + + // 及时清理分区数据 + partition_cached_vectorbatch_[partition_id].clear(); // 定长数据内存释放 + for (size_t col = 0; col < column_type_id_.size(); col++) { + vc_partition_array_buffers_[partition_id][col].clear(); // binary 释放内存 + } + + return 0; +} + +int Splitter::WriteDataFileProto() { + LogsDebug(" spill DataFile: %s ", (options_.next_spilled_file_dir + ".data").c_str()); + std::unique_ptr outStream = writeLocalFile(options_.next_spilled_file_dir + ".data"); + WriterOptions options; + // tmp spilled file no need compression + options.setCompression(CompressionKind_NONE); + std::unique_ptr streamsFactory = createStreamsFactory(options, outStream.get()); + std::unique_ptr bufferStream = streamsFactory->createStream(); + // 顺序写入每个partition的offset + for (auto pid = 0; pid < num_partitions_; ++pid) { + protoSpillPartition(pid, bufferStream); + } + std::fill(std::begin(partition_id_cnt_cache_), std::end(partition_id_cnt_cache_), 0); + outStream->close(); + return 0; +} + +void Splitter::MergeSpilled() { + LogsDebug(" Merge Spilled Tmp File."); + std::unique_ptr outStream = writeLocalFile(options_.data_file); + LogsDebug(" MergeSpilled target dir: %s ", options_.data_file.c_str()); + WriterOptions options; + options.setCompression(options_.compression_type); + options.setCompressionBlockSize(options_.compress_block_size); + options.setCompressionStrategy(CompressionStrategy_COMPRESSION); + std::unique_ptr streamsFactory = createStreamsFactory(options, outStream.get()); + std::unique_ptr bufferOutPutStream = streamsFactory->createStream(); + + void* bufferOut = nullptr; + int sizeOut = 0; + for (int pid = 0; pid < num_partitions_; pid++) { + LogsDebug(" MergeSplled traversal partition( %d ) ",pid); + for (auto &pair : spilled_tmp_files_info_) { + auto tmpDataFilePath = pair.first + ".data"; + auto tmpPartitionOffset = reinterpret_cast(pair.second->data_)[pid]; + auto tmpPartitionSize = reinterpret_cast(pair.second->data_)[pid + 1] - reinterpret_cast(pair.second->data_)[pid]; + LogsDebug(" get Partition Stream...tmpPartitionOffset %d tmpPartitionSize %d path %s", + tmpPartitionOffset, tmpPartitionSize, tmpDataFilePath.c_str()); + std::unique_ptr inputStream = readLocalFile(tmpDataFilePath); + uint64_t targetLen = tmpPartitionSize; + uint64_t seekPosit = tmpPartitionOffset; + uint64_t onceReadLen = 0; + while ((targetLen > 0) && bufferOutPutStream->Next(&bufferOut, &sizeOut)) { + onceReadLen = targetLen > sizeOut ? sizeOut : targetLen; + inputStream->read(bufferOut, onceReadLen, seekPosit); + targetLen -= onceReadLen; + seekPosit += onceReadLen; + if (onceReadLen < sizeOut) { + // Reached END. + bufferOutPutStream->BackUp(sizeOut - onceReadLen); + break; + } + } + + uint64_t flushSize = bufferOutPutStream->flush(); + total_bytes_written_ += flushSize; + LogsDebug(" Merge Flush Partition[%d] flushSize: %ld ", pid, flushSize); + partition_lengths_[pid] += flushSize; + } + } + outStream->close(); +} + +int Splitter::DeleteSpilledTmpFile() { + for (auto &pair : spilled_tmp_files_info_) { + auto tmpDataFilePath = pair.first + ".data"; + // 释放存储有各个临时文件的偏移数据内存 + options_.allocator->free(pair.second->data_, pair.second->capacity_); + if (IsFileExist(tmpDataFilePath)) { + remove(tmpDataFilePath.c_str()); + } + } + // 释放内存空间,Reset spilled_tmp_files_info_, 这个地方是否有内存泄漏的风险??? + spilled_tmp_files_info_.clear(); + return 0; +} + +int Splitter::SpillToTmpFile() { + for (auto pid = 0; pid < num_partitions_; ++pid) { + CacheVectorBatch(pid, true); + partition_buffer_size_[pid] = 0; //溢写之后将其清零,条件溢写需要重新分配内存 + } + + options_.next_spilled_file_dir = CreateTempShuffleFile(NextSpilledFileDir()); + WriteDataFileProto(); + std::shared_ptr ptrTmp = CaculateSpilledTmpFilePartitionOffsets(); + spilled_tmp_files_info_[options_.next_spilled_file_dir] = ptrTmp; + + LogsDebug(" free vectorBatch memory... "); + auto cache_vectorBatch_num = vectorBatch_cache_.size(); + for (auto i = 0; i < cache_vectorBatch_num; ++i) { + ReleaseVectorBatch(*vectorBatch_cache_[i]); + if (nullptr == vectorBatch_cache_[i]) { + throw std::runtime_error("delete nullptr error for free vectorBatch"); + } + delete vectorBatch_cache_[i]; + vectorBatch_cache_[i] = nullptr; + } + vectorBatch_cache_.clear(); + num_row_splited_ = 0; + cached_vectorbatch_size_ = 0; + return 0; +} + +Splitter::Splitter(InputDataTypes inputDataTypes, int32_t num_cols, int32_t num_partitions, SplitOptions options, bool flag) + : num_fields_(num_cols), + num_partitions_(num_partitions), + options_(std::move(options)), + singlePartitionFlag(flag), + input_col_types(inputDataTypes) +{ + LogsDebug("Input Schema colNum: %d", num_cols); + ToSplitterTypeId(num_cols); +} + +std::shared_ptr Create(InputDataTypes inputDataTypes, + int32_t num_cols, + int32_t num_partitions, + SplitOptions options, + bool flag) +{ + std::shared_ptr res( + new Splitter(inputDataTypes, + num_cols, + num_partitions, + std::move(options), + flag)); + res->Split_Init(); + return res; +} + +std::shared_ptr Splitter::Make( + const std::string& short_name, + InputDataTypes inputDataTypes, + int32_t num_cols, + int num_partitions, + SplitOptions options) { + if (short_name == "hash" || short_name == "rr" || short_name == "range") { + return Create(inputDataTypes, num_cols, num_partitions, std::move(options), false); + } else if (short_name == "single") { + return Create(inputDataTypes, num_cols, num_partitions, std::move(options), true); + } else { + throw("ERROR: Unsupported Splitter Type."); + } +} + +std::string Splitter::NextSpilledFileDir() { + auto spilled_file_dir = GetSpilledShuffleFileDir(configured_dirs_[dir_selection_], + sub_dir_selection_[dir_selection_]); + LogsDebug(" spilled_file_dir %s ", spilled_file_dir.c_str()); + sub_dir_selection_[dir_selection_] = + (sub_dir_selection_[dir_selection_] + 1) % options_.num_sub_dirs; + dir_selection_ = (dir_selection_ + 1) % configured_dirs_.size(); + return spilled_file_dir; +} + +int Splitter::Stop() { + LogsDebug(" Spill For Splitter Stopped."); + TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFile()); + TIME_NANO_OR_RAISE(total_write_time_, MergeSpilled()); + TIME_NANO_OR_RAISE(total_write_time_, DeleteSpilledTmpFile()); + LogsDebug("total_spill_row_num_: %ld ", total_spill_row_num_); + if (nullptr == vecBatchProto) { + throw std::runtime_error("delete nullptr error for free protobuf vecBatch memory"); + } + delete vecBatchProto; //free protobuf vecBatch memory + return 0; +} + + + diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h new file mode 100644 index 0000000000000000000000000000000000000000..6339bec516397d4610d28a5f9205ee3ab9616d55 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h @@ -0,0 +1,192 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CPP_SPLITTER_H +#define CPP_SPLITTER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "type.h" +#include "../io/ColumnWriter.hh" +#include "../common/common.h" +#include "vec_data.pb.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" + +using namespace std; +using namespace spark; +using namespace google::protobuf::io; +using namespace omniruntime::vec; +using namespace omniruntime::type; +using namespace omniruntime::mem; + +struct SplitRowInfo { + uint32_t copyedRow = 0; + uint32_t onceCopyRow = 0; + uint64_t remainCopyRow = 0; + vector cacheBatchIndex; // 记录各定长列的溢写Batch下标 + vector cacheBatchCopyedLen; // 记录各定长列的溢写Batch内部偏移 +}; + +class Splitter { + + virtual int DoSplit(VectorBatch& vb); + + int WriteDataFileProto(); + + std::shared_ptr CaculateSpilledTmpFilePartitionOffsets(); + + int SerializingFixedColumns(int32_t partitionId, + spark::Vec& vec, + int fixColIndexTmp, + SplitRowInfo* splitRowInfoTmp); + + int SerializingBinaryColumns(int32_t partitionId, + spark::Vec& vec, + int colIndex, + int curBatch); + + int protoSpillPartition(int32_t partition_id, std::unique_ptr &bufferStream); + + int ComputeAndCountPartitionId(VectorBatch& vb); + + int AllocatePartitionBuffers(int32_t partition_id, int32_t new_size); + + int SplitFixedWidthValueBuffer(VectorBatch& vb); + + int SplitFixedWidthValidityBuffer(VectorBatch& vb); + + int SplitBinaryArray(VectorBatch& vb); + + int CacheVectorBatch(int32_t partition_id, bool reset_buffers); + + void ToSplitterTypeId(int num_cols); + + void MergeSpilled(); + + std::vector partition_id_; // 记录当前vb每一行的pid + std::vector partition_id_cnt_cur_; // 统计不同partition记录的行数(当前处理中的vb) + std::vector partition_id_cnt_cache_; // 统计不同partition记录的行数,cache住的 + // column number + uint32_t num_row_splited_; // cached row number + uint64_t cached_vectorbatch_size_; // cache total vectorbatch size in bytes + uint64_t current_fixed_alloc_buffer_size_ = 0; + std::vector fixed_valueBuffer_size_; // 当前定长omniAlloc已经分配value内存大小byte + std::vector fixed_nullBuffer_size_; // 当前定长omniAlloc已分配null内存大小byte + // int32_t num_cache_vector_; + std::vector column_type_id_; // 各列映射SHUFFLE类型,schema列id序列 + std::vector> partition_fixed_width_validity_addrs_; + std::vector> partition_fixed_width_value_addrs_; // + std::vector>>> partition_fixed_width_buffers_; + std::vector>> partition_binary_builders_; + std::vector partition_buffer_size_; // 各分区的buffer大小 + std::vector fixed_width_array_idx_; // 记录各定长类型列的序号,VB 列id序列 + std::vector binary_array_idx_; //记录各变长类型列序号 + std::vector partition_buffer_idx_base_; //当前已缓存的各partition行数据记录,用于定位缓冲buffer当前可用位置 + std::vector partition_buffer_idx_offset_; //split定长列时用于统计offset的临时变量 + std::vector partition_serialization_size_; // 记录序列化后的各partition大小,用于stop返回partition偏移 in bytes + + std::vector input_fixed_width_has_null_; // 定长列是否含有null标志数组 + + // configured local dirs for spilled file + int32_t dir_selection_ = 0; + std::vector sub_dir_selection_; + std::vector configured_dirs_; + + std::vector>>>> partition_cached_vectorbatch_; + std::vector vectorBatch_cache_; + /* + * varchar buffers: + * partition_array_buffers_[partition_id][col_id][varcharBatch_id] + * + */ + std::vector>> vc_partition_array_buffers_; + + int64_t total_bytes_written_ = 0; + int64_t total_bytes_spilled_ = 0; + int64_t total_write_time_ = 0; + int64_t total_spill_time_ = 0; + int64_t total_compute_pid_time_ = 0; + int64_t total_spill_row_num_ = 0; + std::vector partition_lengths_; + +private: + bool first_vector_batch_ = false; + std::vector vector_batch_col_types_; + InputDataTypes input_col_types; + std::vector binary_array_empirical_size_; + +public: + bool singlePartitionFlag = false; + int32_t num_partitions_; + SplitOptions options_; + // 分区数 + int32_t num_fields_; + + std::map> spilled_tmp_files_info_; + + VecBatch *vecBatchProto = new VecBatch(); //protobuf 序列化对象结构 + + virtual int Split_Init(); + + virtual int Split(VectorBatch& vb); + + int Stop(); + + int SpillToTmpFile(); + + Splitter(InputDataTypes inputDataTypes, + int32_t num_cols, + int32_t num_partitions, + SplitOptions options, + bool flag); + + static std::shared_ptr Make( + const std::string &short_name, + InputDataTypes inputDataTypes, + int32_t num_cols, + int num_partitions, + SplitOptions options); + + std::string NextSpilledFileDir(); + + int DeleteSpilledTmpFile(); + + int64_t TotalBytesWritten() const { return total_bytes_written_; } + + int64_t TotalBytesSpilled() const { return total_bytes_spilled_; } + + int64_t TotalWriteTime() const { return total_write_time_; } + + int64_t TotalSpillTime() const { return total_spill_time_; } + + int64_t TotalComputePidTime() const { return total_compute_pid_time_; } + + const std::vector& PartitionLengths() const { return partition_lengths_; } +}; + + +#endif // CPP_SPLITTER_H diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h new file mode 100644 index 0000000000000000000000000000000000000000..446cedc5f89988f115aedb7d9b3bc9b7c1c0a177 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h @@ -0,0 +1,72 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CPP_TYPE_H +#define CPP_TYPE_H +#include +#include "../io/SparkFile.hh" +#include "../io/ColumnWriter.hh" + +using namespace spark; +using namespace omniruntime::mem; + +static constexpr int32_t kDefaultSplitterBufferSize = 4096; +static constexpr int32_t kDefaultNumSubDirs = 64; + +struct SplitOptions { + int32_t buffer_size = kDefaultSplitterBufferSize; + int32_t num_sub_dirs = kDefaultNumSubDirs; + CompressionKind compression_type = CompressionKind_NONE; + std::string next_spilled_file_dir = ""; + + std::string data_file; + + int64_t thread_id = -1; + int64_t task_attempt_id = -1; + + BaseAllocator *allocator = omniruntime::mem::GetProcessRootAllocator(); + + uint64_t spill_batch_row_num = 4096; // default value + uint64_t spill_mem_threshold = 1024 * 1024 * 1024; // default value + uint64_t compress_block_size = 64 * 1024; // default value + + static SplitOptions Defaults(); +}; + +enum ShuffleTypeId : int { + SHUFFLE_1BYTE = 0, + SHUFFLE_2BYTE = 1, + SHUFFLE_4BYTE = 2, + SHUFFLE_8BYTE = 3, + SHUFFLE_DECIMAL128 = 4, + SHUFFLE_BIT = 5, + SHUFFLE_BINARY = 6, + SHUFFLE_LARGE_BINARY = 7, + SHUFFLE_NULL = 8, + NUM_TYPES = 9, + SHUFFLE_NOT_IMPLEMENTED = 10 +}; + +struct InputDataTypes { + int32_t *inputVecTypeIds = nullptr; + uint32_t *inputDataPrecisions = nullptr; + uint32_t *inputDataScales = nullptr; +}; + +#endif //CPP_TYPE_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/utils.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..b7c523cb9c491483ae33d410554c97b66a894dad --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/utils.h @@ -0,0 +1,134 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THESTRAL_PLUGIN_MASTER_UTILS_H +#define THESTRAL_PLUGIN_MASTER_UTILS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +constexpr char kSep = '/'; + +static std::string GenerateUUID() { + boost::uuids::random_generator generator; + return boost::uuids::to_string(generator()); +} + +std::string MakeRandomName(int num_chars) { + static const std::string chars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + std::string resBuf = ""; + char tmp; + for (int i = 0; i < num_chars; i++) { + tmp = chars[random() % (chars.length())]; + resBuf += tmp; + } + return resBuf; +} + +std::string MakeTemporaryDir(const std::string& prefix) { + const int kNumChars = 8; + std::string suffix = MakeRandomName(kNumChars); + return prefix + suffix; +} + +std::vector GetConfiguredLocalDirs() { + auto joined_dirs_c = std::getenv("NATIVESQL_SPARK_LOCAL_DIRS"); + std::vector res; + if (joined_dirs_c != nullptr && strcmp(joined_dirs_c, "") > 0) { + auto joined_dirs = std::string(joined_dirs_c); + std::string delimiter = ","; + + size_t pos; + while ((pos = joined_dirs.find(delimiter)) != std::string::npos) { + auto dir = joined_dirs.substr(0, pos); + if (dir.length() > 0) { + res.push_back(std::move(dir)); + } + joined_dirs.erase(0, pos + delimiter.length()); + } + if (joined_dirs.length() > 0) { + res.push_back(std::move(joined_dirs)); + } + return res; + } else { + auto omni_tmp_dir = MakeTemporaryDir("columnar-shuffle-"); + if (!IsFileExist(omni_tmp_dir.c_str())) { + mkdir(omni_tmp_dir.c_str(), S_IRWXU|S_IRWXG|S_IROTH|S_IXOTH); + } + return std::vector{omni_tmp_dir}; + } +} + +std::string EnsureTrailingSlash(const std::string& v) { + if (v.length() > 0 && v.back() != kSep) { + // XXX How about "C:" on Windows? We probably don't want to turn it into "C:/"... + // Unless the local filesystem always uses absolute paths + return std::string(v) + kSep; + } else { + return std::string(v); + } +} + +std::string RemoveLeadingSlash(std::string key) { + while (!key.empty() && key.front() == kSep) { + key.erase(0); + } + return key; +} + +std::string ConcatAbstractPath(const std::string& base, const std::string& stem) { + if(stem.empty()) { + throw std::runtime_error("stem empty! "); + } + + if (base.empty()) { + return stem; + } + return EnsureTrailingSlash(base) + std::string(RemoveLeadingSlash(stem)); +} + +std::string GetSpilledShuffleFileDir(const std::string& configured_dir, + int32_t sub_dir_id) { + std::stringstream ss; + ss << std::setfill('0') << std::setw(2) << std::hex << sub_dir_id; + auto dir = ConcatAbstractPath(configured_dir, "shuffle_" + ss.str()); + return dir; +} + +std::string CreateTempShuffleFile(const std::string& dir) { + if (dir.length() == 0) { + throw std::runtime_error("CreateTempShuffleFile failed!"); + } + + if (!IsFileExist(dir.c_str())) { + mkdir(dir.c_str(), S_IRWXU|S_IRWXG|S_IROTH|S_IXOTH); + } + + std::string file_path = ConcatAbstractPath(dir, "temp_shuffle_" + GenerateUUID()); + return file_path; +} + +#endif //THESTRAL_PLUGIN_MASTER_UTILS_H diff --git a/omnioperator/omniop-spark-extension/cpp/src/utils/macros.h b/omnioperator/omniop-spark-extension/cpp/src/utils/macros.h new file mode 100644 index 0000000000000000000000000000000000000000..7c6deca18aab90690c05944ee5225094f5d5dc1e --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/utils/macros.h @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_THESTRAL_PLUGIN_MACROS_H +#define SPARK_THESTRAL_PLUGIN_MACROS_H + +#pragma once + +#include +#include + +#define TIME_NANO_OR_RAISE(time, expr) \ + do { \ + auto start = std::chrono::steady_clock::now(); \ + (expr); \ + auto end = std::chrono::steady_clock::now(); \ + time += std::chrono::duration_cast(end - start).count(); \ + } while (false); + +#endif //SPARK_THESTRAL_PLUGIN_MACROS_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a06358d823aa333e4abcf829c14078ac6228de86 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt @@ -0,0 +1,45 @@ +add_subdirectory(shuffle) +add_subdirectory(utils) +add_subdirectory(tablescan) + +# configure +set(TP_TEST_TARGET tptest) +set(MY_LINK + utilstest + shuffletest + tablescantest + ) + +# find gtest package +find_package(GTest REQUIRED) + +set (UT_FILES + tptest.cpp + shuffle/shuffle_test.cpp + tablescan/scan_test.cpp + ) + +# compile a executable file +add_executable(${TP_TEST_TARGET} ${UT_FILES}) +add_dependencies(${TP_TEST_TARGET} ${MY_LINK}) + +# dependent libraries +target_link_libraries(${TP_TEST_TARGET} + ${GTEST_BOTH_LIBRARIES} + ${MY_LINK} + gtest + pthread + stdc++ + dl + boostkit-omniop-runtime-1.0.0-aarch64 + boostkit-omniop-vector-1.0.0-aarch64 + securec + spark_columnar_plugin) + +target_compile_options(${TP_TEST_TARGET} PUBLIC -g -O2 -fPIC) + +# dependent include +target_include_directories(${TP_TEST_TARGET} PRIVATE ${GTEST_INCLUDE_DIRS}) + +# discover tests +gtest_discover_tests(${TP_TEST_TARGET}) diff --git a/omnioperator/omniop-spark-extension/cpp/test/shuffle/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/test/shuffle/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ff8bf512a412063f3adb32e6bf02f3fd672ea797 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/test/shuffle/CMakeLists.txt @@ -0,0 +1,10 @@ +aux_source_directory(${CMAKE_CURRENT_LIST_DIR} SHUFFLE_TESTS_LIST) +set(SHUFFLE_TEST_TARGET shuffletest) +add_library(${SHUFFLE_TEST_TARGET} STATIC ${SHUFFLE_TESTS_LIST}) +target_compile_options(${SHUFFLE_TEST_TARGET} PUBLIC ) +target_include_directories(${SHUFFLE_TEST_TARGET} PUBLIC ${CMAKE_BINARY_DIR}/src) +target_include_directories(${SHUFFLE_TEST_TARGET} PUBLIC /opt/lib/include) +target_include_directories(${SHUFFLE_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include) +target_include_directories(${SHUFFLE_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) + + diff --git a/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp b/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f67d9d4742e77f03aafb887528351a423876147c --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp @@ -0,0 +1,464 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gtest/gtest.h" +#include "../utils/test_utils.h" + +static std::string tmpTestingDir; +static std::string tmpShuffleFilePath; + +class ShuffleTest : public testing::Test { +protected: + + // run before first case... + static void SetUpTestSuite() { + tmpTestingDir = s_shuffle_tests_dir; + if (!IsFileExist(tmpTestingDir)) { + mkdir(tmpTestingDir.c_str(), S_IRWXU|S_IRWXG|S_IROTH|S_IXOTH); + } + } + + // run after last case... + static void TearDownTestSuite() { + if (IsFileExist(tmpTestingDir)) { + DeletePathAll(tmpTestingDir.c_str()); + } + } + + // run before each case... + virtual void SetUp() override { + } + + // run after each case... + virtual void TearDown() override { + if (IsFileExist(tmpShuffleFilePath)) { + remove(tmpShuffleFilePath.c_str()); + } + } + +}; + +TEST_F (ShuffleTest, Split_SingleVarChar) { + tmpShuffleFilePath = tmpTestingDir + "/shuffle_split_SingleVarChar"; + int32_t inputVecTypeIds[] = {OMNI_VARCHAR}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int splitterId = Test_splitter_nativeMake("hash", + 4, + inputDataTypes, + colNumber, + 1024, + "lz4", + tmpShuffleFilePath, + 0, + tmpTestingDir); + VectorBatch* vb1 = CreateVectorBatch_1row_varchar_withPid(3, "N"); + Test_splitter_split(splitterId, vb1); + VectorBatch* vb2 = CreateVectorBatch_1row_varchar_withPid(2, "F"); + Test_splitter_split(splitterId, vb2); + VectorBatch* vb3 = CreateVectorBatch_1row_varchar_withPid(3, "N"); + Test_splitter_split(splitterId, vb3); + VectorBatch* vb4 = CreateVectorBatch_1row_varchar_withPid(2, "F"); + Test_splitter_split(splitterId, vb4); + VectorBatch* vb5 = CreateVectorBatch_1row_varchar_withPid(2, "F"); + Test_splitter_split(splitterId, vb5); + VectorBatch* vb6 = CreateVectorBatch_1row_varchar_withPid(1, "R"); + Test_splitter_split(splitterId, vb6); + VectorBatch* vb7 = CreateVectorBatch_1row_varchar_withPid(3,"N"); + Test_splitter_split(splitterId, vb7); + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; +} + +TEST_F (ShuffleTest, Split_Fixed_Cols) { + tmpShuffleFilePath = tmpTestingDir + "/shuffle_split_fixed_cols"; + int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int partitionNum = 4; + int splitterId = Test_splitter_nativeMake("hash", + partitionNum, + inputDataTypes, + colNumber, + 4096, + "lz4", + tmpShuffleFilePath, + 0, + tmpTestingDir); + for (uint64_t j = 0; j < 999; j++) { + VectorBatch* vb = CreateVectorBatch_3fixedCols_withPid(partitionNum, 999); + Test_splitter_split(splitterId, vb); + } + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; +} + +TEST_F (ShuffleTest, Split_Fixed_SinglePartition_SomeNullRow) { + tmpShuffleFilePath = tmpTestingDir + "/shuffle_split_fixed_singlePartition_someNullRow"; + int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int partitionNum = 1; + int splitterId = Test_splitter_nativeMake("single", + partitionNum, + inputDataTypes, + colNumber, + 4096, + "lz4", + tmpShuffleFilePath, + 0, + tmpTestingDir); + for (uint64_t j = 0; j < 100; j++) { + VectorBatch* vb = CreateVectorBatch_someNullRow_vectorBatch(); + Test_splitter_split(splitterId, vb); + } + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; +} + +TEST_F (ShuffleTest, Split_Fixed_SinglePartition_SomeNullCol) { + tmpShuffleFilePath = tmpTestingDir + "/shuffle_split_fixed_singlePartition_someNullCol"; + int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int partitionNum = 1; + int splitterId = Test_splitter_nativeMake("single", + partitionNum, + inputDataTypes, + colNumber, + 4096, + "lz4", + tmpShuffleFilePath, + 0, + tmpTestingDir); + for (uint64_t j = 0; j < 100; j++) { + VectorBatch* vb = CreateVectorBatch_someNullCol_vectorBatch(); + Test_splitter_split(splitterId, vb); + } + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; +} + +TEST_F (ShuffleTest, Split_Compression_None) { + Test_Shuffle_Compression("uncompressed", 4, 999, 999); +} + +TEST_F (ShuffleTest, Split_Compression_zstd) { + Test_Shuffle_Compression("zstd", 4, 999, 999); +} + +TEST_F (ShuffleTest, Split_Compression_Lz4) { + Test_Shuffle_Compression("lz4", 4, 999, 999); +} + +TEST_F (ShuffleTest, Split_Compression_Snappy) { + Test_Shuffle_Compression("snappy", 4, 999, 999); +} + +TEST_F (ShuffleTest, Split_Compression_Zlib) { + Test_Shuffle_Compression("zlib", 4, 999, 999); +} + +TEST_F (ShuffleTest, Split_Mix_LargeSize) { + tmpShuffleFilePath = tmpTestingDir + "/shuffle_split_mix_largeSize"; + int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int partitionNum = 4; + int splitterId = Test_splitter_nativeMake("hash", + partitionNum, + inputDataTypes, + colNumber, + 4096, + "lz4", + tmpShuffleFilePath, + 0, + tmpTestingDir); + for (uint64_t j = 0; j < 999; j++) { + VectorBatch* vb = CreateVectorBatch_4col_withPid(partitionNum, 999); + Test_splitter_split(splitterId, vb); + } + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; +} + +TEST_F (ShuffleTest, Split_Long_10WRows) { + tmpShuffleFilePath = tmpTestingDir + "/shuffle_split_long_10WRows"; + int32_t inputVecTypeIds[] = {OMNI_LONG}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int partitionNum = 10; + int splitterId = Test_splitter_nativeMake("hash", + partitionNum, + inputDataTypes, + colNumber, + 4096, + "lz4", + tmpShuffleFilePath, + 0, + tmpTestingDir); + for (uint64_t j = 0; j < 100; j++) { + VectorBatch* vb = CreateVectorBatch_1longCol_withPid(partitionNum, 10000); + Test_splitter_split(splitterId, vb); + } + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; +} + +TEST_F (ShuffleTest, Split_VarChar_LargeSize) { + tmpShuffleFilePath = tmpTestingDir + "/shuffle_varchar_largeSize"; + int32_t inputVecTypeIds[] = {OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int partitionNum = 4; + int splitterId = Test_splitter_nativeMake("hash", + partitionNum, + inputDataTypes, + colNumber, + 4096, + "lz4", + tmpShuffleFilePath, + 64, + tmpTestingDir); + for (uint64_t j = 0; j < 99; j++) { + VectorBatch* vb = CreateVectorBatch_4varcharCols_withPid(partitionNum, 99); + Test_splitter_split(splitterId, vb); + } + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; +} + +TEST_F (ShuffleTest, Split_VarChar_First) { + tmpShuffleFilePath = tmpTestingDir + "/shuffle_split_varchar_first"; + int32_t inputVecTypeIds[] = {OMNI_VARCHAR, OMNI_INT}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int partitionNum = 4; + int splitterId = Test_splitter_nativeMake("hash", + partitionNum, + inputDataTypes, + colNumber, + 4096, + "lz4", + tmpShuffleFilePath, + 0, + tmpTestingDir); + VectorBatch* vb0 = CreateVectorBatch_2column_1row_withPid(0, "corpbrand #4", 1); + Test_splitter_split(splitterId, vb0); + VectorBatch* vb1 = CreateVectorBatch_2column_1row_withPid(3, "brandmaxi #4", 1); + Test_splitter_split(splitterId, vb1); + VectorBatch* vb2 = CreateVectorBatch_2column_1row_withPid(1, "edu packnameless #9", 1); + Test_splitter_split(splitterId, vb2); + VectorBatch* vb3 = CreateVectorBatch_2column_1row_withPid(1, "amalgunivamalg #11", 1); + Test_splitter_split(splitterId, vb3); + VectorBatch* vb4 = CreateVectorBatch_2column_1row_withPid(0, "brandcorp #2", 1); + Test_splitter_split(splitterId, vb4); + VectorBatch* vb5 = CreateVectorBatch_2column_1row_withPid(0, "scholarbrand #2", 1); + Test_splitter_split(splitterId, vb5); + VectorBatch* vb6 = CreateVectorBatch_2column_1row_withPid(2, "edu packcorp #6", 1); + Test_splitter_split(splitterId, vb6); + VectorBatch* vb7 = CreateVectorBatch_2column_1row_withPid(2, "edu packamalg #1", 1); + Test_splitter_split(splitterId, vb7); + VectorBatch* vb8 = CreateVectorBatch_2column_1row_withPid(0, "brandnameless #8", 1); + Test_splitter_split(splitterId, vb8); + VectorBatch* vb9 = CreateVectorBatch_2column_1row_withPid(2, "univmaxi #2", 1); + Test_splitter_split(splitterId, vb9); + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; +} + +TEST_F (ShuffleTest, Split_Dictionary) { + tmpShuffleFilePath = tmpTestingDir + "/shuffle_split_dictionary"; + int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int partitionNum = 4; + int splitterId = Test_splitter_nativeMake("hash", + partitionNum, + inputDataTypes, + colNumber, + 4096, + "lz4", + tmpShuffleFilePath, + 0, + tmpTestingDir); + for (uint64_t j = 0; j < 2; j++) { + VectorBatch* vb = CreateVectorBatch_2dictionaryCols_withPid(partitionNum); + Test_splitter_split(splitterId, vb); + } + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; +} + +TEST_F (ShuffleTest, Split_Char) { + tmpShuffleFilePath = tmpTestingDir + "/shuffle_char_largeSize"; + int32_t inputVecTypeIds[] = {OMNI_CHAR, OMNI_CHAR, OMNI_CHAR, OMNI_CHAR}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int partitionNum = 4; + int splitterId = Test_splitter_nativeMake("hash", + partitionNum, + inputDataTypes, + colNumber, + 4096, + "lz4", + tmpShuffleFilePath, + 64, + tmpTestingDir); + for (uint64_t j = 0; j < 99; j++) { + VectorBatch* vb = CreateVectorBatch_4charCols_withPid(partitionNum, 99); + Test_splitter_split(splitterId, vb); + } + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; +} + +TEST_F (ShuffleTest, Split_Decimal128) { + tmpShuffleFilePath = tmpTestingDir + "/shuffle_decimal128_test"; + int32_t inputVecTypeIds[] = {OMNI_DECIMAL128}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int partitionNum = 4; + int splitterId = Test_splitter_nativeMake("hash", + partitionNum, + inputDataTypes, + colNumber, + 4096, + "lz4", + tmpShuffleFilePath, + 0, + tmpTestingDir); + for (uint64_t j = 0; j < 999; j++) { + VectorBatch* vb = CreateVectorBatch_1decimal128Col_withPid(partitionNum, 999); + Test_splitter_split(splitterId, vb); + } + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; +} + +TEST_F (ShuffleTest, Split_Decimal64) { + tmpShuffleFilePath = tmpTestingDir + "/shuffle_decimal64_test"; + int32_t inputVecTypeIds[] = {OMNI_DECIMAL64}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int partitionNum = 4; + int splitterId = Test_splitter_nativeMake("hash", + partitionNum, + inputDataTypes, + colNumber, + 4096, + "lz4", + tmpShuffleFilePath, + 0, + tmpTestingDir); + for (uint64_t j = 0; j < 999; j++) { + VectorBatch* vb = CreateVectorBatch_1decimal64Col_withPid(partitionNum, 999); + Test_splitter_split(splitterId, vb); + } + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; +} + +TEST_F (ShuffleTest, Split_Decimal64_128) { + tmpShuffleFilePath = tmpTestingDir + "/shuffle_decimal64_128"; + int32_t inputVecTypeIds[] = {OMNI_DECIMAL64, OMNI_DECIMAL128}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int partitionNum = 4; + int splitterId = Test_splitter_nativeMake("hash", + partitionNum, + inputDataTypes, + colNumber, + 4096, + "lz4", + tmpShuffleFilePath, + 0, + tmpTestingDir); + for (uint64_t j = 0; j < 999; j++) { + VectorBatch* vb = CreateVectorBatch_2decimalCol_withPid(partitionNum, 999); + Test_splitter_split(splitterId, vb); + } + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/test/tablescan/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ad201c2a601781fecad4bee63cfd2ba57f80a880 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/test/tablescan/CMakeLists.txt @@ -0,0 +1,13 @@ +set(MAIN_PATH ${CMAKE_CURRENT_SOURCE_DIR}) +configure_file(scan_test.h.in ${CMAKE_CURRENT_SOURCE_DIR}/scan_test.h) + +aux_source_directory(${CMAKE_CURRENT_LIST_DIR} SCAN_TESTS_LIST) +set(SCAN_TEST_TARGET tablescantest) +add_library(${SCAN_TEST_TARGET} STATIC ${SCAN_TESTS_LIST}) +target_compile_options(${SCAN_TEST_TARGET} PUBLIC ) + +target_include_directories(${SCAN_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include) +target_include_directories(${SCAN_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) +target_include_directories(${SCAN_TEST_TARGET} PUBLIC /opt/lib/include) + + diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/resources/orc_data_all_type b/omnioperator/omniop-spark-extension/cpp/test/tablescan/resources/orc_data_all_type new file mode 100644 index 0000000000000000000000000000000000000000..9cc57fa78ccdae728d2d902f587c30c337b0e4a5 Binary files /dev/null and b/omnioperator/omniop-spark-extension/cpp/test/tablescan/resources/orc_data_all_type differ diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp b/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3cfff6a9fa5e0611b5d478b2caa32c70428b4871 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp @@ -0,0 +1,163 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gtest/gtest.h" +#include +#include +#include "../../src/jni/OrcColumnarBatchJniReader.h" +#include "scan_test.h" +#include "orc/sargs/SearchArgument.hh" + +static std::string filename = "/resources/orc_data_all_type"; +static orc::ColumnVectorBatch *batchPtr; + +/* + * CREATE TABLE `orc_test` ( `c1` int, `c2` varChar(60), `c3` string, `c4` bigint, + * `c5` char(40), `c6` float, `c7` double, `c8` decimal(9,8), `c9` decimal(18,5), + * `c10` boolean, `c11` smallint, `c12` timestamp, `c13` date)stored as orc; + * + * insert into `orc_test` values (10, "varchar_1", "string_type_1", 10000, "char_1", + * 11.11, 1111.1111, 121.1111, 131.1111, true, 11, '2021-12-01 01:00:11', '2021-12-01'); + * insert into `orc_test` values (20, "varchar_2", NULL, 20000, "char_2", + * 11.22, 1111.2222, 121.2222, 131.2222, true, 12, '2021-12-01 01:22:11', '2021-12-02'); + * insert into `orc_test` values (30, "varchar_3", "string_type_3", NULL, "char_2", + * 11.33, 1111.333, 121.3333, 131.2222, NULL, 13, '2021-12-01 01:33:11', '2021-12-03'); + * insert into `orc_test` values (40, "varchar_4", "string_type_4", 40000, NULL, + * 11.44, NULL, 121.2222, 131.44, false, 14, '2021-12-01 01:44:11', '2021-12-04'); + * insert into `orc_test` values (50, "varchar_5", "string_type_5", 50000, "char_5", + * 11.55, 1111.55, 121.55, 131.55, true, 15, '2021-12-01 01:55:11', '2021-12-05'); + * + */ +class ScanTest : public testing::Test { +protected: + // run before each case... + virtual void SetUp() override + { + orc::ReaderOptions readerOpts; + orc::RowReaderOptions rowReaderOptions; + std::unique_ptr reader = orc::createReader(orc::readFile(PROJECT_PATH + filename), readerOpts); + std::unique_ptr rowReader = reader->createRowReader(); + std::unique_ptr batch = rowReader->createRowBatch(4096); + rowReader->next(*batch); + batchPtr = batch.release(); + } + + // run after each case... + virtual void TearDown() override { + delete batchPtr; + } +}; + +TEST_F(ScanTest, test_get_literal) +{ + orc::Literal tmpLit(0L); + // test get long + getLiteral(tmpLit, 0, "123456789"); + ASSERT_EQ(tmpLit.toString() == "123456789", true); + + // test get string + getLiteral(tmpLit, 2, "testStringForLit"); + ASSERT_EQ(tmpLit.toString() == "testStringForLit", true); + + // test get date + getLiteral(tmpLit, 3, "987654321"); + ASSERT_EQ(tmpLit.toString() == "987654321", true); +} + +TEST_F(ScanTest, test_copy_vec) +{ + orc::StructVectorBatch *root = static_cast(batchPtr); + int omniType = 0; + uint64_t ominVecId = 0; + // int type + copyToOminVec(0, 3, omniType, ominVecId, root->fields[0]); + ASSERT_EQ(omniType == 1, true); + omniruntime::vec::IntVector *olbInt = (omniruntime::vec::IntVector *)(ominVecId); + ASSERT_EQ(olbInt->GetValue(0) == 10, true); + delete olbInt; + + // varchar type + copyToOminVec(60, 16, omniType, ominVecId, root->fields[1]); + ASSERT_EQ(omniType == 15, true); + uint8_t *actualChar = nullptr; + omniruntime::vec::VarcharVector * olbVc = (omniruntime::vec::VarcharVector *)(ominVecId); + int len = olbVc->GetValue(0, &actualChar); + std::string actualStr(reinterpret_cast(actualChar), 0, len); + ASSERT_EQ(actualStr == "varchar_1", true); + delete olbVc; + + // string type + copyToOminVec(0, 7, omniType, ominVecId, root->fields[2]); + ASSERT_EQ(omniType == 15, true); + omniruntime::vec::VarcharVector *olbStr = (omniruntime::vec::VarcharVector *)(ominVecId); + len = olbStr->GetValue(0, &actualChar); + std::string actualStr2(reinterpret_cast(actualChar), 0, len); + ASSERT_EQ(actualStr2 == "string_type_1", true); + delete olbStr; + + // bigint type + copyToOminVec(0, 4, omniType, ominVecId, root->fields[3]); + ASSERT_EQ(omniType == 2, true); + omniruntime::vec::LongVector *olbLong = (omniruntime::vec::LongVector *)(ominVecId); + ASSERT_EQ(olbLong->GetValue(0) == 10000, true); + delete olbLong; + + // char type + copyToOminVec(40, 17, omniType, ominVecId, root->fields[4]); + ASSERT_EQ(omniType == 15, true); + omniruntime::vec::VarcharVector *olbChar40 = (omniruntime::vec::VarcharVector *)(ominVecId); + len = olbChar40->GetValue(0, &actualChar); + std::string actualStr3(reinterpret_cast(actualChar), 0, len); + ASSERT_EQ(actualStr3 == "char_1", true); + delete olbChar40; +} + +TEST_F(ScanTest, test_build_leafs) +{ + int leafOp = 0; + std::vector litList; + std::string leafNameString; + int leafType = 0; + std::unique_ptr builder = orc::SearchArgumentFactory::newBuilder(); + (*builder).startAnd(); + orc::Literal lit(100L); + + + // test equal + buildLeafs(0, litList, lit, "leaf-0", 0, *builder); + + // test LESS_THAN + buildLeafs(2, litList, lit, "leaf-1", 0, *builder); + + // test LESS_THAN_EQUALS + buildLeafs(3, litList, lit, "leaf-1", 0, *builder); + + // test NULL_SAFE_EQUALS + buildLeafs(1, litList, lit, "leaf-1", 0, *builder); + + // test IS_NULL + buildLeafs(6, litList, lit, "leaf-1", 0, *builder); + + std::string result = ((*builder).end().build())->toString(); + std::string buildString = + "leaf-0 = (leaf-0 = 100), leaf-1 = (leaf-1 < 100), leaf-2 = (leaf-1 <= 100), leaf-3 = (leaf-1 null_safe_= " + "100), leaf-4 = (leaf-1 is null), expr = (and leaf-0 leaf-1 leaf-2 leaf-3 leaf-4)"; + + ASSERT_EQ(buildString == result, true); +} diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.h.in b/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.h.in new file mode 100644 index 0000000000000000000000000000000000000000..5ca616ec499c349478cb839213a4eb7bb289439c --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.h.in @@ -0,0 +1 @@ +#define PROJECT_PATH "@MAIN_PATH@" \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/test/tptest.cpp b/omnioperator/omniop-spark-extension/cpp/test/tptest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..321631487e58d3b4f64baaa1a63b5fc25238a721 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/test/tptest.cpp @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "gtest/gtest.h" + +int main(int argc, char **argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/test/utils/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/test/utils/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0865325fc7c231363d6a09fb937e047f1bd0be5f --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/test/utils/CMakeLists.txt @@ -0,0 +1,8 @@ +aux_source_directory(${CMAKE_CURRENT_LIST_DIR} UTILS_TESTS_LIST) +set(UTILS_TEST_TARGET utilstest) +add_library(${UTILS_TEST_TARGET} ${UTILS_TESTS_LIST}) +target_include_directories(${UTILS_TEST_TARGET} PUBLIC /opt/lib/include) +target_include_directories(${UTILS_TEST_TARGET} PUBLIC ${CMAKE_BINARY_DIR}/src) +target_include_directories(${UTILS_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include) +target_include_directories(${UTILS_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) + diff --git a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f7458566a44e7b4f829b4da6163788d060c0c9d --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp @@ -0,0 +1,727 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "test_utils.h" + +using namespace omniruntime::vec; + +void ToVectorTypes(const int32_t *dataTypeIds, int32_t dataTypeCount, std::vector &dataTypes) +{ + for (int i = 0; i < dataTypeCount; ++i) { + if (dataTypeIds[i] == OMNI_VARCHAR) { + dataTypes.push_back(VarcharDataType(50)); + continue; + } else if (dataTypeIds[i] == OMNI_CHAR) { + dataTypes.push_back(CharDataType(50)); + continue; + } + dataTypes.push_back(DataType(dataTypeIds[i])); + } +} + +VectorBatch* CreateInputData(const int32_t numRows, + const int32_t numCols, + int32_t* inputTypeIds, + int64_t* allData) +{ + auto *vecBatch = new VectorBatch(numCols, numRows); + vector inputTypes; + ToVectorTypes(inputTypeIds, numCols, inputTypes); + vecBatch->NewVectors(omniruntime::vec::GetProcessGlobalVecAllocator(), inputTypes); + for (int i = 0; i < numCols; ++i) { + switch (inputTypeIds[i]) { + case OMNI_INT: + ((IntVector *)vecBatch->GetVector(i))->SetValues(0, (int32_t *)allData[i], numRows); + break; + case OMNI_LONG: + ((LongVector *)vecBatch->GetVector(i))->SetValues(0, (int64_t *)allData[i], numRows); + break; + case OMNI_DOUBLE: + ((DoubleVector *)vecBatch->GetVector(i))->SetValues(0, (double *)allData[i], numRows); + break; + case OMNI_SHORT: + ((IntVector *)vecBatch->GetVector(i))->SetValues(0, (int32_t *)allData[i], numRows); + break; + case OMNI_VARCHAR: + case OMNI_CHAR: { + for (int j = 0; j < numRows; ++j) { + int64_t addr = (reinterpret_cast(allData[i]))[j]; + std::string s (reinterpret_cast(addr)); + ((VarcharVector *)vecBatch->GetVector(i))->SetValue(j, (uint8_t *)(s.c_str()), s.length()); + } + break; + } + case OMNI_DECIMAL128: + ((Decimal128Vector *)vecBatch->GetVector(i))->SetValues(0, (int64_t *) allData[i], numRows); + break; + default:{ + LogError("No such data type %d", inputTypeIds[i]); + } + } + } + return vecBatch; +} + +VarcharVector *CreateVarcharVector(VarcharDataType type, std::string *values, int32_t length) +{ + VectorAllocator * vecAllocator = omniruntime::vec::GetProcessGlobalVecAllocator(); + uint32_t width = type.GetWidth(); + VarcharVector *vector = std::make_unique(vecAllocator, length * width, length).release(); + for (int32_t i = 0; i < length; i++) { + vector->SetValue(i, reinterpret_cast(values[i].c_str()), values[i].length()); + } + return vector; +} + +Decimal128Vector *CreateDecimal128Vector(Decimal128 *values, int32_t length) +{ + VectorAllocator *vecAllocator = omniruntime::vec::GetProcessGlobalVecAllocator(); + Decimal128Vector *vector = std::make_unique(vecAllocator, length).release(); + for (int32_t i = 0; i < length; i++) { + vector->SetValue(i, values[i]); + } + return vector; +} + +Vector *CreateVector(DataType &vecType, int32_t rowCount, va_list &args) +{ + switch (vecType.GetId()) { + case OMNI_INT: + case OMNI_DATE32: + return CreateVector(va_arg(args, int32_t *), rowCount); + case OMNI_LONG: + case OMNI_DECIMAL64: + return CreateVector(va_arg(args, int64_t *), rowCount); + case OMNI_DOUBLE: + return CreateVector(va_arg(args, double *), rowCount); + case OMNI_BOOLEAN: + return CreateVector(va_arg(args, bool *), rowCount); + case OMNI_VARCHAR: + case OMNI_CHAR: + return CreateVarcharVector(static_cast(vecType), va_arg(args, std::string *), rowCount); + case OMNI_DECIMAL128: + return CreateDecimal128Vector(va_arg(args, Decimal128 *), rowCount); + default: + std::cerr << "Unsupported type : " << vecType.GetId() << std::endl; + return nullptr; + } +} + +DictionaryVector *CreateDictionaryVector(DataType &vecType, int32_t rowCount, int32_t *ids, int32_t idsCount, ...) +{ + va_list args; + va_start(args, idsCount); + Vector *dictionary = CreateVector(vecType, rowCount, args); + va_end(args); + auto vec = std::make_unique(dictionary, ids, idsCount).release(); + delete dictionary; + return vec; +} + +Vector *buildVector(const DataType &aggType, int32_t rowNumber) +{ + VectorAllocator *vecAllocator = omniruntime::vec::GetProcessGlobalVecAllocator(); + switch (aggType.GetId()) { + case OMNI_NONE: { + LongVector *col = new LongVector(vecAllocator, rowNumber); + for (int32_t j = 0; j < rowNumber; ++j) { + col->SetValueNull(j); + } + return col; + } + case OMNI_INT: + case OMNI_DATE32: { + IntVector *col = new IntVector(vecAllocator, rowNumber); + for (int32_t j = 0; j < rowNumber; ++j) { + col->SetValue(j, 1); + } + return col; + } + case OMNI_LONG: + case OMNI_DECIMAL64: { + LongVector *col = new LongVector(vecAllocator, rowNumber); + for (int32_t j = 0; j < rowNumber; ++j) { + col->SetValue(j, 1); + } + return col; + } + case OMNI_DOUBLE: { + DoubleVector *col = new DoubleVector(vecAllocator, rowNumber); + for (int32_t j = 0; j < rowNumber; ++j) { + col->SetValue(j, 1); + } + return col; + } + case OMNI_BOOLEAN: { + BooleanVector *col = new BooleanVector(vecAllocator, rowNumber); + for (int32_t j = 0; j < rowNumber; ++j) { + col->SetValue(j, 1); + } + return col; + } + case OMNI_DECIMAL128: { + Decimal128Vector *col = new Decimal128Vector(vecAllocator, rowNumber); + for (int32_t j = 0; j < rowNumber; ++j) { + col->SetValue(j, Decimal128(0, 1)); + } + return col; + } + case OMNI_VARCHAR: + case OMNI_CHAR: { + VarcharDataType charType = (VarcharDataType &)aggType; + VarcharVector *col = new VarcharVector(vecAllocator, charType.GetWidth() * rowNumber, rowNumber); + for (int32_t j = 0; j < rowNumber; ++j) { + std::string str = std::to_string(j); + col->SetValue(j, reinterpret_cast(str.c_str()), str.size()); + } + return col; + } + default: { + LogError("No such %d type support", aggType.GetId()); + return nullptr; + } + } +} + +VectorBatch *CreateVectorBatch(DataTypes &types, int32_t rowCount, ...) +{ + int32_t typesCount = types.GetSize(); + VectorBatch *vectorBatch = std::make_unique(typesCount).release(); + va_list args; + va_start(args, rowCount); + for (int32_t i = 0; i < typesCount; i++) { + DataType type = types.Get()[i]; + vectorBatch->SetVector(i, CreateVector(type, rowCount, args)); + } + va_end(args); + return vectorBatch; +} + +/** + * create a VectorBatch with 1 col 1 row varchar value and it's partition id + * + * @param {int} pid partition id for this row + * @param {string} inputString varchar row value + * @return {VectorBatch} a VectorBatch + */ +VectorBatch* CreateVectorBatch_1row_varchar_withPid(int pid, std::string inputString) { + // gen vectorBatch + const int32_t numCols = 2; + int32_t* inputTypes = new int32_t[numCols]; + inputTypes[0] = OMNI_INT; + inputTypes[1] = OMNI_VARCHAR; + const int32_t numRows = 1; + auto* col1 = new int32_t[numRows]; + col1[0] = pid; + auto* col2 = new int64_t[numRows]; + std::string* strTmp = new std::string(inputString); + col2[0] = (int64_t)(strTmp->c_str()); + + int64_t allData[numCols] = {reinterpret_cast(col1), + reinterpret_cast(col2)}; + VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); + delete[] inputTypes; + delete[] col1; + delete[] col2; + delete strTmp; + return in; +} + +/** + * create a VectorBatch with 4col OMNI_INT OMNI_LONG OMNI_DOUBLE OMNI_VARCHAR and it's partition id + * + * @param {int} parNum partition number + * @param {int} rowNum row number + * @return {VectorBatch} a VectorBatch + */ +VectorBatch* CreateVectorBatch_4col_withPid(int parNum, int rowNum) { + int partitionNum = parNum; + const int32_t numCols = 5; + int32_t* inputTypes = new int32_t[numCols]; + inputTypes[0] = OMNI_INT; + inputTypes[1] = OMNI_INT; + inputTypes[2] = OMNI_LONG; + inputTypes[3] = OMNI_DOUBLE; + inputTypes[4] = OMNI_VARCHAR; + + const int32_t numRows = rowNum; + auto* col0 = new int32_t[numRows]; + auto* col1 = new int32_t[numRows]; + auto* col2 = new int64_t[numRows]; + auto* col3 = new double[numRows]; + auto* col4 = new int64_t[numRows]; + string startStr = "_START_"; + string endStr = "_END_"; + + std::vector string_cache_test_; + for (int i = 0; i < numRows; i++) { + col0[i] = (i+1) % partitionNum; + col1[i] = i + 1; + col2[i] = i + 1; + col3[i] = i + 1; + std::string* strTmp = new std::string(startStr + to_string(i + 1) + endStr); + string_cache_test_.push_back(strTmp); + col4[i] = (int64_t)((*strTmp).c_str()); + } + + int64_t allData[numCols] = {reinterpret_cast(col0), + reinterpret_cast(col1), + reinterpret_cast(col2), + reinterpret_cast(col3), + reinterpret_cast(col4)}; + VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); + delete[] inputTypes; + delete[] col0; + delete[] col1; + delete[] col2; + delete[] col3; + delete[] col4; + + for (int p = 0; p < string_cache_test_.size(); p++) { + delete string_cache_test_[p]; // release memory + } + return in; +} + +VectorBatch* CreateVectorBatch_1longCol_withPid(int parNum, int rowNum) { + int partitionNum = parNum; + const int32_t numCols = 2; + int32_t* inputTypes = new int32_t[numCols]; + inputTypes[0] = OMNI_INT; + inputTypes[1] = OMNI_LONG; + + const int32_t numRows = rowNum; + auto* col0 = new int32_t[numRows]; + auto* col1 = new int64_t[numRows]; + for (int i = 0; i < numRows; i++) { + col0[i] = (i+1) % partitionNum; + col1[i] = i + 1; + } + + int64_t allData[numCols] = {reinterpret_cast(col0), + reinterpret_cast(col1)}; + VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); + delete[] inputTypes; + delete[] col0; + delete[] col1; + return in; +} + +VectorBatch* CreateVectorBatch_2column_1row_withPid(int pid, std::string strVar, int intVar) { + const int32_t numCols = 3; + int32_t* inputTypes = new int32_t[numCols]; + inputTypes[0] = OMNI_INT; + inputTypes[1] = OMNI_VARCHAR; + inputTypes[2] = OMNI_INT; + + const int32_t numRows = 1; + auto* col0 = new int32_t[numRows]; + auto* col1 = new int64_t[numRows]; + auto* col2 = new int32_t[numRows]; + + col0[0] = pid; + std::string* strTmp = new std::string(strVar); + col1[0] = (int64_t)(strTmp->c_str()); + col2[0] = intVar; + + int64_t allData[numCols] = {reinterpret_cast(col0), + reinterpret_cast(col1), + reinterpret_cast(col2)}; + VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); + delete[] inputTypes; + delete[] col0; + delete[] col1; + delete[] col2; + delete strTmp; + return in; +} + +VectorBatch* CreateVectorBatch_4varcharCols_withPid(int parNum, int rowNum) { + int partitionNum = parNum; + const int32_t numCols = 5; + int32_t* inputTypes = new int32_t[numCols]; + inputTypes[0] = OMNI_INT; + inputTypes[1] = OMNI_VARCHAR; + inputTypes[2] = OMNI_VARCHAR; + inputTypes[3] = OMNI_VARCHAR; + inputTypes[4] = OMNI_VARCHAR; + + const int32_t numRows = rowNum; + auto* col0 = new int32_t[numRows]; + auto* col1 = new int64_t[numRows]; + auto* col2 = new int64_t[numRows]; + auto* col3 = new int64_t[numRows]; + auto* col4 = new int64_t[numRows]; + + std::vector string_cache_test_; + for (int i = 0; i < numRows; i++) { + col0[i] = (i+1) % partitionNum; + std::string* strTmp1 = new std::string("Col1_START_" + to_string(i + 1) + "_END_"); + col1[i] = (int64_t)((*strTmp1).c_str()); + std::string* strTmp2 = new std::string("Col2_START_" + to_string(i + 1) + "_END_"); + col2[i] = (int64_t)((*strTmp2).c_str()); + std::string* strTmp3 = new std::string("Col3_START_" + to_string(i + 1) + "_END_"); + col3[i] = (int64_t)((*strTmp3).c_str()); + std::string* strTmp4 = new std::string("Col4_START_" + to_string(i + 1) + "_END_"); + col4[i] = (int64_t)((*strTmp4).c_str()); + string_cache_test_.push_back(strTmp1); + string_cache_test_.push_back(strTmp2); + string_cache_test_.push_back(strTmp3); + string_cache_test_.push_back(strTmp4); + } + + int64_t allData[numCols] = {reinterpret_cast(col0), + reinterpret_cast(col1), + reinterpret_cast(col2), + reinterpret_cast(col3), + reinterpret_cast(col4)}; + VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); + delete[] inputTypes; + delete[] col0; + delete[] col1; + delete[] col2; + delete[] col3; + delete[] col4; + + for (int p = 0; p < string_cache_test_.size(); p++) { + delete string_cache_test_[p]; // release memory + } + return in; +} + +VectorBatch* CreateVectorBatch_4charCols_withPid(int parNum, int rowNum) { + int partitionNum = parNum; + const int32_t numCols = 5; + int32_t* inputTypes = new int32_t[numCols]; + inputTypes[0] = OMNI_INT; + inputTypes[1] = OMNI_CHAR; + inputTypes[2] = OMNI_CHAR; + inputTypes[3] = OMNI_CHAR; + inputTypes[4] = OMNI_CHAR; + + const int32_t numRows = rowNum; + auto* col0 = new int32_t[numRows]; + auto* col1 = new int64_t[numRows]; + auto* col2 = new int64_t[numRows]; + auto* col3 = new int64_t[numRows]; + auto* col4 = new int64_t[numRows]; + + std::vector string_cache_test_; + for (int i = 0; i < numRows; i++) { + col0[i] = (i+1) % partitionNum; + std::string* strTmp1 = new std::string("Col1_CHAR_" + to_string(i + 1) + "_END_"); + col1[i] = (int64_t)((*strTmp1).c_str()); + std::string* strTmp2 = new std::string("Col2_CHAR_" + to_string(i + 1) + "_END_"); + col2[i] = (int64_t)((*strTmp2).c_str()); + std::string* strTmp3 = new std::string("Col3_CHAR_" + to_string(i + 1) + "_END_"); + col3[i] = (int64_t)((*strTmp3).c_str()); + std::string* strTmp4 = new std::string("Col4_CHAR_" + to_string(i + 1) + "_END_"); + col4[i] = (int64_t)((*strTmp4).c_str()); + string_cache_test_.push_back(strTmp1); + string_cache_test_.push_back(strTmp2); + string_cache_test_.push_back(strTmp3); + string_cache_test_.push_back(strTmp4); + } + + int64_t allData[numCols] = {reinterpret_cast(col0), + reinterpret_cast(col1), + reinterpret_cast(col2), + reinterpret_cast(col3), + reinterpret_cast(col4)}; + VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); + delete[] inputTypes; + delete[] col0; + delete[] col1; + delete[] col2; + delete[] col3; + delete[] col4; + + for (int p = 0; p < string_cache_test_.size(); p++) { + delete string_cache_test_[p]; // release memory + } + return in; +} + +VectorBatch* CreateVectorBatch_3fixedCols_withPid(int parNum, int rowNum) { + int partitionNum = parNum; + + // gen vectorBatch + const int32_t numCols = 4; + int32_t* inputTypes = new int32_t[numCols]; + inputTypes[0] = OMNI_INT; + inputTypes[1] = OMNI_INT; + inputTypes[2] = OMNI_LONG; + inputTypes[3] = OMNI_DOUBLE; + + const int32_t numRows = rowNum; + auto* col0 = new int32_t[numRows]; + auto* col1 = new int32_t[numRows]; + auto* col2 = new int64_t[numRows]; + auto* col3 = new double[numRows]; + for (int i = 0; i < numRows; i++) { + col0[i] = i % partitionNum; + col1[i] = i + 1; + col2[i] = i + 1; + col3[i] = i + 1; + } + + int64_t allData[numCols] = {reinterpret_cast(col0), + reinterpret_cast(col1), + reinterpret_cast(col2), + reinterpret_cast(col3)}; + VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); + delete[] inputTypes; + delete[] col0; + delete[] col1; + delete[] col2; + delete[] col3; + return in; +} + +VectorBatch* CreateVectorBatch_2dictionaryCols_withPid(int partitionNum) { + // dictionary test + // construct input data + const int32_t dataSize = 6; + // prepare data + int32_t data0[dataSize] = {111, 112, 113, 114, 115, 116}; + int64_t data1[dataSize] = {221, 222, 223, 224, 225, 226}; + void *datas[2] = {data0, data1}; + DataTypes sourceTypes(std::vector({ IntDataType(), LongDataType()})); + int32_t ids[] = {0, 1, 2, 3, 4, 5}; + VectorBatch *vectorBatch = new VectorBatch(3, dataSize); + VectorAllocator *allocator = omniruntime::vec::GetProcessGlobalVecAllocator(); + IntVector *intVectorTmp = new IntVector(allocator, 6); + for (int i = 0; i < intVectorTmp->GetSize(); i++) { + intVectorTmp->SetValue(i, (i+1) % partitionNum); + } + for (int32_t i = 0; i < 3; i ++) { + if (i == 0) { + vectorBatch->SetVector(i, intVectorTmp); + } else { + omniruntime::vec::DataType dataType = sourceTypes.Get()[i - 1]; + vectorBatch->SetVector(i, CreateDictionaryVector(dataType, dataSize, ids, dataSize, datas[i - 1])); + } + } + return vectorBatch; +} + +VectorBatch* CreateVectorBatch_1decimal128Col_withPid(int partitionNum, int rowNum) { + auto decimal128InputVec = buildVector(Decimal128DataType(38, 2), rowNum); + VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); + IntVector *intVectorPid = new IntVector(allocator, rowNum); + for (int i = 0; i < intVectorPid->GetSize(); i++) { + intVectorPid->SetValue(i, (i+1) % partitionNum); + } + VectorBatch *vecBatch = new VectorBatch(2); + vecBatch->SetVector(0, intVectorPid); + vecBatch->SetVector(1, decimal128InputVec); + return vecBatch; +} + +VectorBatch* CreateVectorBatch_1decimal64Col_withPid(int partitionNum, int rowNum) { + auto decimal64InputVec = buildVector(Decimal64DataType(7, 2), rowNum); + VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); + IntVector *intVectorPid = new IntVector(allocator, rowNum); + for (int i = 0; i < intVectorPid->GetSize(); i++) { + intVectorPid->SetValue(i, (i+1) % partitionNum); + } + VectorBatch *vecBatch = new VectorBatch(2); + vecBatch->SetVector(0, intVectorPid); + vecBatch->SetVector(1, decimal64InputVec); + return vecBatch; +} + +VectorBatch* CreateVectorBatch_2decimalCol_withPid(int partitionNum, int rowNum) { + auto decimal64InputVec = buildVector(Decimal64DataType(7, 2), rowNum); + auto decimal128InputVec = buildVector(Decimal128DataType(38, 2), rowNum); + VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); + IntVector *intVectorPid = new IntVector(allocator, rowNum); + for (int i = 0; i < intVectorPid->GetSize(); i++) { + intVectorPid->SetValue(i, (i+1) % partitionNum); + } + VectorBatch *vecBatch = new VectorBatch(3); + vecBatch->SetVector(0, intVectorPid); + vecBatch->SetVector(1, decimal64InputVec); + vecBatch->SetVector(2, decimal128InputVec); + return vecBatch; +} + +VectorBatch* CreateVectorBatch_someNullRow_vectorBatch() { + const int32_t numRows = 6; + int32_t data1[numRows] = {0, 1, 2, 0, 1, 2}; + int64_t data2[numRows] = {0, 1, 2, 3, 4, 5}; + double data3[numRows] = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; + std::string data4[numRows] = {"abcde", "fghij", "klmno", "pqrst", "", ""}; + + auto vec0 = CreateVector(data1, numRows); + auto vec1 = CreateVector(data2, numRows); + auto vec2 = CreateVector(data3, numRows); + auto vec3 = CreateVarcharVector(VarcharDataType(5), data4, numRows); + for (int i = 0; i < numRows; i = i + 2) { + vec0->SetValueNull(i); + vec1->SetValueNull(i); + vec2->SetValueNull(i); + vec3->SetValueNull(i); + } + VectorBatch *vecBatch = new VectorBatch(4); + vecBatch->SetVector(0, vec0); + vecBatch->SetVector(1, vec1); + vecBatch->SetVector(2, vec2); + vecBatch->SetVector(3, vec3); + return vecBatch; +} + +VectorBatch* CreateVectorBatch_someNullCol_vectorBatch() { + const int32_t numRows = 6; + int32_t data1[numRows] = {0, 1, 2, 0, 1, 2}; + int64_t data2[numRows] = {0, 1, 2, 3, 4, 5}; + double data3[numRows] = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; + std::string data4[numRows] = {"abcde", "fghij", "klmno", "pqrst", "", ""}; + + auto vec0 = CreateVector(data1, numRows); + auto vec1 = CreateVector(data2, numRows); + auto vec2 = CreateVector(data3, numRows); + auto vec3 = CreateVarcharVector(VarcharDataType(5), data4, numRows); + for (int i = 0; i < numRows; i = i + 1) { + vec1->SetValueNull(i); + vec3->SetValueNull(i); + } + VectorBatch *vecBatch = new VectorBatch(4); + vecBatch->SetVector(0, vec0); + vecBatch->SetVector(1, vec1); + vecBatch->SetVector(2, vec2); + vecBatch->SetVector(3, vec3); + return vecBatch; +} + +void Test_Shuffle_Compression(std::string compStr, int32_t numPartition, int32_t numVb, int32_t numRow) { + std::string shuffleTestsDir = s_shuffle_tests_dir; + std::string tmpDataFilePath = shuffleTestsDir + "/shuffle_" + compStr; + if (!IsFileExist(shuffleTestsDir)) { + mkdir(shuffleTestsDir.c_str(), S_IRWXU|S_IRWXG|S_IROTH|S_IXOTH); + } + int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR}; + int colNumber = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputVecTypeIds; + inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; + inputDataTypes.inputDataScales = new uint32_t[colNumber]; + int partitionNum = numPartition; + int splitterId = Test_splitter_nativeMake("hash", + partitionNum, + inputDataTypes, + colNumber, + 4096, + compStr.c_str(), + tmpDataFilePath, + 0, + shuffleTestsDir); + for (uint64_t j = 0; j < numVb; j++) { + VectorBatch* vb = CreateVectorBatch_4col_withPid(partitionNum, numRow); + Test_splitter_split(splitterId, vb); + } + Test_splitter_stop(splitterId); + Test_splitter_close(splitterId); + delete[] inputDataTypes.inputDataPrecisions; + delete[] inputDataTypes.inputDataScales; + if (IsFileExist(tmpDataFilePath)) { + remove(tmpDataFilePath.c_str()); + } +} + +long Test_splitter_nativeMake(std::string partitioning_name, + int num_partitions, + InputDataTypes inputDataTypes, + int numCols, + int buffer_size, + const char* compression_type_jstr, + std::string data_file_jstr, + int num_sub_dirs, + std::string local_dirs_jstr) { + auto splitOptions = SplitOptions::Defaults(); + if (buffer_size > 0) { + splitOptions.buffer_size = buffer_size; + } + if (num_sub_dirs > 0) { + splitOptions.num_sub_dirs = num_sub_dirs; + } + setenv("NATIVESQL_SPARK_LOCAL_DIRS", local_dirs_jstr.c_str(), 1); + auto compression_type_result = GetCompressionType(compression_type_jstr); + splitOptions.compression_type = compression_type_result; + splitOptions.data_file = data_file_jstr; + //TODO: memory pool select + auto splitter = Splitter::Make(partitioning_name, inputDataTypes, numCols, num_partitions, std::move(splitOptions)); + return shuffle_splitter_holder_.Insert(std::shared_ptr(splitter)); +} + +int Test_splitter_split(long splitter_id, VectorBatch* vb) { + auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); + //初始化split各全局变量 + splitter->Split(*vb); +} + +void Test_splitter_stop(long splitter_id) { + auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); + if (!splitter) { + std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + throw std::runtime_error("Test no splitter."); + } + splitter->Stop(); +} + +void Test_splitter_close(long splitter_id) { + auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); + if (!splitter) { + std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + throw std::runtime_error("Test no splitter."); + } + shuffle_splitter_holder_.Erase(splitter_id); +} + +void GetFilePath(const char *path, const char *filename, char *filepath) { + strcpy(filepath, path); + if(filepath[strlen(path) - 1] != '/') { + strcat(filepath, "/"); + } + strcat(filepath, filename); +} + +void DeletePathAll(const char* path) { + DIR *dir; + struct dirent *dirInfo; + struct stat statBuf; + char filepath[256] = {0}; + lstat(path, &statBuf); + if (S_ISREG(statBuf.st_mode)) { + remove(path); + } else if (S_ISDIR(statBuf.st_mode)) { + if ((dir = opendir(path)) != NULL) { + while ((dirInfo = readdir(dir)) != NULL) { + GetFilePath(path, dirInfo->d_name, filepath); + if (strcmp(dirInfo->d_name, ".") == 0 || strcmp(dirInfo->d_name, "..") == 0) { + continue; + } + DeletePathAll(filepath); + rmdir(filepath); + } + closedir(dir); + rmdir(path); + } + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..042dc4a142ebaa20cabba735894f4c0b221b1ab7 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h @@ -0,0 +1,93 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_THESTRAL_PLUGIN_TEST_UTILS_H +#define SPARK_THESTRAL_PLUGIN_TEST_UTILS_H + +#include +#include +#include +#include +#include "../../src/shuffle/splitter.h" +#include "../../src/jni/concurrent_map.h" + +static ConcurrentMap> shuffle_splitter_holder_; + +static std::string s_shuffle_tests_dir = "/tmp/shuffleTests"; + +VectorBatch* CreateInputData(const int32_t numRows, const int32_t numCols, int32_t* inputTypeIds, int64_t* allData); + +Vector *buildVector(const DataType &aggType, int32_t rowNumber); + +VectorBatch* CreateVectorBatch_1row_varchar_withPid(int pid, std::string inputChar); + +VectorBatch* CreateVectorBatch_4col_withPid(int parNum, int rowNum); + +VectorBatch* CreateVectorBatch_1longCol_withPid(int parNum, int rowNum); + +VectorBatch* CreateVectorBatch_2column_1row_withPid(int pid, std::string strVar, int intVar); + +VectorBatch* CreateVectorBatch_4varcharCols_withPid(int parNum, int rowNum); + +VectorBatch* CreateVectorBatch_4charCols_withPid(int parNum, int rowNum); + +VectorBatch* CreateVectorBatch_3fixedCols_withPid(int parNum, int rowNum); + +VectorBatch* CreateVectorBatch_2dictionaryCols_withPid(int partitionNum); + +VectorBatch* CreateVectorBatch_1decimal128Col_withPid(int partitionNum, int rowNum); + +VectorBatch* CreateVectorBatch_1decimal64Col_withPid(int partitionNum, int rowNum); + +VectorBatch* CreateVectorBatch_2decimalCol_withPid(int partitionNum, int rowNum); +VectorBatch* CreateVectorBatch_someNullRow_vectorBatch(); + +VectorBatch* CreateVectorBatch_someNullCol_vectorBatch(); + +void Test_Shuffle_Compression(std::string compStr, int32_t numPartition, int32_t numVb, int32_t numRow); + +long Test_splitter_nativeMake(std::string partitioning_name, + int num_partitions, + InputDataTypes inputDataTypes, + int numCols, + int buffer_size, + const char* compression_type_jstr, + std::string data_file_jstr, + int num_sub_dirs, + std::string local_dirs_jstr); + +int Test_splitter_split(long splitter_id, VectorBatch* vb); + +void Test_splitter_stop(long splitter_id); + +void Test_splitter_close(long splitter_id); + +template T *CreateVector(V *values, int32_t length) +{ + VectorAllocator *vecAllocator = omniruntime::vec::GetProcessGlobalVecAllocator(); + auto vector = new T(vecAllocator, length); + vector->SetValues(0, values, length); + return vector; +} + +void GetFilePath(const char *path, const char *filename, char *filepath); + +void DeletePathAll(const char* path); + +#endif //SPARK_THESTRAL_PLUGIN_TEST_UTILS_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/dev/checkstyle-suppressions.xml b/omnioperator/omniop-spark-extension/dev/checkstyle-suppressions.xml new file mode 100644 index 0000000000000000000000000000000000000000..337f1e91889208eda7bcf193048c1c155d8371cf --- /dev/null +++ b/omnioperator/omniop-spark-extension/dev/checkstyle-suppressions.xml @@ -0,0 +1,51 @@ + + + + + + + + + + + + + + + + + + diff --git a/omnioperator/omniop-spark-extension/dev/checkstyle.xml b/omnioperator/omniop-spark-extension/dev/checkstyle.xml new file mode 100644 index 0000000000000000000000000000000000000000..11684fe7df9a6c1d5a7f76dc708987d937ddb759 --- /dev/null +++ b/omnioperator/omniop-spark-extension/dev/checkstyle.xml @@ -0,0 +1,190 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..1c6823ab7a808fd8c31f3d6ac1c20b2078cf2d29 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/pom.xml @@ -0,0 +1,371 @@ + + + + 4.0.0 + + com.huawei.kunpeng + boostkit-omniop-spark-parent + 3.1.1-1.0.0 + ../pom.xml + + + boostkit-omniop-spark + jar + BoostKit Spark Native Sql Engine Extension With OmniOperator + + + ../cpp/ + ../cpp/build/releases/ + ${cpp.test} + incremental + 0.6.1 + 3.0.0 + 1.6.2 + ${project.build.directory}/scala-${scala.binary.version}/jars + + + + + org.apache.spark + spark-sql_${scala.binary.version} + provided + + + org.apache.hadoop + hadoop-client + provided + + + com.google.protobuf + protobuf-java + + + com.huawei.boostkit + boostkit-omniop-bindings + + + junit + junit + test + + + io.trino.tpcds + tpcds + 1.4 + test + + + org.scalatest + scalatest_${scala.binary.version} + 3.2.3 + test + + + org.mockito + mockito-core + 2.23.4 + test + + + org.apache.spark + spark-core_${scala.binary.version} + test-jar + test + 3.1.1 + + + org.apache.spark + spark-catalyst_${scala.binary.version} + test-jar + test + 3.1.1 + + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + 3.1.1 + test + + + com.tdunning + json + 1.8 + + + org.apache.spark + spark-hive_${scala.binary.version} + 3.1.1 + test + + + + + + + + + ${artifactId}-${version}${dep.os.arch} + + + ${cpp.build.dir} + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + kr.motd.maven + os-maven-plugin + ${os.plugin.version} + + + + + exec-maven-plugin + org.codehaus.mojo + 3.0.0 + + + Build CPP + generate-resources + + exec + + + bash + + ${cpp.dir}/build.sh + ${plugin.cpp.test} + + + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + ${protobuf.maven.version} + + ${project.basedir}/../cpp/src/proto + + + + + compile + + + + + + net.alchim31.maven + scala-maven-plugin + 4.4.0 + + ${scala.recompile.mode} + + + + scala-compile-first + process-resources + + add-source + compile + + + + scala-test-compile + process-test-resources + + testCompile + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + prepare-test-jar + test-compile + + test-jar + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.8.0 + + 1.8 + 1.8 + + + + compile + + compile + + + + + + org.apache.maven.plugins + maven-assembly-plugin + 3.1.0 + + + jar-with-dependencies + + + + + make-assembly + package + + single + + + + + + org.scalastyle + scalastyle-maven-plugin + 1.0.0 + + false + true + true + false + ${project.basedir}/src/main/scala + ${project.basedir}/src/test/scala + ${user.dir}/scalastyle-config.xml + ${project.basedir}/target/scalastyle-output.xml + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} + + + + + check + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 3.1.2 + + false + true + + ${project.basedir}/src/main/java + ${project.basedir}/src/main/scala + + + ${project.basedir}/src/test/java + ${project.basedir}/src/test/scala + + dev/checkstyle.xml + ${project.basedir}/target/checkstyle-output.xml + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} + + + + com.puppycrawl.tools + checkstyle + 8.29 + + + + + + check + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + org.scalatest + scalatest-maven-plugin + + false + + ${project.build.directory}/surefire-reports + . + TestSuite.txt + + + + test + + test + + + + + + org.scoverage + scoverage-maven-plugin + 1.4.11 + + + test + test + + report + + + + + true + true + ${project.build.sourceEncoding} + + + + + \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/AircompressorCodec.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/AircompressorCodec.java new file mode 100644 index 0000000000000000000000000000000000000000..2d57291bce319e9402a69ddb91f4277ba6aaa61b --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/AircompressorCodec.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.compress; + +import io.airlift.compress.Compressor; +import io.airlift.compress.Decompressor; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.EnumSet; + +public class AircompressorCodec implements CompressionCodec { + private final Compressor compressor; + private final Decompressor decompressor; + + AircompressorCodec(Compressor compressor,Decompressor decompressor) { + this.compressor = compressor; + this.decompressor = decompressor; + } + + // Thread local buffer + private static final ThreadLocal threadBuffer = + new ThreadLocal(){ + @Override + protected byte[] initialValue() { + return null; + } + }; + + protected static byte[] getBuffer(int size) { + byte[] result = threadBuffer.get(); + if (result == null || result.length < size || result.length > size * 2) { + result = new byte[size]; + threadBuffer.set(result); + } + return result; + } + + @Override + public boolean compress(ByteBuffer in,ByteBuffer out, + ByteBuffer overflow) throws IOException { + int inBytes = in.remaining(); + // I should work on a patch for Snappy to support an overflow buffer + // to prevent the extra buffer copy. + byte[] compressed = getBuffer(compressor.maxCompressedLength(inBytes)); + int outBytes = + compressor.compress(in.array(),in.arrayOffset() + in.position(),inBytes, + compressed,0,compressed.length); + if (outBytes < inBytes) { + int remaining = out.remaining(); + if(remaining >= outBytes) { + System.arraycopy(compressed,0,out.array(),out.arrayOffset() + + out.position(),outBytes); + out.position(out.position() + outBytes); + } else { + System.arraycopy(compressed,0,out.array(),out.arrayOffset() + + out.position(),remaining); + out.position(out.limit()); + System.arraycopy(compressed,remaining,overflow.array(), + overflow.arrayOffset(),outBytes - remaining); + overflow.position(outBytes - remaining); + } + return true; + } else { + return false; + } + } + + @Override + public int decompress(byte[] input,int inputLength,byte[] output) throws IOException { + int uncompressLen = + decompressor.decompress(input, 0, inputLength, + output, 0, output.length); + return uncompressLen; + } + + @Override + public CompressionCodec modify(EnumSet modifiers) { + // snappy allows no modifications + return this; + } + + @Override + public void reset() { + // Nothing to do. + } + + @Override + public void close() { + // Nothing to do. + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/CompressionCodec.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/CompressionCodec.java new file mode 100644 index 0000000000000000000000000000000000000000..453deaaf84eaaa9e0610db41b52b23b793dc6911 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/CompressionCodec.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.compress; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.EnumSet; + +public interface CompressionCodec { + + enum Modifier { + /* speed/compression tradeoffs */ + FASTEST, + FAST, + DEFAULT, + /* data sensitivity modifiers */ + TEXT, + BINARY + }; + + /** + * Compress the in buffer to the out buffer. + * @param in the bytes to compress + * @param out the uncompressed bytes + * @param overflow put any additional bytes here + * @return true if the output is smaller than input + * @throws IOException + */ + boolean compress(ByteBuffer in, ByteBuffer out, ByteBuffer overflow + ) throws IOException; + + /** + * Decompress the in buffer to the out buffer. + * @param input the bytes to decompress + * @param output the decompressed bytes + * @throws IOException + */ + int decompress(byte[] input, int inputLength, byte[] output) throws IOException; + + /** + * Produce a modified compression codec if the underlying algorithm allows + * modification. + * + * This does not modify the current object, but returns a new object if + * modifications are possible. Returns the same object if no modifications + * are possible. + * @param modifiers compression modifiers (nullable) + * @return codec for use after optional modification + */ + CompressionCodec modify(EnumSet modifiers); + + /** Resets the codec, preparing it for reuse. */ + void reset(); + + /** Closes the codec, releasing the resources. */ + void close(); + +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/CompressionKind.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/CompressionKind.java new file mode 100644 index 0000000000000000000000000000000000000000..016a1928e723ea3afea2dc6e661a3fa27902cebf --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/CompressionKind.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.compress; + +/** + * An enumeration that lists the generic compression algorithms that + * can be applied to ORC files. + */ +public enum CompressionKind { + NONE, ZLIB, SNAPPY, LZO, LZ4 +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/CompressionUtil.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/CompressionUtil.java new file mode 100644 index 0000000000000000000000000000000000000000..bef30bb718c6b8a7561febe795cfb2e2a29527f2 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/CompressionUtil.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.compress; + +import io.airlift.compress.lz4.Lz4Compressor; +import io.airlift.compress.lz4.Lz4Decompressor; +import io.airlift.compress.lzo.LzoCompressor; +import io.airlift.compress.lzo.LzoDecompressor; + +public class CompressionUtil { + public static CompressionCodec createCodec(String compressionCodec) { + switch (compressionCodec) { + case "zlib": + return new ZlibCodec(); + case "snappy": + return new SnappyCodec(); + case "lzo": + return new AircompressorCodec(new LzoCompressor(), + new LzoDecompressor()); + case "lz4": + return new AircompressorCodec(new Lz4Compressor(), + new Lz4Decompressor()); + default: + throw new IllegalArgumentException("Unknown compression codec: " + + compressionCodec); + + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/DecompressionStream.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/DecompressionStream.java new file mode 100644 index 0000000000000000000000000000000000000000..4bbe922ca85907f0eda0e7820277d757a00d2ebe --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/DecompressionStream.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.compress; + + +import java.io.IOException; +import java.io.InputStream; + +public class DecompressionStream extends InputStream { + public static final int HEADER_SIZE = 3; + public static final int UNCOMPRESSED_LENGTH = 64 * 1024; + + protected int compressBlockSize = 64 * 1024; + private boolean finishedReading = false; + protected final InputStream in; + private byte[] compressed; // 临时原始压缩数据 + private byte[] uncompressed; // 解压后数据 + private int uncompressedCursor = 0; // 解压后数组游标 + private int uncompressedLimit = 0; // 解压后数组最大标 + + private final CompressionCodec codec; + + public DecompressionStream(InputStream in, CompressionCodec codec, int compressBlockSize) throws IOException { + this.compressBlockSize = compressBlockSize; + this.in = in; + this.codec = codec; + this.readHeader(); + } + + + public void close() throws IOException { + this.compressed = null; + this.uncompressed = null; + if (this.in != null) { + this.in.close(); + } + } + + protected void readHeader() throws IOException { + int[] b = new int[3]; + for (int i = 0; i < HEADER_SIZE; i++) { + int ret = in.read(); + if (ret == -1) { + finishedReading = true; + return; + } + b[i] = ret & 0xff; + } + boolean isOriginal = (b[0] & 0x01) == 1; + int chunkLength = (b[2] << 15) | (b[1] << 7) | (b[0] >> 1); + + uncompressedCursor = 0; + uncompressedLimit = 0; + // read the entire input data to the buffer + compressed = new byte[chunkLength]; // 8K + int readBytes = 0; + while (readBytes < chunkLength) { + int ret = in.read(compressed, readBytes, chunkLength - readBytes); + if (ret == -1) { + finishedReading = true; + break; + } + readBytes += ret; + } + if (readBytes < chunkLength) { + throw new IOException("failed to read chunk!"); + } + if (isOriginal) { + uncompressed = compressed; + uncompressedLimit = chunkLength; + return; + } + if (uncompressed == null || UNCOMPRESSED_LENGTH > uncompressed.length) { + uncompressed = new byte[UNCOMPRESSED_LENGTH]; + } + + int actualUncompressedLength = codec.decompress(compressed, chunkLength, uncompressed); + uncompressedLimit = actualUncompressedLength; + } + + public int read(byte[] data, int offset, int length) throws IOException { + if (!ensureUncompressed()) { + return -1; + } + int actualLength = Math.min(length, uncompressedLimit - uncompressedCursor); + System.arraycopy(uncompressed, uncompressedCursor, data, offset, actualLength); + uncompressedCursor += actualLength; + return actualLength; + } + + public int read() throws IOException { + if (!ensureUncompressed()) { + return -1; + } + int data = 0xff & uncompressed[uncompressedCursor]; + uncompressedCursor += 1; + return data; + } + + private boolean ensureUncompressed() throws IOException { + while (uncompressed == null || (uncompressedLimit - uncompressedCursor) == 0) { + if (finishedReading) { + return false; + } + readHeader(); + } + return true; + } + + public int available() throws IOException { + if (!ensureUncompressed()) { + return 0; + } + return uncompressedLimit - uncompressedCursor; + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/SnappyCodec.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/SnappyCodec.java new file mode 100644 index 0000000000000000000000000000000000000000..b6e2c98825b517f75aa6e0446de0b23840f3e965 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/SnappyCodec.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.compress; + +import io.airlift.compress.snappy.SnappyCompressor; +import io.airlift.compress.snappy.SnappyDecompressor; + +import java.io.IOException; + +public class SnappyCodec extends AircompressorCodec { + + SnappyCodec() { + super(new SnappyCompressor(), new SnappyDecompressor()); + } + + @Override + public int decompress(byte[] input, int inputLength, byte[] output) throws IOException { + return super.decompress(input, inputLength, output); + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/ZlibCodec.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/ZlibCodec.java new file mode 100644 index 0000000000000000000000000000000000000000..31212a0bd1bdaeaecd61b9eb8622b8ee566df937 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/ZlibCodec.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.compress; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.EnumSet; +import java.util.zip.DataFormatException; +import java.util.zip.Deflater; +import java.util.zip.Inflater; + +public class ZlibCodec implements CompressionCodec { + + private int level; + private int strategy; + + public ZlibCodec() { + level = Deflater.DEFAULT_COMPRESSION; + strategy = Deflater.DEFAULT_STRATEGY; + } + + private ZlibCodec(int level, int strategy) { + this.level = level; + this.strategy = strategy; + } + + @Override + public boolean compress(ByteBuffer in, ByteBuffer out, + ByteBuffer overflow) throws IOException { + int length = in.remaining(); + int outSize = 0; + Deflater deflater = new Deflater(level, true); + try { + deflater.setStrategy(strategy); + deflater.setInput(in.array(), in.arrayOffset() + in.position(), length); + deflater.finish(); + int offset = out.arrayOffset() + out.position(); + while (!deflater.finished() && (length > outSize)) { + int size = deflater.deflate(out.array(), offset, out.remaining()); + out.position(size + out.position()); + outSize += size; + offset += size; + // if we run out of space in the out buffer, use the overflow + if (out.remaining() == 0) { + if (overflow == null) { + return false; + } + out = overflow; + offset = out.arrayOffset() + out.position(); + } + } + }finally { + deflater.end(); + } + return length > outSize; + } + + @Override + public int decompress(byte[] input, int inputLength, byte[] output) throws IOException { + Inflater inflater = new Inflater(true); + int offset = 0; + int length = output.length; + try { + inflater.setInput(input, 0, inputLength); + while (!(inflater.finished() || inflater.needsDictionary() || + inflater.needsInput())) { + try { + int count = inflater.inflate(output, offset, length - offset); + offset += count; + } catch (DataFormatException dfe) { + throw new IOException("Bad compression data", dfe); + } + } + } finally { + inflater.end(); + } + return offset; + } + + @Override + public CompressionCodec modify(/* @Nullable */ EnumSet modifiers){ + + if (modifiers == null){ + return this; + } + + int l = this.level; + int s = this.strategy; + + for (Modifier m : modifiers){ + switch (m){ + case BINARY: + /* filtered == less LZ77, more huffman */ + s = Deflater.FILTERED; + break; + case TEXT: + s = Deflater.DEFAULT_STRATEGY; + break; + case FASTEST: + // deflate_fast looking for 8 byte patterns + l = Deflater.BEST_SPEED; + break; + case FAST: + // deflate_fast looking for 16 byte patterns + l = Deflater.BEST_SPEED + 1; + break; + case DEFAULT: + // deflate_slow looking for 128 byte patterns + l = Deflater.DEFAULT_COMPRESSION; + break; + default: + break; + } + } + return new ZlibCodec(l,s); + } + + @Override + public void reset(){ + level = Deflater.DEFAULT_COMPRESSION; + strategy = Deflater.DEFAULT_STRATEGY; + } + + @Override + public void close(){ + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/NativeLoader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/NativeLoader.java new file mode 100644 index 0000000000000000000000000000000000000000..7cd435f7ce052e1ece0c9b5140fc151a9e7152a0 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/NativeLoader.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; + +import nova.hetu.omniruntime.utils.NativeLog; + +/** + * @since 2021.08 + */ + +public class NativeLoader { + + private static volatile NativeLoader INSTANCE; + private static final String LIBRARY_NAME = "spark_columnar_plugin"; + private static final Logger LOG = LoggerFactory.getLogger(NativeLoader.class); + private static final int BUFFER_SIZE = 1024; + + public static NativeLoader getInstance() { + if (INSTANCE == null) { + synchronized (NativeLoader.class) { + if (INSTANCE == null) { + INSTANCE = new NativeLoader(); + } + } + } + return INSTANCE; + } + + private NativeLoader() { + File tempFile = null; + try { + String nativeLibraryPath = File.separator + System.mapLibraryName(LIBRARY_NAME); + tempFile = File.createTempFile(LIBRARY_NAME, ".so"); + try (InputStream in = NativeLoader.class.getResourceAsStream(nativeLibraryPath); + FileOutputStream fos = new FileOutputStream(tempFile)) { + int i; + byte[] buf = new byte[BUFFER_SIZE]; + while ((i = in.read(buf)) != -1) { + fos.write(buf, 0, i); + } + System.load(tempFile.getCanonicalPath()); + NativeLog.getInstance(); + } + } catch (IOException e) { + LOG.warn("fail to load library from Jar!errmsg:{}", e.getMessage()); + System.loadLibrary(LIBRARY_NAME); + } finally { + if (tempFile != null) { + tempFile.deleteOnExit(); + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java new file mode 100644 index 0000000000000000000000000000000000000000..22707a88e58000e403a845a123ff2e5839205079 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java @@ -0,0 +1,283 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Decimal64DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Decimal128Vec; +import nova.hetu.omniruntime.vector.Vec; + +import org.apache.hadoop.hive.ql.io.sarg.ExpressionTree; +import org.apache.hadoop.hive.ql.io.sarg.PredicateLeaf; +import org.apache.orc.OrcFile.ReaderOptions; +import org.apache.orc.Reader.Options; +import org.json.JSONObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.apache.orc.TypeDescription; + +import java.sql.Date; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + + +public class OrcColumnarBatchJniReader { + private static final Logger LOGGER = LoggerFactory.getLogger(OrcColumnarBatchJniReader.class); + + public long reader; + public long recordReader; + public long batchReader; + public int[] colsToGet; + public int realColsCnt; + + public OrcColumnarBatchJniReader() { + NativeLoader.getInstance(); + } + + public JSONObject getSubJson(ExpressionTree etNode) { + JSONObject jsonObject = new JSONObject(); + jsonObject.put("op", etNode.getOperator().ordinal()); + if (etNode.getOperator().toString().equals("LEAF")) { + jsonObject.put("leaf", etNode.toString()); + return jsonObject; + } + ArrayList child = new ArrayList(); + for (ExpressionTree childNode : etNode.getChildren()) { + JSONObject rtnJson = getSubJson(childNode); + child.add(rtnJson); + } + jsonObject.put("child", child); + return jsonObject; + } + + public JSONObject getLeavesJson(List leaves, TypeDescription schema) { + JSONObject jsonObjectList = new JSONObject(); + for (int i = 0; i < leaves.size(); i++) { + PredicateLeaf pl = leaves.get(i); + JSONObject jsonObject = new JSONObject(); + jsonObject.put("op", pl.getOperator().ordinal()); + jsonObject.put("name", pl.getColumnName()); + jsonObject.put("type", pl.getType().ordinal()); + if (pl.getLiteral() != null) { + if (pl.getType() == PredicateLeaf.Type.DATE) { + jsonObject.put("literal", ((int)Math.ceil(((Date)pl.getLiteral()).getTime()* 1.0/3600/24/1000)) + ""); + } else if (pl.getType() == PredicateLeaf.Type.DECIMAL) { + int decimalP = schema.findSubtype(pl.getColumnName()).getPrecision(); + int decimalS = schema.findSubtype(pl.getColumnName()).getScale(); + String[] spiltValues = pl.getLiteral().toString().split("\\."); + String strToAdd = ""; + if (spiltValues.length == 2) { + strToAdd = String.format("%1$" + decimalS + "s", spiltValues[1]).replace(' ', '0'); + } else { + strToAdd = String.format("%1$" + decimalS + "s", "").replace(' ', '0'); + } + jsonObject.put("literal", spiltValues[0] + "." + strToAdd + " " + decimalP + " " + decimalS); + } else { + jsonObject.put("literal", pl.getLiteral().toString()); + } + } else { + jsonObject.put("literal", ""); + } + if ((pl.getLiteralList() != null) && (pl.getLiteralList().size() != 0)){ + List lst = new ArrayList(); + for (Object ob : pl.getLiteralList()) { + if (pl.getType() == PredicateLeaf.Type.DECIMAL) { + int decimalP = schema.findSubtype(pl.getColumnName()).getPrecision(); + int decimalS = schema.findSubtype(pl.getColumnName()).getScale(); + String[] spiltValues = ob.toString().split("\\."); + String strToAdd = ""; + if (spiltValues.length == 2) { + strToAdd = String.format("%1$" + decimalS + "s", spiltValues[1]).replace(' ', '0'); + } else { + strToAdd = String.format("%1$" + decimalS + "s", "").replace(' ', '0'); + } + lst.add(spiltValues[0] + "." + strToAdd + " " + decimalP + " " + decimalS); + } else if (pl.getType() == PredicateLeaf.Type.DATE) { + lst.add(((int)Math.ceil(((Date)pl.getLiteral()).getTime()* 1.0/3600/24/1000)) + ""); + } else { + lst.add(ob.toString()); + } + } + jsonObject.put("literalList", lst); + } else { + jsonObject.put("literalList", new ArrayList()); + } + jsonObjectList.put("leaf-" + i, jsonObject); + } + return jsonObjectList; + } + + /** + * Init Orc reader. + * + * @param path split file path + * @param options split file options + */ + public long initializeReaderJava(String path, ReaderOptions options) { + JSONObject job = new JSONObject(); + if (options.getOrcTail() == null) { + job.put("serializedTail", ""); + } else { + job.put("serializedTail", options.getOrcTail().getSerializedTail().toString()); + } + job.put("tailLocation", 9223372036854775807L); + reader = initializeReader(path, job); + return reader; + } + + /** + * Init Orc RecordReader. + * + * @param options split file options + */ + public long initializeRecordReaderJava(Options options) { + JSONObject job = new JSONObject(); + if (options.getInclude() == null) { + job.put("include", ""); + } else { + job.put("include", options.getInclude().toString()); + } + job.put("offset", options.getOffset()); + job.put("length", options.getLength()); + if (options.getSearchArgument() != null) { + LOGGER.debug("SearchArgument:" + options.getSearchArgument().toString()); + JSONObject jsonexpressionTree = getSubJson(options.getSearchArgument().getExpression()); + job.put("expressionTree", jsonexpressionTree); + JSONObject jsonleaves = getLeavesJson(options.getSearchArgument().getLeaves(), options.getSchema()); + job.put("leaves", jsonleaves); + } + + List allCols; + if (options.getColumnNames() == null) { + allCols = Arrays.asList(getAllColumnNames(reader)); + } else { + allCols = Arrays.asList(options.getColumnNames()); + } + ArrayList colToInclu = new ArrayList(); + List optionField = options.getSchema().getFieldNames(); + colsToGet = new int[optionField.size()]; + realColsCnt = 0; + for (int i = 0; i < optionField.size(); i++) { + if (allCols.contains(optionField.get(i))) { + colToInclu.add(optionField.get(i)); + colsToGet[i] = 0; + realColsCnt++; + } else { + colsToGet[i] = -1; + } + } + job.put("includedColumns", colToInclu.toArray()); + recordReader = initializeRecordReader(reader, job); + return recordReader; + } + + public long initBatchJava(long batchSize) { + batchReader = initializeBatch(recordReader, batchSize); + return 0; + } + + public long getNumberOfRowsJava() { + return getNumberOfRows(recordReader, batchReader); + } + + public long getRowNumber() { + return recordReaderGetRowNumber(recordReader); + } + + public float getProgress() { + return recordReaderGetProgress(recordReader); + } + + public void close() { + recordReaderClose(recordReader, reader, batchReader); + } + + public void seekToRow(long rowNumber) { + recordReaderSeekToRow(recordReader, rowNumber); + } + + public int next(Vec[] vecList) { + int vectorCnt = vecList.length; + int[] typeIds = new int[realColsCnt]; + long[] vecNativeIds = new long[realColsCnt]; + long rtn = recordReaderNext(recordReader, reader, batchReader, typeIds, vecNativeIds); + if (rtn == 0) { + return 0; + } + int nativeGetId = 0; + for (int i = 0; i < vectorCnt; i++) { + if (colsToGet[i] != 0) { + continue; + } + switch (DataType.DataTypeId.values()[typeIds[nativeGetId]]) { + case OMNI_DATE32: + case OMNI_INT: { + vecList[i] = new IntVec(vecNativeIds[nativeGetId]); + break; + } + case OMNI_LONG: { + vecList[i] = new LongVec(vecNativeIds[nativeGetId]); + break; + } + case OMNI_VARCHAR: { + vecList[i] = new VarcharVec(vecNativeIds[nativeGetId]); + break; + } + case OMNI_DECIMAL128: { + vecList[i] = new Decimal128Vec(vecNativeIds[nativeGetId], Decimal128DataType.DECIMAL128); + break; + } + case OMNI_DECIMAL64: { + vecList[i] = new LongVec(vecNativeIds[nativeGetId]); + break; + } + default: { + LOGGER.error("UNKNOWN TYPE ERROR IN JAVA" + DataType.DataTypeId.values()[typeIds[i]]); + } + } + nativeGetId++; + } + return (int)rtn; + } + + public native long initializeReader(String path, JSONObject job); + + public native long initializeRecordReader(long reader, JSONObject job); + + public native long initializeBatch(long rowReader, long batchSize); + + public native long recordReaderNext(long rowReader, long reader, long batchReader, int[] typeId, long[] vecNativeId); + + public native long recordReaderGetRowNumber(long rowReader); + + public native float recordReaderGetProgress(long rowReader); + + public native void recordReaderClose(long rowReader, long reader, long batchReader); + + public native void recordReaderSeekToRow(long rowReader, long rowNumber); + + public native String[] getAllColumnNames(long reader); + + public native long getNumberOfRows(long rowReader, long batch); +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/SparkJniWrapper.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/SparkJniWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..9aa7c414bc841fc10f2e0daae30fd168b1690055 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/SparkJniWrapper.java @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import com.huawei.boostkit.spark.vectorized.PartitionInfo; +import com.huawei.boostkit.spark.vectorized.SplitResult; + +public class SparkJniWrapper { + + public SparkJniWrapper() { + NativeLoader.getInstance(); + } + + public long make(PartitionInfo part, + int bufferSize, + String codec, + String dataFile, + int subDirsPerLocalDir, + String localDirs, + long shuffleCompressBlockSize, + int shuffleSpillBatchRowNum, + long shuffleSpillMemoryThreshold) { + return nativeMake( + part.getPartitionName(), + part.getPartitionNum(), + part.getInputTypes(), + part.getNumCols(), + bufferSize, + codec, + dataFile, + subDirsPerLocalDir, + localDirs, + shuffleCompressBlockSize, + shuffleSpillBatchRowNum, + shuffleSpillMemoryThreshold); + } + + public native long nativeMake( + String shortName, + int numPartitions, + String inputTypes, + int numCols, + int bufferSize, + String codec, + String dataFile, + int subDirsPerLocalDir, + String localDirs, + long shuffleCompressBlockSize, + int shuffleSpillBatchRowNum, + long shuffleSpillMemoryThreshold + ); + + /** + * Split one record batch represented by bufAddrs and bufSizes into several batches. The batch is + * split according to the first column as partition id. During splitting, the data in native + * buffers will be write to disk when the buffers are full. + * + * @param nativeVectorBatch Addresses of nativeVectorBatch + */ + public native void split(long splitterId, long nativeVectorBatch); + + /** + * Write the data remained in the buffers hold by native splitter to each partition's temporary + * file. And stop processing splitting + * + * @param splitterId splitter instance id + * @return SplitResult + */ + public native SplitResult stop(long splitterId); + + /** + * Release resources associated with designated splitter instance. + * + * @param splitterId splitter instance id + */ + public native void close(long splitterId); +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java new file mode 100644 index 0000000000000000000000000000000000000000..1b94c47b03f77f628e1a3a0e6f5d0a6ee2147597 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.serialize; + + +import com.google.protobuf.InvalidProtocolBufferException; +import nova.hetu.omniruntime.vector.BooleanVec; +import nova.hetu.omniruntime.vector.Decimal128Vec; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.ShortVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; + +import org.apache.spark.sql.execution.vectorized.OmniColumnVector; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class ShuffleDataSerializer { + + public static ColumnarBatch deserialize(byte[] bytes) { + try { + VecData.VecBatch vecBatch = VecData.VecBatch.parseFrom(bytes); + int vecCount = vecBatch.getVecCnt(); + int rowCount = vecBatch.getRowCnt(); + ColumnVector[] vecs = new ColumnVector[vecCount]; + for (int i = 0; i < vecCount; i++) { + vecs[i] = buildVec(vecBatch.getVecs(i), rowCount); + } + return new ColumnarBatch(vecs, rowCount); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage()); + } + } + + private static ColumnVector buildVec(VecData.Vec protoVec, int vecSize) { + VecData.VecType protoTypeId = protoVec.getVecType(); + Vec vec; + DataType type; + switch (protoTypeId.getTypeId()) { + case VEC_TYPE_INT: + type = DataTypes.IntegerType; + vec = new IntVec(vecSize); + break; + case VEC_TYPE_DATE32: + type = DataTypes.DateType; + vec = new IntVec(vecSize); + break; + case VEC_TYPE_LONG: + type = DataTypes.LongType; + vec = new LongVec(vecSize); + break; + case VEC_TYPE_DATE64: + type = DataTypes.DateType; + vec = new LongVec(vecSize); + break; + case VEC_TYPE_DECIMAL64: + type = DataTypes.createDecimalType(protoTypeId.getPrecision(), protoTypeId.getScale()); + vec = new LongVec(vecSize); + break; + case VEC_TYPE_SHORT: + type = DataTypes.ShortType; + vec = new ShortVec(vecSize); + break; + case VEC_TYPE_BOOLEAN: + type = DataTypes.BooleanType; + vec = new BooleanVec(vecSize); + break; + case VEC_TYPE_DOUBLE: + type = DataTypes.DoubleType; + vec = new DoubleVec(vecSize); + break; + case VEC_TYPE_VARCHAR: + case VEC_TYPE_CHAR: + type = DataTypes.StringType; + vec = new VarcharVec(protoVec.getValues().size(), vecSize); + if (vec instanceof VarcharVec) { + ((VarcharVec) vec).setOffsetsBuf(protoVec.getOffset().toByteArray()); + } + break; + case VEC_TYPE_DECIMAL128: + type = DataTypes.createDecimalType(protoTypeId.getPrecision(), protoTypeId.getScale()); + vec = new Decimal128Vec(vecSize); + break; + case VEC_TYPE_TIME32: + case VEC_TYPE_TIME64: + case VEC_TYPE_INTERVAL_DAY_TIME: + case VEC_TYPE_INTERVAL_MONTHS: + default: + throw new IllegalStateException("Unexpected value: " + protoTypeId.getTypeId()); + } + vec.setValuesBuf(protoVec.getValues().toByteArray()); + vec.setNullsBuf(protoVec.getNulls().toByteArray()); + OmniColumnVector vecTmp = new OmniColumnVector(vecSize, type, false); + vecTmp.setVec(vec); + return vecTmp; + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/vectorized/PartitionInfo.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/vectorized/PartitionInfo.java new file mode 100644 index 0000000000000000000000000000000000000000..67ce94c3cc73cc7e91687662f70935a1f6852afb --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/vectorized/PartitionInfo.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.vectorized; + +import java.io.Serializable; + +/** + * hold partitioning info needed by splitter + */ +public class PartitionInfo implements Serializable { + private final String partitionName; + private final int partitionNum; + private final int numCols; + private final String inputTypes; + + /** + * Init PartitionInfo + * + * @param partitionName Partitioning name. "single" for SinglePartitioning, "rr" for + * RoundRobinPartitioning, "hash" for HashPartitioning, "range" for RangePartitioning + * @param partitionNum partition number + */ + public PartitionInfo(String partitionName, int partitionNum, int numCols, String inputTypes) { + this.partitionName = partitionName; + this.partitionNum = partitionNum; + this.numCols = numCols; + this.inputTypes = inputTypes; + } + + public String getPartitionName() { + return partitionName; + } + + public int getPartitionNum() { + return partitionNum; + } + + public int getNumCols() { + return numCols; + } + + public String getInputTypes() { + return inputTypes; + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/vectorized/SplitResult.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/vectorized/SplitResult.java new file mode 100644 index 0000000000000000000000000000000000000000..f0169749c9cd926e54be3a9ba6da956b250e3470 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/vectorized/SplitResult.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.vectorized; + +/** POJO to hold native split result */ +public class SplitResult { + private final long totalComputePidTime; + private final long totalWriteTime; + private final long totalSpillTime; + private final long totalBytesWritten; + private final long totalBytesSpilled; + private final long[] partitionLengths; + + public SplitResult( + long totalComputePidTime, + long totalWriteTime, + long totalSpillTime, + long totalBytesWritten, + long totalBytesSpilled, + long[] partitionLengths) { + this.totalComputePidTime = totalComputePidTime; + this.totalWriteTime = totalWriteTime; + this.totalSpillTime = totalSpillTime; + this.totalBytesWritten = totalBytesWritten; + this.totalBytesSpilled = totalBytesSpilled; + this.partitionLengths = partitionLengths; + } + + public long getTotalComputePidTime() { + return totalComputePidTime; + } + + public long getTotalWriteTime() { + return totalWriteTime; + } + + public long getTotalSpillTime() { + return totalSpillTime; + } + + public long getTotalBytesWritten() { + return totalBytesWritten; + } + + public long getTotalBytesSpilled() { + return totalBytesSpilled; + } + + public long[] getPartitionLengths() { + return partitionLengths; + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java new file mode 100644 index 0000000000000000000000000000000000000000..f64724c100ee56f29a0fdb0306698eacf8ebf8b8 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -0,0 +1,219 @@ +/* + * Copyright (C) 2021-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import com.google.common.annotations.VisibleForTesting; +import com.huawei.boostkit.spark.jni.OrcColumnarBatchJniReader; +import nova.hetu.omniruntime.vector.Vec; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; +import org.apache.orc.OrcConf; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; +import org.apache.spark.sql.execution.vectorized.OmniColumnVector; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +import java.io.IOException; + +/** + * To support vectorization in WholeStageCodeGen, this reader returns ColumnarBatch. + * After creating, `initialize` and `initBatch` should be called sequentially. + */ +public class OmniOrcColumnarBatchReader extends RecordReader { + + // The capacity of vectorized batch. + + private int capacity; + + /** + * The column IDs of the physical ORC file schema which are required by this reader. + * -1 means this required column is partition column, or it doesn't exist in the ORC file. + * Ideally partition column should never appear in the physical file, and should only appear + * in the directory name. However, Spark allows partition columns inside physical file, + * but Spark will discard the values from the file, and use the partition value got from + * directory name. The column order will be reserved though. + */ + @VisibleForTesting + public int[] requestedDataColIds; + + // Native Record reader from ORC row batch. + private OrcColumnarBatchJniReader recordReader; + + private StructField[] requiredFields; + + // The result columnar batch for vectorized execution by whole-stage codegen. + @VisibleForTesting + public ColumnarBatch columnarBatch; + + // The wrapped ORC column vectors. + private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers; + + private Vec[] vecs; + + public OmniOrcColumnarBatchReader(int capacity) { + this.capacity = capacity; + } + + + @Override + public Void getCurrentKey() { + return null; + } + + @Override + public ColumnarBatch getCurrentValue() { + return columnarBatch; + } + + @Override + public float getProgress() throws IOException { + return recordReader.getProgress(); + } + + @Override + public boolean nextKeyValue() throws IOException { + return nextBatch(); + } + + @Override + public void close() throws IOException { + // TODO close omni vec + if (recordReader != null) { + recordReader.close(); + recordReader = null; + } + } + + /** + * Initialize ORC file reader and batch record reader. + * Please note that `initBatch` is needed to be called after this. + */ + @Override + public void initialize( + InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException { + FileSplit fileSplit = (FileSplit)inputSplit; + Configuration conf = taskAttemptContext.getConfiguration(); + OrcFile.ReaderOptions readerOptions = OrcFile.readerOptions(conf) + .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) + .filesystem(fileSplit.getPath().getFileSystem(conf)); +// long reader = OrcColumnarNativeReader.initializeReaderJava(fileSplit.getPath(), readerOptions); + Reader.Options options = + OrcColumnarNativeReader.buildOptions(conf, fileSplit.getStart(), fileSplit.getLength()); + recordReader = new OrcColumnarBatchJniReader(); + recordReader.initializeReaderJava(fileSplit.getPath().toString(), readerOptions); + recordReader.initializeRecordReaderJava(options); + } + + /** + * Initialize columnar batch by setting required schema and partition information. + * With this information, this creates ColumnarBatch with the full schema. + * + * @param orcSchema Schema from ORC file reader. + * @param requiredFields All the fields that are required to return, including partition fields. + * @param requestedDataColIds Requested column ids from orcSchema. -1 if not existed. + * @param requestedPartitionColIds Requested column ids from partition schema. -1 if not existed. + * @param partitionValues Values of partition columns. + */ + public void initBatch( + TypeDescription orcSchema, + StructField[] requiredFields, + int[] requestedDataColIds, + int[] requestedPartitionColIds, + InternalRow partitionValues) { + // wrap = new OrcShimUtils.VectorizedRowBatchWrap(orcSchema.createRowBatch(capacity)); + // assert(!wrap.batch().selectedInUse); // `selectedInUse` should be initialized with `false`. + assert(requiredFields.length == requestedDataColIds.length); + assert(requiredFields.length == requestedPartitionColIds.length); + // If a required column is also partition column, use partition value and don't read from file. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedPartitionColIds[i] != -1) { + requestedDataColIds[i] = -1; + } + } + this.requiredFields = requiredFields; + this.requestedDataColIds = requestedDataColIds; + + StructType resultSchema = new StructType(requiredFields); + + // Just wrap the ORC column vector instead of copying it to Spark column vector. + orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; + + for (int i = 0; i < requiredFields.length; i++) { + DataType dt = requiredFields[i].dataType(); + if (requestedPartitionColIds[i] != -1) { + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); + ColumnVectorUtils.populate(partitionCol, partitionValues, requestedPartitionColIds[i]); + partitionCol.setIsConstant(); + orcVectorWrappers[i] = partitionCol; + } else { + int colId = requestedDataColIds[i]; + // Initialize the missing columns once. + if (colId == -1) { + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + orcVectorWrappers[i] = missingCol; + } else { + orcVectorWrappers[i] = new OmniColumnVector(capacity, dt, false); + } + } + } + // init batch + recordReader.initBatchJava(capacity); + vecs = new Vec[orcVectorWrappers.length]; + columnarBatch = new ColumnarBatch(orcVectorWrappers); + } + + /** + * Return true if there exists more data in the next batch. If exists, prepare the next batch + * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns. + */ + private boolean nextBatch() throws IOException { + // TODO recordReader.nextBatch(wrap.batch()); + int batchSize = capacity; + if ((requiredFields.length == 1 && requestedDataColIds[0] == -1) || requiredFields.length == 0) { + batchSize = (int) recordReader.getNumberOfRowsJava(); + } else { + batchSize = recordReader.next(vecs); + } + if (batchSize == 0) { + return false; + } + columnarBatch.setNumRows(batchSize); + + for (int i = 0; i < requiredFields.length; i++) { + if (requestedDataColIds[i] != -1) { + int colId = requestedDataColIds[i]; + ((OmniColumnVector) orcVectorWrappers[i]).setVec(vecs[colId]); + } + } + return true; + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarNativeReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarNativeReader.java new file mode 100644 index 0000000000000000000000000000000000000000..fc581846caedd068d4572863afe24449e581ad9f --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarNativeReader.java @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2021-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import org.apache.commons.codec.binary.Base64; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; +import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl; +import org.apache.orc.OrcConf; +import org.apache.orc.OrcFile.ReaderOptions; +import org.apache.orc.Reader.Options; +import org.apache.orc.TypeDescription; +import org.apache.orc.mapred.OrcInputFormat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class OrcColumnarNativeReader { + private static final Logger LOGGER = LoggerFactory.getLogger(OrcColumnarNativeReader.class); + + public static Options buildOptions(Configuration conf, long start, long length) { + TypeDescription schema = + TypeDescription.fromString(OrcConf.MAPRED_INPUT_SCHEMA.getString(conf)); + Options options = new Options(conf) + .range(start, length) + .useZeroCopy(OrcConf.USE_ZEROCOPY.getBoolean(conf)) + .skipCorruptRecords(OrcConf.SKIP_CORRUPT_DATA.getBoolean(conf)) + .tolerateMissingSchema(OrcConf.TOLERATE_MISSING_SCHEMA.getBoolean(conf)); + if (schema != null) { + options.schema(schema); + } else { + // TODO + LOGGER.error("TODO: null schema should support"); + } + options.include(OrcInputFormat.parseInclude(schema, + OrcConf.INCLUDE_COLUMNS.getString(conf))); + String kryoSarg = OrcConf.KRYO_SARG.getString(conf); + String sargColumns = OrcConf.SARG_COLUMNS.getString(conf); + if (kryoSarg != null && sargColumns != null) { + byte[] sargBytes = Base64.decodeBase64(kryoSarg); + SearchArgument sarg = + new Kryo().readObject(new Input(sargBytes), SearchArgumentImpl.class); + options.searchArgument(sarg, sargColumns.split(",")); + sarg.getExpression().toString(); + } + return options; + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java new file mode 100644 index 0000000000000000000000000000000000000000..c0e00761ec3c77120b8563924be8fa464339ba65 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java @@ -0,0 +1,812 @@ +/* + * Copyright (C) 2021-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized; + +import nova.hetu.omniruntime.vector.*; + +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; + +/** + * OmniColumnVector + */ +public class OmniColumnVector extends WritableColumnVector { + private static final boolean BIG_ENDIAN_PLATFORM = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + + /** + * Allocates columns to store elements of each field of the schema on heap. + * Capacity is the initial capacity of the vector and it will grow as necessary. + * Capacity is in number of elements, not number of bytes. + */ + public static OmniColumnVector[] allocateColumns(int capacity, StructType schema, boolean initVec) { + return allocateColumns(capacity, schema.fields(), initVec); + } + + /** + * Allocates columns to store elements of each field on heap. Capacity is the + * initial capacity of the vector and it will grow as necessary. Capacity is in + * number of elements, not number of bytes. + */ + public static OmniColumnVector[] allocateColumns(int capacity, StructField[] fields, boolean initVec) { + OmniColumnVector[] vectors = new OmniColumnVector[fields.length]; + for (int i = 0; i < fields.length; i++) { + vectors[i] = new OmniColumnVector(capacity, fields[i].dataType(), initVec); + } + return vectors; + } + + // The data stored in these arrays need to maintain binary compatible. We can + // directly pass this buffer to external components. + // This is faster than a boolean array and we optimize this over memory + // footprint. + // Array for each type. Only 1 is populated for any type. + private BooleanVec booleanDataVec; + private ShortVec shortDataVec; + private IntVec intDataVec; + private LongVec longDataVec; + private DoubleVec doubleDataVec; + private Decimal128Vec decimal128DataVec; + private VarcharVec charsTypeDataVec; + private DictionaryVec dictionaryData; + + // init vec + private boolean initVec; + + public OmniColumnVector(int capacity, DataType type, boolean initVec) { + super(capacity, type); + this.initVec = initVec; + if (this.initVec) { + reserveInternal(capacity); + } + reset(); + } + + /** + * get vec + * + * @return Vec + */ + public Vec getVec() { + if (dictionaryData != null) { + return dictionaryData; + } + + if (type instanceof LongType) { + return longDataVec; + } else if (type instanceof BooleanType) { + return booleanDataVec; + } else if (type instanceof ShortType) { + return shortDataVec; + } else if (type instanceof IntegerType) { + return intDataVec; + } else if (type instanceof DecimalType) { + if (DecimalType.is64BitDecimalType(type)) { + return longDataVec; + } else { + return decimal128DataVec; + } + } else if (type instanceof DoubleType) { + return doubleDataVec; + } else if (type instanceof StringType) { + return charsTypeDataVec; + } else if (type instanceof DateType) { + return intDataVec; + } else if (type instanceof ByteType) { + return charsTypeDataVec; + } else { + return null; + } + } + + /** + * set Vec + * + * @param vec Vec + */ + public void setVec(Vec vec) { + if (vec instanceof DictionaryVec) { + dictionaryData = (DictionaryVec) vec; + } else if (type instanceof LongType) { + this.longDataVec = (LongVec) vec; + } else if (type instanceof DecimalType) { + if (DecimalType.is64BitDecimalType(type)) { + this.longDataVec = (LongVec) vec; + } else { + this.decimal128DataVec = (Decimal128Vec) vec; + } + } else if (type instanceof BooleanType) { + this.booleanDataVec = (BooleanVec) vec; + } else if (type instanceof ShortType) { + this.shortDataVec = (ShortVec) vec; + } else if (type instanceof IntegerType) { + this.intDataVec = (IntVec) vec; + } else if (type instanceof DoubleType) { + this.doubleDataVec = (DoubleVec) vec; + } else if (type instanceof StringType) { + this.charsTypeDataVec = (VarcharVec) vec; + } else if (type instanceof DateType) { + this.intDataVec = (IntVec) vec; + } else if (type instanceof ByteType) { + this.charsTypeDataVec = (VarcharVec) vec; + } else { + return; + } + } + + @Override + public void close() { + super.close(); + if (booleanDataVec != null) { + booleanDataVec.close(); + } + if (shortDataVec != null) { + shortDataVec.close(); + } + if (intDataVec != null) { + intDataVec.close(); + } + if (longDataVec != null) { + longDataVec.close(); + } + if (doubleDataVec != null) { + doubleDataVec.close(); + } + if (decimal128DataVec != null) { + decimal128DataVec.close(); + } + if (charsTypeDataVec != null) { + charsTypeDataVec.close(); + } + if (dictionaryData != null) { + dictionaryData.close(); + dictionaryData = null; + } + } + + // + // APIs dealing with nulls + // + + @Override + public boolean hasNull() { + throw new UnsupportedOperationException("hasNull is not supported"); + } + + @Override + public int numNulls() { + throw new UnsupportedOperationException("numNulls is not supported"); + } + + @Override + public void putNotNull(int rowId) {} + + @Override + public void putNull(int rowId) { + if (dictionaryData != null) { + dictionaryData.setNull(rowId); + return; + } + if (type instanceof BooleanType) { + booleanDataVec.setNull(rowId); + } else if (type instanceof ByteType) { + charsTypeDataVec.setNull(rowId); + } else if (type instanceof ShortType) { + shortDataVec.setNull(rowId); + } else if (type instanceof IntegerType) { + intDataVec.setNull(rowId); + } else if (type instanceof DecimalType) { + if (DecimalType.is64BitDecimalType(type)) { + longDataVec.setNull(rowId); + } else { + decimal128DataVec.setNull(rowId); + } + } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) { + longDataVec.setNull(rowId); + } else if (type instanceof FloatType) { + return; + } else if (type instanceof DoubleType) { + doubleDataVec.setNull(rowId); + } else if (type instanceof StringType) { + charsTypeDataVec.setNull(rowId); + } else if (type instanceof DateType) { + intDataVec.setNull(rowId); + } + } + + @Override + public void putNulls(int rowId, int count) { + boolean[] nullValue = new boolean[count]; + Arrays.fill(nullValue, true); + if (dictionaryData != null) { + dictionaryData.setNulls(rowId, nullValue, 0, count); + return; + } + if (type instanceof BooleanType) { + booleanDataVec.setNulls(rowId, nullValue, 0, count); + } else if (type instanceof ByteType) { + charsTypeDataVec.setNulls(rowId, nullValue, 0, count); + } else if (type instanceof ShortType) { + shortDataVec.setNulls(rowId, nullValue, 0, count); + } else if (type instanceof IntegerType) { + intDataVec.setNulls(rowId, nullValue, 0, count); + } else if (type instanceof DecimalType) { + if (DecimalType.is64BitDecimalType(type)) { + longDataVec.setNulls(rowId, nullValue, 0, count); + } else { + decimal128DataVec.setNulls(rowId, nullValue, 0, count); + } + } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) { + longDataVec.setNulls(rowId, nullValue, 0, count); + } else if (type instanceof FloatType) { + return; + } else if (type instanceof DoubleType) { + doubleDataVec.setNulls(rowId, nullValue, 0, count); + } else if (type instanceof StringType) { + charsTypeDataVec.setNulls(rowId, nullValue, 0, count); + } else if (type instanceof DateType) { + intDataVec.setNulls(rowId, nullValue, 0, count); + } + } + + @Override + public void putNotNulls(int rowId, int count) {} + + @Override + public boolean isNullAt(int rowId) { + if (dictionaryData != null) { + return dictionaryData.isNull(rowId); + } + if (type instanceof BooleanType) { + return booleanDataVec.isNull(rowId); + } else if (type instanceof ByteType) { + return charsTypeDataVec.isNull(rowId); + } else if (type instanceof ShortType) { + return shortDataVec.isNull(rowId); + } else if (type instanceof IntegerType) { + return intDataVec.isNull(rowId); + } else if (type instanceof DecimalType) { + if (DecimalType.is64BitDecimalType(type)) { + return longDataVec.isNull(rowId); + } else { + return decimal128DataVec.isNull(rowId); + } + } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) { + return longDataVec.isNull(rowId); + } else if (type instanceof FloatType) { + return false; + } else if (type instanceof DoubleType) { + return doubleDataVec.isNull(rowId); + } else if (type instanceof StringType) { + return charsTypeDataVec.isNull(rowId); + } else if (type instanceof DateType) { + return intDataVec.isNull(rowId); + } else { + + throw new UnsupportedOperationException("isNullAt is not supported for type:" + type); + } + } + + // + // APIs dealing with Booleans + // + + @Override + public void putBoolean(int rowId, boolean value) { + booleanDataVec.set(rowId, value); + } + + @Override + public void putBooleans(int rowId, int count, boolean value) { + for (int i = 0; i < count; ++i) { + booleanDataVec.set(i + rowId, value); + } + } + + @Override + public boolean getBoolean(int rowId) { + if (dictionaryData != null) { + return dictionaryData.getBoolean(rowId); + } + return booleanDataVec.get(rowId); + } + + @Override + public boolean[] getBooleans(int rowId, int count) { + assert (dictionary == null); + boolean[] array = new boolean[count]; + for (int i = 0; i < count; ++i) { + array[i] = booleanDataVec.get(rowId + i); + } + return array; + } + + // + + // + // APIs dealing with Bytes + // + + @Override + public void putByte(int rowId, byte value) { + charsTypeDataVec.set(rowId, new byte[]{value}); + } + + @Override + public void putBytes(int rowId, int count, byte value) { + for (int i = 0; i < count; ++i) { + charsTypeDataVec.set(rowId, new byte[]{value}); + } + } + + @Override + public void putBytes(int rowId, int count, byte[] src, int srcIndex) { + byte[] array = new byte[count]; + System.arraycopy(src, srcIndex, array, 0, count); + charsTypeDataVec.set(rowId, array); + } + + /** + * + * @param length length of string value + * @param src src value + * @param offset offset value + * @return return count of elements + */ + public final int appendString(int length, byte[] src, int offset) { + reserve(elementsAppended + 1); + int result = elementsAppended; + putBytes(elementsAppended, length, src, offset); + elementsAppended++; + return result; + } + + @Override + public byte getByte(int rowId) { + if (dictionary != null) { + return (byte) dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); + } else if (dictionaryData != null) { + return dictionaryData.getBytes(rowId)[0]; + } else { + return charsTypeDataVec.get(rowId)[0]; + } + } + + @Override + public byte[] getBytes(int rowId, int count) { + assert (dictionary == null); + byte[] array = new byte[count]; + for (int i = 0; i < count; i++) { + if (type instanceof StringType) { + array[i] = ((VarcharVec) ((OmniColumnVector) getChild(0)).getVec()).get(rowId + i)[0]; + } else if (type instanceof ByteType) { + array[i] = charsTypeDataVec.get(rowId + i)[0]; + } else { + throw new UnsupportedOperationException("getBytes is not supported for type:" + type); + } + } + return array; + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (dictionaryData != null) { + return UTF8String.fromBytes(dictionaryData.getBytes(rowId)); + } else { + return UTF8String.fromBytes(charsTypeDataVec.get(rowId)); + } + } + + @Override + protected UTF8String getBytesAsUTF8String(int rowId, int count) { + return UTF8String.fromBytes(getBytes(rowId, count), rowId, count); + } + + // + // APIs dealing with Shorts + // + + @Override + public void putShort(int rowId, short value) { + shortDataVec.set(rowId, value); + } + + @Override + public void putShorts(int rowId, int count, short value) { + for (int i = 0; i < count; ++i) { + shortDataVec.set(i + rowId, value); + } + } + + @Override + public void putShorts(int rowId, int count, short[] src, int srcIndex) { + shortDataVec.put(src, rowId, srcIndex, count); + } + + @Override + public void putShorts(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException("putShorts is not supported"); + } + + @Override + public short getShort(int rowId) { + if (dictionary != null) { + return (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); + } else if (dictionaryData != null) { + throw new UnsupportedOperationException("getShort is not supported for dictionary vector"); + } else { + return shortDataVec.get(rowId); + } + } + + @Override + public short[] getShorts(int rowId, int count) { + assert (dictionary == null); + short[] array = new short[count]; + for (int i = 0; i < count; i++) { + array[i] = shortDataVec.get(rowId + i); + } + return array; + } + + // + // APIs dealing with Ints + // + + @Override + public void putInt(int rowId, int value) { + intDataVec.set(rowId, value); + } + + @Override + public void putInts(int rowId, int count, int value) { + for (int i = 0; i < count; ++i) { + intDataVec.set(rowId + i, value); + } + } + + @Override + public void putInts(int rowId, int count, int[] src, int srcIndex) { + intDataVec.put(src, rowId, srcIndex, count); + } + + @Override + public void putInts(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException("putInts is not supported"); + } + + @Override + public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; + for (int i = 0; i < count; ++i, srcOffset += 4) { + intDataVec.set(rowId + i, Platform.getInt(src, srcOffset)); + if (BIG_ENDIAN_PLATFORM) { + intDataVec.set(rowId + i, Integer.reverseBytes(intDataVec.get(i + rowId))); + } + } + } + + @Override + public int getInt(int rowId) { + if (dictionary != null) { + return dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); + } else if (dictionaryData != null) { + return dictionaryData.getInt(rowId); + } else { + return intDataVec.get(rowId); + } + } + + @Override + public int[] getInts(int rowId, int count) { + assert (dictionary == null); + int[] array = new int[count]; + for (int i = 0; i < count; i++) { + array[i] = intDataVec.get(rowId + i); + } + return array; + } + + /** + * Returns the dictionary Id for rowId. This should only be called when the + * ColumnVector is dictionaryIds. We have this separate method for dictionaryIds + * as per SPARK-16928. + */ + public int getDictId(int rowId) { + assert (dictionary == null) : "A ColumnVector dictionary should not have a dictionary for itself."; + return intDataVec.get(rowId); + } + + // + // APIs dealing with Longs + // + + @Override + public void putLong(int rowId, long value) { + longDataVec.set(rowId, value); + } + + @Override + public void putLongs(int rowId, int count, long value) { + for (int i = 0; i < count; ++i) { + longDataVec.set(i + rowId, value); + } + } + + @Override + public void putLongs(int rowId, int count, long[] src, int srcIndex) { + longDataVec.put(src, rowId, srcIndex, count); + } + + @Override + public void putLongs(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException("putLongs is not supported"); + } + + @Override + public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; + for (int i = 0; i < count; ++i, srcOffset += 8) { + longDataVec.set(i + rowId, Platform.getLong(src, srcOffset)); + if (BIG_ENDIAN_PLATFORM) { + longDataVec.set(i + rowId, Long.reverseBytes(longDataVec.get(i + rowId))); + } + } + } + + @Override + public long getLong(int rowId) { + if (dictionary != null) { + return dictionary.decodeToLong(dictionaryIds.getDictId(rowId)); + } else if (dictionaryData != null) { + return dictionaryData.getLong(rowId); + } else { + return longDataVec.get(rowId); + } + } + + @Override + public long[] getLongs(int rowId, int count) { + assert (dictionary == null); + long[] array = new long[count]; + for (int i = 0; i < count; i++) { + array[i] = longDataVec.get(rowId + i); + } + return array; + } + + // + // APIs dealing with floats, omni-vector not support float data type + // + + @Override + public void putFloat(int rowId, float value) { + throw new UnsupportedOperationException("putFloat is not supported"); + } + + @Override + public void putFloats(int rowId, int count, float value) { + throw new UnsupportedOperationException("putFloats is not supported"); + } + + @Override + public void putFloats(int rowId, int count, float[] src, int srcIndex) { + throw new UnsupportedOperationException("putFloats is not supported"); + } + + @Override + public void putFloats(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException("putFloats is not supported"); + } + + @Override + public void putFloatsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException("putFloatsLittleEndian is not supported"); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException("getFloat is not supported"); + } + + @Override + public float[] getFloats(int rowId, int count) { + throw new UnsupportedOperationException("getFloats is not supported"); + } + + // + // APIs dealing with doubles + // + + @Override + public void putDouble(int rowId, double value) { + doubleDataVec.set(rowId, value); + } + + @Override + public void putDoubles(int rowId, int count, double value) { + for (int i = 0; i < count; i++) { + doubleDataVec.set(rowId + i, value); + } + } + + @Override + public void putDoubles(int rowId, int count, double[] src, int srcIndex) { + throw new UnsupportedOperationException("putDoubles is not supported"); + } + + @Override + public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException("putDoubles is not supported"); + } + + @Override + public void putDoublesLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + if (!BIG_ENDIAN_PLATFORM) { + putDoubles(rowId, count, src, srcIndex); + } else { + ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); + for (int i = 0; i < count; ++i) { + doubleDataVec.set(i + rowId, bb.getDouble(srcIndex + (8 * i))); + } + } + } + + @Override + public double getDouble(int rowId) { + if (dictionary != null) { + return dictionary.decodeToDouble(dictionaryIds.getDictId(rowId)); + } else if (dictionaryData != null) { + return dictionaryData.getDouble(rowId); + } else { + return doubleDataVec.get(rowId); + } + } + + @Override + public double[] getDoubles(int rowId, int count) { + assert (dictionary == null); + double[] array = new double[count]; + for (int i = 0; i < count; i++) { + array[i] = doubleDataVec.get(rowId + i); + } + return array; + } + + // + // APIs dealing with Arrays + // + + @Override + public int getArrayLength(int rowId) { + throw new UnsupportedOperationException("getArrayLength is not supported"); + } + + @Override + public int getArrayOffset(int rowId) { + throw new UnsupportedOperationException("getArrayOffset is not supported"); + } + + @Override + public void putArray(int rowId, int offset, int length) { + throw new UnsupportedOperationException("putArray is not supported"); + } + + // + // APIs dealing with Byte Arrays + // + + @Override + public int putByteArray(int rowId, byte[] value, int offset, int length) { + throw new UnsupportedOperationException("putByteArray is not supported"); + } + + /** + * + * @param value BigDecimal + * @return return count of elements + */ + public final int appendDecimal(Decimal value) + { + reserve(elementsAppended + 1); + int result = elementsAppended; + if (value.precision() <= Decimal.MAX_LONG_DIGITS()) { + longDataVec.set(elementsAppended, value.toUnscaledLong()); + } else { + decimal128DataVec.setBigInteger(elementsAppended, value.toJavaBigDecimal().unscaledValue()); + } + elementsAppended++; + return result; + } + + @Override + public void putDecimal(int rowId, Decimal value, int precision) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + longDataVec.set(rowId, value.toUnscaledLong()); + } else { + decimal128DataVec.setBigInteger(rowId, value.toJavaBigDecimal().unscaledValue()); + } + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(rowId), precision, scale); + } else { + BigInteger value; + if (dictionaryData != null) { + value = Decimal128Vec.getDecimal(dictionaryData.getDecimal128(rowId)); + } else { + value = decimal128DataVec.getBigInteger(rowId); + } + return Decimal.apply(new BigDecimal(value, scale), precision, scale); + } + } + + @Override + public boolean isArray() { + return false; + } + + // Spilt this function out since it is the slow path. + @Override + protected void reserveInternal(int newCapacity) { + if (type instanceof BooleanType) { + booleanDataVec = new BooleanVec(newCapacity); + } else if (type instanceof ByteType) { + charsTypeDataVec = new VarcharVec(newCapacity * 4, newCapacity); + } else if (type instanceof ShortType) { + shortDataVec = new ShortVec(newCapacity); + } else if (type instanceof IntegerType) { + intDataVec = new IntVec(newCapacity); + } else if (type instanceof DecimalType) { + if (DecimalType.is64BitDecimalType(type)) { + longDataVec = new LongVec(newCapacity); + } else { + decimal128DataVec = new Decimal128Vec(newCapacity); + } + } else if (type instanceof LongType) { + longDataVec = new LongVec(newCapacity); + } else if (type instanceof FloatType) { + throw new UnsupportedOperationException("reserveInternal is not supported for type:" + type); + } else if (type instanceof DoubleType) { + doubleDataVec = new DoubleVec(newCapacity); + } else if (type instanceof StringType) { + // need to set with real column size, suppose char(200) utf8 + charsTypeDataVec = new VarcharVec(newCapacity * 4 * 200, newCapacity); + } else if (type instanceof DateType) { + intDataVec = new IntVec(newCapacity); + } else { + throw new UnsupportedOperationException("reserveInternal is not supported for type:" + type); + } + capacity = newCapacity; + } + + @Override + protected OmniColumnVector reserveNewColumn(int capacity, DataType type) { + return new OmniColumnVector(capacity, type, true); + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala new file mode 100644 index 0000000000000000000000000000000000000000..bdb5b530c4bad8e12b53a6e6d2fa652fdb98482c --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, CustomShuffleReaderExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.window.WindowExec +import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport + +case class RowGuard(child: SparkPlan) extends SparkPlan { + def output: Seq[Attribute] = child.output + protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException + } + def children: Seq[SparkPlan] = Seq(child) +} + +case class ColumnarGuardRule() extends Rule[SparkPlan] { + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val preferColumnar: Boolean = columnarConf.enablePreferColumnar + val enableColumnarShuffle: Boolean = columnarConf.enableColumnarShuffle + val enableColumnarSort: Boolean = columnarConf.enableColumnarSort + val enableTakeOrderedAndProject: Boolean = columnarConf.enableTakeOrderedAndProject && + columnarConf.enableColumnarShuffle + val enableColumnarUnion: Boolean = columnarConf.enableColumnarUnion + val enableColumnarWindow: Boolean = columnarConf.enableColumnarWindow + val enableColumnarHashAgg: Boolean = columnarConf.enableColumnarHashAgg + val enableColumnarProject: Boolean = columnarConf.enableColumnarProject + val enableColumnarFilter: Boolean = columnarConf.enableColumnarFilter + val enableColumnarBroadcastExchange: Boolean = columnarConf.enableColumnarBroadcastExchange && + columnarConf.enableColumnarBroadcastJoin + val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastJoin + val enableColumnarSortMergeJoin: Boolean = columnarConf.enableColumnarSortMergeJoin + val enableShuffledHashJoin: Boolean = columnarConf.enableShuffledHashJoin + val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan + val optimizeLevel: Integer = columnarConf.joinOptimizationThrottle + + private def tryConvertToColumnar(plan: SparkPlan): Boolean = { + try { + val columnarPlan = plan match { + case plan: FileSourceScanExec => + if (!checkColumnarBatchSupport(conf, plan)) { + return false + } + if (!enableColumnarFileScan) return false + ColumnarFileSourceScanExec( + plan.relation, + plan.output, + plan.requiredSchema, + plan.partitionFilters, + plan.optionalBucketSet, + plan.optionalNumCoalescedBuckets, + plan.dataFilters, + plan.tableIdentifier, + plan.disableBucketedScan + ).buildCheck() + case plan: ProjectExec => + if (!enableColumnarProject) return false + ColumnarProjectExec(plan.projectList, plan.child).buildCheck() + case plan: FilterExec => + if (!enableColumnarFilter) return false + ColumnarFilterExec(plan.condition, plan.child).buildCheck() + case plan: HashAggregateExec => + if (!enableColumnarHashAgg) return false + new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + plan.child).buildCheck() + case plan: SortExec => + if (!enableColumnarSort) return false + ColumnarSortExec(plan.sortOrder, plan.global, + plan.child, plan.testSpillFrequency).buildCheck() + case plan: BroadcastExchangeExec => + if (!enableColumnarBroadcastExchange) return false + new ColumnarBroadcastExchangeExec(plan.mode, plan.child) + case plan: TakeOrderedAndProjectExec => + if (!enableTakeOrderedAndProject) return false + ColumnarTakeOrderedAndProjectExec( + plan.limit, + plan.sortOrder, + plan.projectList, + plan.child).buildCheck() + case plan: UnionExec => + if (!enableColumnarUnion) return false + ColumnarUnionExec(plan.children).buildCheck() + case plan: ShuffleExchangeExec => + if (!enableColumnarShuffle) return false + new ColumnarShuffleExchangeExec(plan.outputPartitioning, plan.child).buildCheck() + case plan: BroadcastHashJoinExec => + // We need to check if BroadcastExchangeExec can be converted to columnar-based. + // If not, BHJ should also be row-based. + if (!enableColumnarBroadcastJoin) return false + val left = plan.left + left match { + case exec: BroadcastExchangeExec => + new ColumnarBroadcastExchangeExec(exec.mode, exec.child) + case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec) => + new ColumnarBroadcastExchangeExec(plan.mode, plan.child) + case BroadcastQueryStageExec(_, plan: ReusedExchangeExec) => + plan match { + case ReusedExchangeExec(_, b: BroadcastExchangeExec) => + new ColumnarBroadcastExchangeExec(b.mode, b.child) + case _ => + } + case _ => + } + val right = plan.right + right match { + case exec: BroadcastExchangeExec => + new ColumnarBroadcastExchangeExec(exec.mode, exec.child) + case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec) => + new ColumnarBroadcastExchangeExec(plan.mode, plan.child) + case BroadcastQueryStageExec(_, plan: ReusedExchangeExec) => + plan match { + case ReusedExchangeExec(_, b: BroadcastExchangeExec) => + new ColumnarBroadcastExchangeExec(b.mode, b.child) + case _ => + } + case _ => + } + ColumnarBroadcastHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + plan.left, + plan.right).buildCheck() + case plan: SortMergeJoinExec => + if (!enableColumnarSortMergeJoin) return false + new ColumnarSortMergeJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + plan.left, + plan.right, + plan.isSkewJoin).buildCheck() + case plan: WindowExec => + if (!enableColumnarWindow) return false + ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, + plan.orderSpec, plan.child).buildCheck() + case plan: ShuffledHashJoinExec => + if (!enableShuffledHashJoin) return false + ColumnarShuffledHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + plan.left, + plan.right).buildCheck() + case plan: BroadcastNestedLoopJoinExec => return false + case p => + p + } + } + catch { + case e: UnsupportedOperationException => + logDebug(s"[OPERATOR FALLBACK] ${e} ${plan.getClass} falls back to Spark operator") + return false + case r: RuntimeException => + logDebug(s"[OPERATOR FALLBACK] ${r} ${plan.getClass} falls back to Spark operator") + return false + case t: Throwable => + logDebug(s"[OPERATOR FALLBACK] ${t} ${plan.getClass} falls back to Spark operator") + return false + } + true + } + + private def existsMultiCodegens(plan: SparkPlan, count: Int = 0): Boolean = + plan match { + case plan: CodegenSupport if plan.supportCodegen => + if ((count + 1) >= optimizeLevel) return true + plan.children.map(existsMultiCodegens(_, count + 1)).exists(_ == true) + case plan: ShuffledHashJoinExec => + if ((count + 1) >= optimizeLevel) return true + plan.children.map(existsMultiCodegens(_, count + 1)).exists(_ == true) + case other => false + } + + private def supportCodegen(plan: SparkPlan): Boolean = plan match { + case plan: CodegenSupport => + plan.supportCodegen + case _ => false + } + + /** + * Inserts an InputAdapter on top of those that do not support codegen. + */ + private def insertRowGuardRecursive(plan: SparkPlan): SparkPlan = { + plan match { + case p: ShuffleExchangeExec => + RowGuard(p.withNewChildren(p.children.map(insertRowGuardOrNot))) + case p: BroadcastExchangeExec => + RowGuard(p.withNewChildren(p.children.map(insertRowGuardOrNot))) + case p: ShuffledHashJoinExec => + RowGuard(p.withNewChildren(p.children.map(insertRowGuardRecursive))) + case p if !supportCodegen(p) => + // insert row guard them recursively + p.withNewChildren(p.children.map(insertRowGuardOrNot)) + case p: CustomShuffleReaderExec => + p.withNewChildren(p.children.map(insertRowGuardOrNot)) + case p: BroadcastQueryStageExec => + p + case p => RowGuard(p.withNewChildren(p.children.map(insertRowGuardRecursive))) + } + } + + private def insertRowGuard(plan: SparkPlan): SparkPlan = { + RowGuard(plan.withNewChildren(plan.children.map(insertRowGuardOrNot))) + } + + /** + * Inserts a WholeStageCodegen on top of those that support codegen. + */ + private def insertRowGuardOrNot(plan: SparkPlan): SparkPlan = { + plan match { + // For operators that will output domain object, do not insert WholeStageCodegen for it as + // domain object can not be written into unsafe row. + case plan if !preferColumnar && existsMultiCodegens(plan) => + insertRowGuardRecursive(plan) + case plan if !tryConvertToColumnar(plan) => + insertRowGuard(plan) + case p: BroadcastQueryStageExec => + p + case other => + other.withNewChildren(other.children.map(insertRowGuardOrNot)) + } + } + + def apply(plan: SparkPlan): SparkPlan = { + insertRowGuardOrNot(plan) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala new file mode 100644 index 0000000000000000000000000000000000000000..74bbe2e5a753b44ec197bb9782f35a6a3d411fa2 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.catalyst.expressions.DynamicPruningSubquery +import org.apache.spark.sql.catalyst.expressions.aggregate.Partial +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{RowToOmniColumnarExec, _} +import org.apache.spark.sql.execution.adaptive.QueryStageExec +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.window.WindowExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport + +case class ColumnarPreOverrides() extends Rule[SparkPlan] { + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan + val enableColumnarProject: Boolean = columnarConf.enableColumnarProject + val enableColumnarFilter: Boolean = columnarConf.enableColumnarFilter + val enableColumnarHashAgg: Boolean = columnarConf.enableColumnarHashAgg + val enableTakeOrderedAndProject: Boolean = columnarConf.enableTakeOrderedAndProject && + columnarConf.enableColumnarShuffle + val enableColumnarBroadcastExchange: Boolean = columnarConf.enableColumnarBroadcastExchange && + columnarConf.enableColumnarBroadcastJoin + val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastJoin && + columnarConf.enableColumnarBroadcastExchange + val enableColumnarSortMergeJoin: Boolean = columnarConf.enableColumnarSortMergeJoin + val enableColumnarSort: Boolean = columnarConf.enableColumnarSort + val enableColumnarWindow: Boolean = columnarConf.enableColumnarWindow + val enableColumnarShuffle: Boolean = columnarConf.enableColumnarShuffle + val enableShuffledHashJoin: Boolean = columnarConf.enableShuffledHashJoin + val enableColumnarUnion: Boolean = columnarConf.enableColumnarUnion + val enableFusion: Boolean = columnarConf.enableFusion + var isSupportAdaptive: Boolean = true + + def apply(plan: SparkPlan): SparkPlan = { + replaceWithColumnarPlan(plan) + } + + def setAdaptiveSupport(enable: Boolean): Unit = { isSupportAdaptive = enable } + + def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match { + case plan: RowGuard => + val actualPlan: SparkPlan = plan.child match { + case p: BroadcastHashJoinExec => + p.withNewChildren(p.children.map { + case plan: BroadcastExchangeExec => + // if BroadcastHashJoin is row-based, BroadcastExchange should also be row-based + RowGuard(plan) + case other => other + }) + case p: BroadcastNestedLoopJoinExec => + p.withNewChildren(p.children.map { + case plan: BroadcastExchangeExec => + // if BroadcastNestedLoopJoin is row-based, BroadcastExchange should also be row-based + RowGuard(plan) + case other => other + }) + case other => + other + } + logDebug(s"Columnar Processing for ${actualPlan.getClass} is under RowGuard.") + actualPlan.withNewChildren(actualPlan.children.map(replaceWithColumnarPlan)) + case plan: FileSourceScanExec + if enableColumnarFileScan && checkColumnarBatchSupport(conf, plan) => + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarFileSourceScanExec( + plan.relation, + plan.output, + plan.requiredSchema, + plan.partitionFilters, + plan.optionalBucketSet, + plan.optionalNumCoalescedBuckets, + plan.dataFilters, + plan.tableIdentifier, + plan.disableBucketedScan + ) + case range: RangeExec => + new ColumnarRangeExec(range.range) + case plan: ProjectExec if enableColumnarProject => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + child match { + case ColumnarFilterExec(condition, child) => + ColumnarConditionProjectExec(plan.projectList, condition, child) + case _ => + ColumnarProjectExec(plan.projectList, child) + } + case plan: FilterExec if enableColumnarFilter => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarFilterExec(plan.condition, child) + case plan: HashAggregateExec if enableColumnarHashAgg => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + if (enableFusion) { + if (plan.aggregateExpressions.forall(_.mode == Partial)) { + child match { + case proj1 @ ColumnarProjectExec(_, + join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj2 @ ColumnarProjectExec(_, + join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj3 @ ColumnarProjectExec(_, + join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj4 @ ColumnarProjectExec(_, + join4 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + filter @ ColumnarFilterExec(_, + scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _) + ), _, _)), _, _)), _, _)), _, _)) => + ColumnarMultipleOperatorExec( + plan, + proj1, + join1, + proj2, + join2, + proj3, + join3, + proj4, + join4, + filter, + scan.relation, + plan.output, + scan.requiredSchema, + scan.partitionFilters, + scan.optionalBucketSet, + scan.optionalNumCoalescedBuckets, + scan.dataFilters, + scan.tableIdentifier, + scan.disableBucketedScan) + case proj1 @ ColumnarProjectExec(_, + join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj2 @ ColumnarProjectExec(_, + join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj3 @ ColumnarProjectExec(_, + join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, _, + filter @ ColumnarFilterExec(_, + scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _)) , _, _)), _, _)) => + ColumnarMultipleOperatorExec1( + plan, + proj1, + join1, + proj2, + join2, + proj3, + join3, + filter, + scan.relation, + plan.output, + scan.requiredSchema, + scan.partitionFilters, + scan.optionalBucketSet, + scan.optionalNumCoalescedBuckets, + scan.dataFilters, + scan.tableIdentifier, + scan.disableBucketedScan) + case proj1 @ ColumnarProjectExec(_, + join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj2 @ ColumnarProjectExec(_, + join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj3 @ ColumnarProjectExec(_, + join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + filter @ ColumnarFilterExec(_, + scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _)) , _, _)), _, _)) => + ColumnarMultipleOperatorExec1( + plan, + proj1, + join1, + proj2, + join2, + proj3, + join3, + filter, + scan.relation, + plan.output, + scan.requiredSchema, + scan.partitionFilters, + scan.optionalBucketSet, + scan.optionalNumCoalescedBuckets, + scan.dataFilters, + scan.tableIdentifier, + scan.disableBucketedScan) + case _ => + new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + child) + } + } else { + new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + child) + } + } else { + new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + child) + } + + case plan: TakeOrderedAndProjectExec if enableTakeOrderedAndProject => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarTakeOrderedAndProjectExec( + plan.limit, + plan.sortOrder, + plan.projectList, + child) + case plan: BroadcastExchangeExec if enableColumnarBroadcastExchange => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + new ColumnarBroadcastExchangeExec(plan.mode, child) + case plan: BroadcastHashJoinExec if enableColumnarBroadcastJoin => + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarBroadcastHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right) + case plan: ShuffledHashJoinExec if enableShuffledHashJoin => + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarShuffledHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right) + case plan: SortMergeJoinExec if enableColumnarSortMergeJoin => + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + new ColumnarSortMergeJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + case plan: SortExec if enableColumnarSort => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarSortExec(plan.sortOrder, plan.global, child, plan.testSpillFrequency) + case plan: WindowExec if enableColumnarWindow => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) + case plan: UnionExec if enableColumnarUnion => + val children = plan.children.map(replaceWithColumnarPlan) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarUnionExec(children) + case plan: ShuffleExchangeExec if enableColumnarShuffle => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + new ColumnarShuffleExchangeExec(plan.outputPartitioning, child) + case p => + val children = plan.children.map(replaceWithColumnarPlan) + logInfo(s"Columnar Processing for ${p.getClass} is currently not supported.") + p.withNewChildren(children) + } +} + +case class ColumnarPostOverrides() extends Rule[SparkPlan] { + + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + var isSupportAdaptive: Boolean = true + + def apply(plan: SparkPlan): SparkPlan = { + replaceWithColumnarPlan(plan) + } + + def setAdaptiveSupport(enable: Boolean): Unit = { isSupportAdaptive = enable } + + def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match { + case plan: RowToColumnarExec => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported") + RowToOmniColumnarExec(child) + case ColumnarToRowExec(child: ColumnarBroadcastExchangeExec) => + replaceWithColumnarPlan(child) + case r: SparkPlan + if !r.isInstanceOf[QueryStageExec] && !r.supportsColumnar && r.children.exists(c => + c.isInstanceOf[ColumnarToRowExec]) => + val children = r.children.map { + case c: ColumnarToRowExec => + c.withNewChildren(c.children.map(replaceWithColumnarPlan)) + case other => + replaceWithColumnarPlan(other) + } + r.withNewChildren(children) + case p => + val children = p.children.map(replaceWithColumnarPlan) + p.withNewChildren(children) + } +} + +case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule with Logging { + def columnarEnabled: Boolean = session.sqlContext.getConf( + "org.apache.spark.sql.columnar.enabled", "true").trim.toBoolean + + def rowGuardOverrides: ColumnarGuardRule = ColumnarGuardRule() + def preOverrides: ColumnarPreOverrides = ColumnarPreOverrides() + def postOverrides: ColumnarPostOverrides = ColumnarPostOverrides() + + var isSupportAdaptive: Boolean = true + + private def supportAdaptive(plan: SparkPlan): Boolean = { + // TODO migrate dynamic-partition-pruning onto adaptive execution. + // Only QueryStage will have Exchange as Leaf Plan + val isLeafPlanExchange = plan match { + case e: Exchange => true + case other => false + } + isLeafPlanExchange || (SQLConf.get.adaptiveExecutionEnabled && (sanityCheck(plan) && + !plan.logicalLink.exists(_.isStreaming) && + !plan.expressions.exists(_.find(_.isInstanceOf[DynamicPruningSubquery]).isDefined) && + plan.children.forall(supportAdaptive))) + } + + private def sanityCheck(plan: SparkPlan): Boolean = + plan.logicalLink.isDefined + + override def preColumnarTransitions: Rule[SparkPlan] = plan => { + if (columnarEnabled) { + isSupportAdaptive = supportAdaptive(plan) + val rule = preOverrides + rule.setAdaptiveSupport(isSupportAdaptive) + logInfo("Using BoostKit Spark Native Sql Engine Extension ColumnarPreOverrides") + rule(rowGuardOverrides(plan)) + } else { + plan + } + } + + override def postColumnarTransitions: Rule[SparkPlan] = plan => { + if (columnarEnabled) { + val rule = postOverrides + rule.setAdaptiveSupport(isSupportAdaptive) + logInfo("Using BoostKit Spark Native Sql Engine Extension ColumnarPostOverrides") + rule(plan) + } else { + plan + } + } +} + +class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { + override def apply(extensions: SparkSessionExtensions): Unit = { + logInfo("Using BoostKit Spark Native Sql Engine Extension to Speed Up Your Queries.") + extensions.injectColumnar(session => ColumnarOverrideRules(session)) + extensions.injectPlannerStrategy(_ => ShuffleJoinStrategy) + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala new file mode 100644 index 0000000000000000000000000000000000000000..39ac95e32e672a918d1fe52933cd79789724b949 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -0,0 +1,174 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.internal.SQLConf + +class ColumnarPluginConfig(conf: SQLConf) extends Logging { + // enable or disable columnar exchange + val enableColumnarShuffle: Boolean = conf + .getConfString("spark.shuffle.manager", "sort") + .equals("org.apache.spark.shuffle.sort.ColumnarShuffleManager") + + // enable or disable columnar hashagg + val enableColumnarHashAgg: Boolean = + conf.getConfString("spark.omni.sql.columnar.hashagg", "true").toBoolean + + val enableColumnarProject: Boolean = + conf.getConfString("spark.omni.sql.columnar.project", "true").toBoolean + + val enableColumnarProjFilter: Boolean = + conf.getConfString("spark.omni.sql.columnar.projfilter", "true").toBoolean + + val enableColumnarFilter: Boolean = + conf.getConfString("spark.omni.sql.columnar.filter", "true").toBoolean + + // enable or disable columnar sort + val enableColumnarSort: Boolean = + conf.getConfString("spark.omni.sql.columnar.sort", "true").toBoolean + + val enableColumnarUnion: Boolean = + conf.getConfString("spark.omni.sql.columnar.union", "true").toBoolean + + // enable or disable columnar window + val enableColumnarWindow: Boolean = + conf.getConfString("spark.omni.sql.columnar.window", "true").toBoolean + + // enable or disable columnar broadcastexchange + val enableColumnarBroadcastExchange: Boolean = + conf.getConfString("spark.omni.sql.columnar.broadcastexchange", "true").toBoolean + + // enable or disable columnar wholestagecodegen + val enableColumnarWholeStageCodegen: Boolean = + conf.getConfString("spark.omni.sql.columnar.wholestagecodegen", "true").toBoolean + + // enable or disable columnar BroadcastHashJoin + val enableColumnarBroadcastJoin: Boolean = conf + .getConfString("spark.omni.sql.columnar.broadcastJoin", "true") + .toBoolean + + // enable native table scan + val enableColumnarFileScan: Boolean = conf + .getConfString("spark.omni.sql.columnar.nativefilescan", "true") + .toBoolean + + val enableColumnarSortMergeJoin: Boolean = conf + .getConfString("spark.omni.sql.columnar.sortMergeJoin", "true") + .toBoolean + + val enableTakeOrderedAndProject: Boolean = conf + .getConfString("spark.omni.sql.columnar.takeOrderedAndProject", "true").toBoolean + + val enableShuffleBatchMerge: Boolean = conf + .getConfString("spark.omni.sql.columnar.shuffle.merge", "true").toBoolean + + val enableJoinBatchMerge: Boolean = conf + .getConfString("spark.omni.sql.columnar.broadcastJoin.merge", "false").toBoolean + + val enableSortMergeJoinBatchMerge: Boolean = conf + .getConfString("spark.omni.sql.columnar.sortMergeJoin.merge", "true").toBoolean + + // prefer to use columnar operators if set to true + val enablePreferColumnar: Boolean = + conf.getConfString("spark.omni.sql.columnar.preferColumnar", "true").toBoolean + + // fallback to row operators if there are several continous joins + val joinOptimizationThrottle: Integer = + conf.getConfString("spark.omni.sql.columnar.joinOptimizationLevel", "12").toInt + + // columnar shuffle spill batch row number + val columnarShuffleSpillBatchRowNum = + conf.getConfString("spark.shuffle.columnar.shuffleSpillBatchRowNum", "10000").toInt + + // columnar shuffle spill memory threshold + val columnarShuffleSpillMemoryThreshold = + conf.getConfString("spark.shuffle.columnar.shuffleSpillMemoryThreshold", + "2147483648").toLong + + // columnar shuffle compress block size + val columnarShuffleCompressBlockSize = + conf.getConfString("spark.shuffle.columnar.compressBlockSize", "65536").toInt + + // enable shuffle compress + val enableShuffleCompress = + conf.getConfString("spark.shuffle.compress", "true").toBoolean + + // shuffle compress type, default lz4 + val columnarShuffleCompressionCodec = + conf.getConfString("spark.io.compression.codec", "lz4").toString + + // columnar shuffle native buffer size + val columnarShuffleNativeBufferSize = + conf.getConfString("spark.sql.execution.columnar.maxRecordsPerBatch", "4096").toInt + + // columnar sort spill threshold + val columnarSortSpillRowThreshold: Integer = + conf.getConfString("spark.omni.sql.columnar.sortSpill.rowThreshold", "200000").toInt + + // columnar sort spill dir disk reserve Size, default 10GB + val columnarSortSpillDirDiskReserveSize:Long = + conf.getConfString("spark.omni.sql.columnar.sortSpill.dirDiskReserveSize", "10737418240").toLong + + // enable or disable columnar sortSpill + val enableSortSpill: Boolean = conf + .getConfString("spark.omni.sql.columnar.sortSpill.enabled", "false") + .toBoolean + + // enable or disable columnar shuffledHashJoin + val enableShuffledHashJoin: Boolean = conf + .getConfString("spark.omni.sql.columnar.shuffledHashJoin", "true") + .toBoolean + + val enableFusion: Boolean = conf + .getConfString("spark.omni.sql.columnar.fusion", "true") + .toBoolean + + // Pick columnar shuffle hash join if one side join count > = 0 to build local hash map, and is + // bigger than the other side join count, and `spark.sql.join.columnar.preferShuffledHashJoin` + // is true. + val columnarPreferShuffledHashJoin = + conf.getConfString("spark.sql.join.columnar.preferShuffledHashJoin", "false").toBoolean + + val maxBatchSizeInBytes = + conf.getConfString("spark.sql.columnar.maxBatchSizeInBytes", "2097152").toInt + + val maxRowCount = + conf.getConfString("spark.sql.columnar.maxRowCount", "20000").toInt + + val enableJit: Boolean = conf.getConfString("spark.omni.sql.columnar.jit", "false").toBoolean + + val enableDecimalCheck : Boolean = conf.getConfString("spark.omni.sql.decimal.constraint.check", "true").toBoolean +} + + +object ColumnarPluginConfig { + var ins: ColumnarPluginConfig = null + + def getConf: ColumnarPluginConfig = synchronized { + if (ins == null) { + ins = getSessionConf + } + ins + } + + def getSessionConf: ColumnarPluginConfig = { + new ColumnarPluginConfig(SQLConf.get) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala new file mode 100644 index 0000000000000000000000000000000000000000..1460c618d401de3bdd13d18eb97445a08642e030 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark + +import nova.hetu.omniruntime.`type`.DataType.DataTypeId + +/** + * @since 2022/4/15 + */ +object Constant { + val DEFAULT_STRING_TYPE_LENGTH = 2000 + val OMNI_VARCHAR_TYPE: String = DataTypeId.OMNI_VARCHAR.ordinal().toString + val OMNI_SHOR_TYPE: String = DataTypeId.OMNI_SHORT.ordinal().toString + val OMNI_INTEGER_TYPE: String = DataTypeId.OMNI_INT.ordinal().toString + val OMNI_LONG_TYPE: String = DataTypeId.OMNI_LONG.ordinal().toString + val OMNI_DOUBLE_TYPE: String = DataTypeId.OMNI_DOUBLE.ordinal().toString + val OMNI_BOOLEAN_TYPE: String = DataTypeId.OMNI_BOOLEAN.ordinal().toString + val OMNI_DATE_TYPE: String = DataTypeId.OMNI_DATE32.ordinal().toString + val IS_ENABLE_JIT: Boolean = ColumnarPluginConfig.getSessionConf.enableJit + val IS_DECIMAL_CHECK: Boolean = ColumnarPluginConfig.getSessionConf.enableDecimalCheck + val IS_SKIP_VERIFY_EXP: Boolean = true + val OMNI_DECIMAL64_TYPE: String = DataTypeId.OMNI_DECIMAL64.ordinal().toString + val OMNI_DECIMAL128_TYPE: String = DataTypeId.OMNI_DECIMAL128.ordinal().toString +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala new file mode 100644 index 0000000000000000000000000000000000000000..19da63cafad7b3bf2f0e1a863060b2eedae2935f --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark + +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper} +import org.apache.spark.sql.catalyst.planning._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.{joins, SparkPlan} + +object ShuffleJoinStrategy extends Strategy + with PredicateHelper + with JoinSelectionHelper + with SQLConfHelper { + + private val columnarPreferShuffledHashJoin = + ColumnarPluginConfig.getConf.columnarPreferShuffledHashJoin + + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, nonEquiCond, left, right, hint) + if columnarPreferShuffledHashJoin => + val enable = getBroadcastBuildSide(left, right, joinType, hint, true, conf).isEmpty && + !hintToSortMergeJoin(hint) && + getShuffleHashJoinBuildSide(left, right, joinType, hint, true, conf).isEmpty && + !hintToShuffleReplicateNL(hint) && + getBroadcastBuildSide(left, right, joinType, hint, false, conf).isEmpty + if (enable) { + var buildLeft = false + var buildRight = false + var joinCountLeft = 0 + var joinCountRight = 0 + left.foreach(x => { + if (x.isInstanceOf[Join]) { + joinCountLeft = joinCountLeft + 1 + } + }) + right.foreach(x => { + if (x.isInstanceOf[Join]) { + joinCountRight = joinCountRight + 1 + } + }) + if ((joinCountLeft > 0) && (joinCountRight == 0)) { + buildLeft = true + } + if ((joinCountRight > 0) && (joinCountLeft == 0)) { + buildRight = true + } + + getBuildSide( + canBuildShuffledHashJoinLeft(joinType) && buildLeft, + canBuildShuffledHashJoinRight(joinType) && buildRight, + left, + right + ).map { + buildSide => + Seq(joins.ShuffledHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + nonEquiCond, + planLater(left), + planLater(right))) + }.getOrElse(Nil) + } else { + Nil + } + + case _ => Nil + } + + private def getBuildSide( + canBuildLeft: Boolean, + canBuildRight: Boolean, + left: LogicalPlan, + right: LogicalPlan): Option[BuildSide] = { + if (canBuildLeft && canBuildRight) { + // returns the smaller side base on its estimated physical size, if we want to build the + // both sides. + Some(getSmallerSide(left, right)) + } else if (canBuildLeft) { + Some(BuildLeft) + } else if (canBuildRight) { + Some(BuildRight) + } else { + None + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala new file mode 100644 index 0000000000000000000000000000000000000000..6e56eba3f3e060378cef4695d5815c589ec0469b --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -0,0 +1,807 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.expression + +import com.huawei.boostkit.spark.Constant.{DEFAULT_STRING_TYPE_LENGTH, IS_DECIMAL_CHECK, OMNI_BOOLEAN_TYPE, OMNI_DATE_TYPE, OMNI_DECIMAL128_TYPE, OMNI_DECIMAL64_TYPE, OMNI_DOUBLE_TYPE, OMNI_INTEGER_TYPE, OMNI_LONG_TYPE, OMNI_SHOR_TYPE, OMNI_VARCHAR_TYPE} +import nova.hetu.omniruntime.`type`.{BooleanDataType, DataTypeSerializer, Date32DataType, Decimal128DataType, Decimal64DataType, DoubleDataType, IntDataType, LongDataType, ShortDataType, VarcharDataType} +import nova.hetu.omniruntime.constants.FunctionType +import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_AVG, OMNI_AGGREGATION_TYPE_COUNT_ALL, OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_MAX, OMNI_AGGREGATION_TYPE_MIN, OMNI_AGGREGATION_TYPE_SUM, OMNI_WINDOW_TYPE_RANK, OMNI_WINDOW_TYPE_ROW_NUMBER} +import nova.hetu.omniruntime.operator.OmniExprVerify +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.util.CharVarcharUtils.getRawTypeString +import org.apache.spark.sql.types.{BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, ShortType, StringType} + +import scala.collection.mutable.ArrayBuffer + +object OmniExpressionAdaptor extends Logging { + + def getRealExprId(expr: Expression): ExprId = { + // TODO support more complex expression + expr match { + case alias: Alias => getRealExprId(alias.child) + case subString: Substring => getRealExprId(subString.str) + case attr: Attribute => attr.exprId + case _ => + throw new UnsupportedOperationException(s"Unsupported expression: $expr") + } + } + def getExprIdMap(inputAttrs: Seq[Attribute]): Map[ExprId, Int] = { + var attrMap: Map[ExprId, Int] = Map() + inputAttrs.zipWithIndex.foreach { case (inputAttr, i) => + attrMap += (inputAttr.exprId -> i) + } + attrMap + } + + private def DECIMAL_ALLOWEDTYPES: Seq[DecimalType] = Seq(DecimalType(7,2), DecimalType(17,2), DecimalType(21,6), DecimalType(22,6), DecimalType(38,16)) + + def checkDecimalTypeWhiteList(dt: DecimalType): Unit = { + if (!IS_DECIMAL_CHECK) { + return + } + if (!DECIMAL_ALLOWEDTYPES.contains(dt)) { + throw new UnsupportedOperationException(s"decimal precision and scale not in support scope, ${dt}") + } + } + + def checkOmniJsonWhiteList(filterExpr: String, projections: Array[AnyRef]): Unit = { + if (!IS_DECIMAL_CHECK) { + return + } + // inputTypes will not be checked if parseFormat is json( == 1), + // only if its parseFormat is String (== 0) + val returnCode: Long = new OmniExprVerify().exprVerifyNative( + DataTypeSerializer.serialize(new Array[nova.hetu.omniruntime.`type`.DataType](0)), + 0, filterExpr, projections, projections.length, 1) + if (returnCode == 0) { + throw new UnsupportedOperationException(s"Unsupported OmniJson Expression \nfilter:${filterExpr} \nproejcts:${projections.mkString("=")}") + } + } + + def rewriteToOmniExpressionLiteral(expr: Expression, exprsIndexMap: Map[ExprId, Int]): String = { + expr match { + case unscaledValue: UnscaledValue => + "UnscaledValue:%s(%s, %d, %d)".format( + sparkTypeToOmniExpType(unscaledValue.dataType), + rewriteToOmniExpressionLiteral(unscaledValue.child, exprsIndexMap), + unscaledValue.child.dataType.asInstanceOf[DecimalType].precision, + unscaledValue.child.dataType.asInstanceOf[DecimalType].scale) + + // omni not support return null, now rewrite to if(IsOverflowDecimal())? NULL:MakeDecimal() + case checkOverflow: CheckOverflow => + ("IF:%s(IsOverflowDecimal:%s(%s,%d,%d,%d,%d), %s, MakeDecimal:%s(%s,%d,%d,%d,%d))") + .format(sparkTypeToOmniExpType(checkOverflow.dataType), + // IsOverflowDecimal returnType + sparkTypeToOmniExpType(BooleanType), + // IsOverflowDecimal arguments + rewriteToOmniExpressionLiteral(checkOverflow.child, exprsIndexMap), + checkOverflow.dataType.precision, checkOverflow.dataType.scale, + checkOverflow.dataType.precision, checkOverflow.dataType.scale, + // if_true + rewriteToOmniExpressionLiteral(Literal(null, checkOverflow.dataType), exprsIndexMap), + // if_false + sparkTypeToOmniExpJsonType(checkOverflow.dataType), + rewriteToOmniExpressionLiteral(checkOverflow.child, exprsIndexMap), + checkOverflow.dataType.precision, checkOverflow.dataType.scale, + checkOverflow.dataType.precision, checkOverflow.dataType.scale) + + case makeDecimal: MakeDecimal => + makeDecimal.child.dataType match { + case decimalChild: DecimalType => + ("MakeDecimal:%s(%s,%s,%s,%s,%s)") + .format(sparkTypeToOmniExpJsonType(makeDecimal.dataType), + rewriteToOmniExpressionLiteral(makeDecimal.child, exprsIndexMap), + decimalChild.precision,decimalChild.scale, + makeDecimal.precision, makeDecimal.scale) + case longChild: LongType => + ("MakeDecimal:%s(%s,%s,%s)") + .format(sparkTypeToOmniExpJsonType(makeDecimal.dataType), + rewriteToOmniExpressionLiteral(makeDecimal.child, exprsIndexMap), + makeDecimal.precision, makeDecimal.scale) + case _ => + throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}") + } + + case promotePrecision: PromotePrecision => + rewriteToOmniExpressionLiteral(promotePrecision.child, exprsIndexMap) + + case sub: Subtract => + "$operator$SUBTRACT:%s(%s,%s)".format( + sparkTypeToOmniExpType(sub.dataType), + rewriteToOmniExpressionLiteral(sub.left, exprsIndexMap), + rewriteToOmniExpressionLiteral(sub.right, exprsIndexMap)) + + case add: Add => + "$operator$ADD:%s(%s,%s)".format( + sparkTypeToOmniExpType(add.dataType), + rewriteToOmniExpressionLiteral(add.left, exprsIndexMap), + rewriteToOmniExpressionLiteral(add.right, exprsIndexMap)) + + case mult: Multiply => + "$operator$MULTIPLY:%s(%s,%s)".format( + sparkTypeToOmniExpType(mult.dataType), + rewriteToOmniExpressionLiteral(mult.left, exprsIndexMap), + rewriteToOmniExpressionLiteral(mult.right, exprsIndexMap)) + + case divide: Divide => + "$operator$DIVIDE:%s(%s,%s)".format( + sparkTypeToOmniExpType(divide.dataType), + rewriteToOmniExpressionLiteral(divide.left, exprsIndexMap), + rewriteToOmniExpressionLiteral(divide.right, exprsIndexMap)) + + case mod: Remainder => + "$operator$MODULUS:%s(%s,%s)".format( + sparkTypeToOmniExpType(mod.dataType), + rewriteToOmniExpressionLiteral(mod.left, exprsIndexMap), + rewriteToOmniExpressionLiteral(mod.right, exprsIndexMap)) + + case greaterThan: GreaterThan => + "$operator$GREATER_THAN:%s(%s,%s)".format( + sparkTypeToOmniExpType(greaterThan.dataType), + rewriteToOmniExpressionLiteral(greaterThan.left, exprsIndexMap), + rewriteToOmniExpressionLiteral(greaterThan.right, exprsIndexMap)) + + case greaterThanOrEq: GreaterThanOrEqual => + "$operator$GREATER_THAN_OR_EQUAL:%s(%s,%s)".format( + sparkTypeToOmniExpType(greaterThanOrEq.dataType), + rewriteToOmniExpressionLiteral(greaterThanOrEq.left, exprsIndexMap), + rewriteToOmniExpressionLiteral(greaterThanOrEq.right, exprsIndexMap)) + + case lessThan: LessThan => + "$operator$LESS_THAN:%s(%s,%s)".format( + sparkTypeToOmniExpType(lessThan.dataType), + rewriteToOmniExpressionLiteral(lessThan.left, exprsIndexMap), + rewriteToOmniExpressionLiteral(lessThan.right, exprsIndexMap)) + + case lessThanOrEq: LessThanOrEqual => + "$operator$LESS_THAN_OR_EQUAL:%s(%s,%s)".format( + sparkTypeToOmniExpType(lessThanOrEq.dataType), + rewriteToOmniExpressionLiteral(lessThanOrEq.left, exprsIndexMap), + rewriteToOmniExpressionLiteral(lessThanOrEq.right, exprsIndexMap)) + + case equal: EqualTo => + "$operator$EQUAL:%s(%s,%s)".format( + sparkTypeToOmniExpType(equal.dataType), + rewriteToOmniExpressionLiteral(equal.left, exprsIndexMap), + rewriteToOmniExpressionLiteral(equal.right, exprsIndexMap)) + + case or: Or => + "OR:%s(%s,%s)".format( + sparkTypeToOmniExpType(or.dataType), + rewriteToOmniExpressionLiteral(or.left, exprsIndexMap), + rewriteToOmniExpressionLiteral(or.right, exprsIndexMap)) + + case and: And => + "AND:%s(%s,%s)".format( + sparkTypeToOmniExpType(and.dataType), + rewriteToOmniExpressionLiteral(and.left, exprsIndexMap), + rewriteToOmniExpressionLiteral(and.right, exprsIndexMap)) + + case alias: Alias => rewriteToOmniExpressionLiteral(alias.child, exprsIndexMap) + case literal: Literal => toOmniLiteral(literal) + case not: Not => + "not:%s(%s)".format( + sparkTypeToOmniExpType(BooleanType), + rewriteToOmniExpressionLiteral(not.child, exprsIndexMap)) + case isnotnull: IsNotNull => + "IS_NOT_NULL:%s(%s)".format( + sparkTypeToOmniExpType(BooleanType), + rewriteToOmniExpressionLiteral(isnotnull.child, exprsIndexMap)) + // Substring + case subString: Substring => + "substr:%s(%s,%s,%s)".format( + sparkTypeToOmniExpType(subString.dataType), + rewriteToOmniExpressionLiteral(subString.str, exprsIndexMap), + rewriteToOmniExpressionLiteral(subString.pos, exprsIndexMap), + rewriteToOmniExpressionLiteral(subString.len, exprsIndexMap)) + // Cast + case cast: Cast => + unsupportedCastCheck(expr, cast) + "CAST:%s(%s)".format( + sparkTypeToOmniExpType(cast.dataType), + rewriteToOmniExpressionLiteral(cast.child, exprsIndexMap)) + // Abs + case abs: Abs => + "abs:%s(%s)".format( + sparkTypeToOmniExpType(abs.dataType), + rewriteToOmniExpressionLiteral(abs.child, exprsIndexMap)) + // In + case in: In => + "IN:%s(%s)".format( + sparkTypeToOmniExpType(in.dataType), + in.children.map(child => rewriteToOmniExpressionLiteral(child, exprsIndexMap)) + .mkString(",")) + // coming from In expression with optimizerInSetConversionThreshold + case inSet: InSet => + "IN:%s(%s,%s)".format( + sparkTypeToOmniExpType(inSet.dataType), + rewriteToOmniExpressionLiteral(inSet.child, exprsIndexMap), + inSet.set.map(child => toOmniLiteral( + Literal(child, inSet.child.dataType))).mkString(",")) + // only support with one case condition, for omni rewrite to if(A, B, C) + case caseWhen: CaseWhen => + "IF:%s(%s, %s, %s)".format( + sparkTypeToOmniExpType(caseWhen.dataType), + rewriteToOmniExpressionLiteral(caseWhen.branches(0)._1, exprsIndexMap), + rewriteToOmniExpressionLiteral(caseWhen.branches(0)._2, exprsIndexMap), + rewriteToOmniExpressionLiteral(caseWhen.elseValue.get, exprsIndexMap)) + // Sum + case sum: Sum => + "SUM:%s(%s)".format( + sparkTypeToOmniExpType(sum.dataType), + sum.children.map(child => rewriteToOmniExpressionLiteral(child, exprsIndexMap)) + .mkString(",")) + // Max + case max: Max => + "MAX:%s(%s)".format( + sparkTypeToOmniExpType(max.dataType), + max.children.map(child => rewriteToOmniExpressionLiteral(child, exprsIndexMap)) + .mkString(",")) + // Average + case avg: Average => + "AVG:%s(%s)".format( + sparkTypeToOmniExpType(avg.dataType), + avg.children.map(child => rewriteToOmniExpressionLiteral(child, exprsIndexMap)) + .mkString(",")) + // Min + case min: Min => + "MIN:%s(%s)".format( + sparkTypeToOmniExpType(min.dataType), + min.children.map(child => rewriteToOmniExpressionLiteral(child, exprsIndexMap)) + .mkString(",")) + + case coalesce: Coalesce => + "COALESCE:%s(%s)".format( + sparkTypeToOmniExpType(coalesce.dataType), + coalesce.children.map(child => rewriteToOmniExpressionLiteral(child, exprsIndexMap)) + .mkString(",")) + + case concat: Concat => + getConcatStr(concat, exprsIndexMap) + + case attr: Attribute => s"#${exprsIndexMap(attr.exprId).toString}" + case _ => + throw new UnsupportedOperationException(s"Unsupported expression: $expr") + } + } + + private def getConcatStr(concat: Concat, exprsIndexMap: Map[ExprId, Int]): String = { + val child: Seq[Expression] = concat.children + checkInputDataTypes(child) + val template = "concat:%s(%s,%s)" + val omniType = sparkTypeToOmniExpType(concat.dataType) + if (child.length == 1) { + return rewriteToOmniExpressionLiteral(child.head, exprsIndexMap) + } + // (a, b, c) => concat(concat(a,b),c) + var res = template.format(omniType, + rewriteToOmniExpressionLiteral(child.head, exprsIndexMap), + rewriteToOmniExpressionLiteral(child(1), exprsIndexMap)) + for (i <- 2 until child.length) { + res = template.format(omniType, res, + rewriteToOmniExpressionLiteral(child(i), exprsIndexMap)) + } + res + } + + private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = { + if (cast.dataType == StringType && cast.child.dataType != StringType) { + throw new UnsupportedOperationException(s"Unsupported expression: $expr") + } + } + + def toOmniLiteral(literal: Literal): String = { + val omniType = sparkTypeToOmniExpType(literal.dataType) + literal.dataType match { + case null => s"null:${omniType}" + case StringType => s"\'${literal.toString}\':${omniType}" + case _ => literal.toString + s":${omniType}" + } + } + + def rewriteToOmniJsonExpressionLiteral(expr: Expression, + exprsIndexMap: Map[ExprId, Int]): String = { + rewriteToOmniJsonExpressionLiteral(expr, exprsIndexMap, expr.dataType) + } + + def rewriteToOmniJsonExpressionLiteral(expr: Expression, + exprsIndexMap: Map[ExprId, Int], + returnDatatype: DataType): String = { + expr match { + case unscaledValue: UnscaledValue => + ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + + "\"function_name\":\"UnscaledValue\", \"arguments\":[%s, %s, %s]}") + .format(sparkTypeToOmniExpJsonType(unscaledValue.dataType), + rewriteToOmniJsonExpressionLiteral(unscaledValue.child, exprsIndexMap), + toOmniJsonLiteral( + Literal(unscaledValue.child.dataType.asInstanceOf[DecimalType].precision, IntegerType)), + toOmniJsonLiteral( + Literal(unscaledValue.child.dataType.asInstanceOf[DecimalType].scale, IntegerType))) + + // omni not support return null, now rewrite to if(IsOverflowDecimal())? NULL:MakeDecimal() + case checkOverflow: CheckOverflow => + ("{\"exprType\":\"IF\",\"returnType\":%s," + + "\"condition\":{\"exprType\":\"FUNCTION\",\"returnType\":%s," + + "\"function_name\":\"IsOverflowDecimal\",\"arguments\":[%s,%s,%s,%s,%s]}," + + "\"if_true\":%s," + + "\"if_false\":{\"exprType\":\"FUNCTION\",\"returnType\":%s," + + "\"function_name\":\"MakeDecimal\", \"arguments\":[%s,%s,%s,%s,%s]}" + + "}") + .format(sparkTypeToOmniExpJsonType(checkOverflow.dataType), + // IsOverflowDecimal returnType + sparkTypeToOmniExpJsonType(BooleanType), + // IsOverflowDecimal arguments + rewriteToOmniJsonExpressionLiteral(checkOverflow.child, exprsIndexMap, + DecimalType(checkOverflow.dataType.precision, checkOverflow.dataType.scale)), + toOmniJsonLiteral( + Literal(checkOverflow.dataType.precision, IntegerType)), + toOmniJsonLiteral( + Literal(checkOverflow.dataType.scale, IntegerType)), + toOmniJsonLiteral( + Literal(checkOverflow.dataType.precision, IntegerType)), + toOmniJsonLiteral( + Literal(checkOverflow.dataType.scale, IntegerType)), + // if_true + toOmniJsonLiteral( + Literal(null, checkOverflow.dataType)), + // if_false + sparkTypeToOmniExpJsonType( + DecimalType(checkOverflow.dataType.precision, checkOverflow.dataType.scale)), + rewriteToOmniJsonExpressionLiteral(checkOverflow.child, + exprsIndexMap, + DecimalType(checkOverflow.dataType.precision, checkOverflow.dataType.scale)), + toOmniJsonLiteral( + Literal(checkOverflow.dataType.precision, IntegerType)), + toOmniJsonLiteral( + Literal(checkOverflow.dataType.scale, IntegerType)), + toOmniJsonLiteral( + Literal(checkOverflow.dataType.precision, IntegerType)), + toOmniJsonLiteral( + Literal(checkOverflow.dataType.scale, IntegerType))) + + case makeDecimal: MakeDecimal => + makeDecimal.child.dataType match { + case decimalChild: DecimalType => + ("{\"exprType\": \"FUNCTION\", \"returnType\":%s," + + "\"function_name\": \"MakeDecimal\", \"arguments\": [%s,%s,%s,%s,%s]}") + .format(sparkTypeToOmniExpJsonType(makeDecimal.dataType), + rewriteToOmniJsonExpressionLiteral(makeDecimal.child, exprsIndexMap), + toOmniJsonLiteral( + Literal(decimalChild.precision, IntegerType)), + toOmniJsonLiteral( + Literal(decimalChild.scale, IntegerType)), + toOmniJsonLiteral( + Literal(makeDecimal.precision, IntegerType)), + toOmniJsonLiteral( + Literal(makeDecimal.scale, IntegerType))) + + case longChild: LongType => + ("{\"exprType\": \"FUNCTION\", \"returnType\":%s," + + "\"function_name\": \"MakeDecimal\", \"arguments\": [%s,%s,%s]}") + .format(sparkTypeToOmniExpJsonType(makeDecimal.dataType), + rewriteToOmniJsonExpressionLiteral(makeDecimal.child, exprsIndexMap), + toOmniJsonLiteral( + Literal(makeDecimal.precision, IntegerType)), + toOmniJsonLiteral( + Literal(makeDecimal.scale, IntegerType))) + case _ => + throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}") + } + + case promotePrecision: PromotePrecision => + rewriteToOmniJsonExpressionLiteral(promotePrecision.child, exprsIndexMap) + + case sub: Subtract => + ("{\"exprType\":\"BINARY\",\"returnType\":%s," + + "\"operator\":\"SUBTRACT\",\"left\":%s,\"right\":%s}").format( + sparkTypeToOmniExpJsonType(returnDatatype), + rewriteToOmniJsonExpressionLiteral(sub.left, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(sub.right, exprsIndexMap)) + + case add: Add => + ("{\"exprType\":\"BINARY\",\"returnType\":%s," + + "\"operator\":\"ADD\",\"left\":%s,\"right\":%s}").format( + sparkTypeToOmniExpJsonType(returnDatatype), + rewriteToOmniJsonExpressionLiteral(add.left, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(add.right, exprsIndexMap)) + + case mult: Multiply => + ("{\"exprType\":\"BINARY\",\"returnType\":%s," + + "\"operator\":\"MULTIPLY\",\"left\":%s,\"right\":%s}").format( + sparkTypeToOmniExpJsonType(returnDatatype), + rewriteToOmniJsonExpressionLiteral(mult.left, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(mult.right, exprsIndexMap)) + + case divide: Divide => + ("{\"exprType\":\"BINARY\",\"returnType\":%s," + + "\"operator\":\"DIVIDE\",\"left\":%s,\"right\":%s}").format( + sparkTypeToOmniExpJsonType(returnDatatype), + rewriteToOmniJsonExpressionLiteral(divide.left, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(divide.right, exprsIndexMap)) + + case mod: Remainder => + ("{\"exprType\":\"BINARY\",\"returnType\":%s," + + "\"operator\":\"MODULUS\",\"left\":%s,\"right\":%s}").format( + sparkTypeToOmniExpJsonType(returnDatatype), + rewriteToOmniJsonExpressionLiteral(mod.left, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(mod.right, exprsIndexMap)) + + case greaterThan: GreaterThan => + ("{\"exprType\":\"BINARY\",\"returnType\":%s," + + "\"operator\":\"GREATER_THAN\",\"left\":%s,\"right\":%s}").format( + sparkTypeToOmniExpJsonType(greaterThan.dataType), + rewriteToOmniJsonExpressionLiteral(greaterThan.left, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(greaterThan.right, exprsIndexMap)) + + case greaterThanOrEq: GreaterThanOrEqual => + ("{\"exprType\":\"BINARY\",\"returnType\":%s," + + "\"operator\":\"GREATER_THAN_OR_EQUAL\",\"left\":%s,\"right\":%s}").format( + sparkTypeToOmniExpJsonType(greaterThanOrEq.dataType), + rewriteToOmniJsonExpressionLiteral(greaterThanOrEq.left, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(greaterThanOrEq.right, exprsIndexMap)) + + case lessThan: LessThan => + ("{\"exprType\":\"BINARY\",\"returnType\":%s," + + "\"operator\":\"LESS_THAN\",\"left\":%s,\"right\":%s}").format( + sparkTypeToOmniExpJsonType(lessThan.dataType), + rewriteToOmniJsonExpressionLiteral(lessThan.left, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(lessThan.right, exprsIndexMap)) + + case lessThanOrEq: LessThanOrEqual => + ("{\"exprType\":\"BINARY\",\"returnType\":%s," + + "\"operator\":\"LESS_THAN_OR_EQUAL\",\"left\":%s,\"right\":%s}").format( + sparkTypeToOmniExpJsonType(lessThanOrEq.dataType), + rewriteToOmniJsonExpressionLiteral(lessThanOrEq.left, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(lessThanOrEq.right, exprsIndexMap)) + + case equal: EqualTo => + ("{\"exprType\":\"BINARY\",\"returnType\":%s," + + "\"operator\":\"EQUAL\",\"left\":%s,\"right\":%s}").format( + sparkTypeToOmniExpJsonType(equal.dataType), + rewriteToOmniJsonExpressionLiteral(equal.left, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(equal.right, exprsIndexMap)) + + case or: Or => + ("{\"exprType\":\"BINARY\",\"returnType\":%s," + + "\"operator\":\"OR\",\"left\":%s,\"right\":%s}").format( + sparkTypeToOmniExpJsonType(or.dataType), + rewriteToOmniJsonExpressionLiteral(or.left, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(or.right, exprsIndexMap)) + + case and: And => + ("{\"exprType\":\"BINARY\",\"returnType\":%s," + + "\"operator\":\"AND\",\"left\":%s,\"right\":%s}").format( + sparkTypeToOmniExpJsonType(and.dataType), + rewriteToOmniJsonExpressionLiteral(and.left, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(and.right, exprsIndexMap)) + + case alias: Alias => rewriteToOmniJsonExpressionLiteral(alias.child, exprsIndexMap) + case literal: Literal => toOmniJsonLiteral(literal) + case not: Not => + "{\"exprType\":\"UNARY\",\"returnType\":%s,\"operator\":\"not\",\"expr\":%s}".format( + sparkTypeToOmniExpJsonType(BooleanType), + rewriteToOmniJsonExpressionLiteral(not.child, exprsIndexMap)) + + case isnotnull: IsNotNull => + ("{\"exprType\":\"UNARY\",\"returnType\":%s, \"operator\":\"not\"," + + "\"expr\":{\"exprType\":\"IS_NULL\",\"returnType\":%s," + + "\"arguments\":[%s]}}").format(sparkTypeToOmniExpJsonType(BooleanType), + sparkTypeToOmniExpJsonType(BooleanType), + rewriteToOmniJsonExpressionLiteral(isnotnull.child, exprsIndexMap)) + + // Substring + case subString: Substring => + ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + + "\"function_name\":\"substr\", \"arguments\":[%s,%s,%s]}") + .format(sparkTypeToOmniExpJsonType(subString.dataType), + rewriteToOmniJsonExpressionLiteral(subString.str, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(subString.pos, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(subString.len, exprsIndexMap)) + + // Cast + case cast: Cast => + unsupportedCastCheck(expr, cast) + val returnType = sparkTypeToOmniExpJsonType(cast.dataType) + cast.dataType match { + case StringType => + ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + + "\"width\":50,\"function_name\":\"CAST\", \"arguments\":[%s]}") + .format(returnType, rewriteToOmniJsonExpressionLiteral(cast.child, exprsIndexMap)) + // for to decimal omni default cast no precision and scale handle + // use MakeDecimal to take it + case dt: DecimalType => + if (cast.child.dataType.isInstanceOf[DoubleType]) { + ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + + "\"function_name\":\"CAST\", \"arguments\":[%s,%s,%s]}") + .format(returnType, rewriteToOmniJsonExpressionLiteral(cast.child, exprsIndexMap), + toOmniJsonLiteral(Literal(dt.precision, IntegerType)), + toOmniJsonLiteral(Literal(dt.scale, IntegerType))) + } else { + rewriteToOmniJsonExpressionLiteral( + MakeDecimal(cast.child, dt.precision, dt.scale), exprsIndexMap) + } + case _ => + ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + + "\"function_name\":\"CAST\",\"arguments\":[%s]}") + .format(returnType, rewriteToOmniJsonExpressionLiteral(cast.child, exprsIndexMap)) + } + // Abs + case abs: Abs => + "{\"exprType\":\"FUNCTION\",\"returnType\":%s,\"function_name\":\"abs\", \"arguments\":[%s]}" + .format(sparkTypeToOmniExpJsonType(abs.dataType), + rewriteToOmniJsonExpressionLiteral(abs.child, exprsIndexMap)) + + // In + case in: In => + "{\"exprType\":\"IN\",\"returnType\":%s, \"arguments\":%s}".format( + sparkTypeToOmniExpJsonType(in.dataType), + in.children.map(child => rewriteToOmniJsonExpressionLiteral(child, exprsIndexMap)) + .mkString("[", ",", "]")) + + // coming from In expression with optimizerInSetConversionThreshold + case inSet: InSet => + "{\"exprType\":\"IN\",\"returnType\":%s, \"arguments\":[%s, %s]}" + .format(sparkTypeToOmniExpJsonType(inSet.dataType), + rewriteToOmniJsonExpressionLiteral(inSet.child, exprsIndexMap), + inSet.set.map(child => + toOmniJsonLiteral(Literal(child, inSet.child.dataType))).mkString(",")) + + case ifExp: If => + "{\"exprType\":\"IF\",\"returnType\":%s,\"condition\":%s,\"if_true\":%s,\"if_false\":%s}" + .format(sparkTypeToOmniExpJsonType(ifExp.dataType), + rewriteToOmniJsonExpressionLiteral(ifExp.predicate, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(ifExp.trueValue, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(ifExp.falseValue, exprsIndexMap)) + + // only support with one case condition, for omni rewrite to if(A, B, C) + case caseWhen: CaseWhen => + "{\"exprType\":\"IF\",\"returnType\":%s,\"condition\":%s,\"if_true\":%s,\"if_false\":%s}" + .format(sparkTypeToOmniExpJsonType(caseWhen.dataType), + rewriteToOmniJsonExpressionLiteral(caseWhen.branches(0)._1, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(caseWhen.branches(0)._2, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(caseWhen.elseValue.get, exprsIndexMap)) + + case coalesce: Coalesce => + "{\"exprType\":\"COALESCE\",\"returnType\":%s, \"value1\":%s,\"value2\":%s}".format( + sparkTypeToOmniExpJsonType(coalesce.dataType), + rewriteToOmniJsonExpressionLiteral(coalesce.children(0), exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(coalesce.children(1), exprsIndexMap)) + + case concat: Concat => + getConcatJsonStr(concat, exprsIndexMap) + case attr: Attribute => toOmniJsonAttribute(attr, exprsIndexMap(attr.exprId)) + case _ => + throw new UnsupportedOperationException(s"Unsupported expression: $expr") + } + } + + private def checkInputDataTypes(children: Seq[Expression]): Unit = { + val childTypes = children.map(_.dataType) + for (dataType <- childTypes) { + if (!dataType.isInstanceOf[StringType]) { + throw new UnsupportedOperationException(s"Invalid input dataType:$dataType for concat") + } + } + } + + private def getConcatJsonStr(concat: Concat, exprsIndexMap: Map[ExprId, Int]): String = { + val children: Seq[Expression] = concat.children + checkInputDataTypes(children) + val template = "{\"exprType\": \"FUNCTION\",\"returnType\":%s," + + "\"function_name\": \"concat\", \"arguments\": [%s, %s]}" + val returnType = sparkTypeToOmniExpJsonType(concat.dataType) + if (children.length == 1) { + return rewriteToOmniJsonExpressionLiteral(children.head, exprsIndexMap) + } + var res = template.format(returnType, + rewriteToOmniJsonExpressionLiteral(children.head, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(children(1), exprsIndexMap)) + for (i <- 2 until children.length) { + res = template.format(returnType, res, + rewriteToOmniJsonExpressionLiteral(children(i), exprsIndexMap)) + } + res + } + + def toOmniJsonAttribute(attr: Attribute, colVal: Int): String = { + + val omniDataType = sparkTypeToOmniExpType(attr.dataType) + attr.dataType match { + case StringType => + ("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":%s," + + "\"colVal\":%d,\"width\":%d}").format(omniDataType, colVal, + getStringLength(attr.metadata)) + case dt: DecimalType => + ("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":%s," + + "\"colVal\":%d,\"precision\":%s, \"scale\":%s}").format(omniDataType, + colVal, dt.precision, dt.scale) + case _ => ("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":%s," + + "\"colVal\":%d}").format(omniDataType, colVal) + } + } + + def toOmniJsonLiteral(literal: Literal): String = { + val omniType = sparkTypeToOmniExpType(literal.dataType) + val value = literal.value + if (value == null) { + return "{\"exprType\":\"LITERAL\",\"dataType\":%s,\"isNull\":%b}".format(sparkTypeToOmniExpJsonType(literal.dataType), true) + } + literal.dataType match { + case StringType => + ("{\"exprType\":\"LITERAL\",\"dataType\":%s," + + "\"isNull\":%b, \"value\":\"%s\",\"width\":%d}") + .format(omniType, false, value.toString, value.toString.length) + case dt: DecimalType => + if (DecimalType.is64BitDecimalType(dt)) { + ("{\"exprType\":\"LITERAL\",\"dataType\":%s," + + "\"isNull\":%b,\"value\":%s,\"precision\":%s, \"scale\":%s}").format(omniType, + false, value.asInstanceOf[Decimal].toUnscaledLong, dt.precision, dt.scale) + } else { + // NOTES: decimal128 literal value need use string format + ("{\"exprType\":\"LITERAL\",\"dataType\":%s," + + "\"isNull\":%b, \"value\":\"%s\", \"precision\":%s, \"scale\":%s}").format(omniType, + false, value.asInstanceOf[Decimal].toJavaBigDecimal.unscaledValue().toString(), + dt.precision, dt.scale) + } + case _ => + "{\"exprType\":\"LITERAL\",\"dataType\":%s, \"isNull\":%b, \"value\":%s}" + .format(omniType, false, value) + } + } + + def toOmniAggFunTypeWithFinal(agg: AggregateExpression, isHashAgg: Boolean = false, isFinal: Boolean): FunctionType = { + agg.aggregateFunction match { + case Sum(_) => { + if (isHashAgg) { + if (agg.dataType.isInstanceOf[DecimalType]) { + new UnsupportedOperationException("HashAgg not supported decimal input") + } + } + OMNI_AGGREGATION_TYPE_SUM + } + case Max(_) => OMNI_AGGREGATION_TYPE_MAX + case Average(_) => OMNI_AGGREGATION_TYPE_AVG + case Min(_) => OMNI_AGGREGATION_TYPE_MIN + case Count(Literal(1, IntegerType) :: Nil) | Count(ArrayBuffer(Literal(1, IntegerType))) => + if (isFinal) { + OMNI_AGGREGATION_TYPE_COUNT_COLUMN + } else { + OMNI_AGGREGATION_TYPE_COUNT_ALL + } + case Count(_) => OMNI_AGGREGATION_TYPE_COUNT_COLUMN + case _ => throw new UnsupportedOperationException(s"Unsupported aggregate function: $agg") + } + } + + def toOmniAggFunType(agg: AggregateExpression, isHashAgg: Boolean = false): FunctionType = { + toOmniAggFunTypeWithFinal(agg, isHashAgg, false) + } + + def toOmniWindowFunType(window: Expression): FunctionType = { + window match { + case Rank(_) => OMNI_WINDOW_TYPE_RANK + case RowNumber() => OMNI_WINDOW_TYPE_ROW_NUMBER + case _ => throw new UnsupportedOperationException(s"Unsupported window function: $window") + } + } + + def sparkTypeToOmniExpType(datatype: DataType): String = { + datatype match { + case ShortType => OMNI_SHOR_TYPE + case IntegerType => OMNI_INTEGER_TYPE + case LongType => OMNI_LONG_TYPE + case DoubleType => OMNI_DOUBLE_TYPE + case BooleanType => OMNI_BOOLEAN_TYPE + case StringType => OMNI_VARCHAR_TYPE + case DateType => OMNI_DATE_TYPE + case dt: DecimalType => + checkDecimalTypeWhiteList(dt) + if (DecimalType.is64BitDecimalType(dt)) { + OMNI_DECIMAL64_TYPE + } else { + OMNI_DECIMAL128_TYPE + } + case _ => + throw new UnsupportedOperationException(s"Unsupported datatype: $datatype") + } + } + + def sparkTypeToOmniExpJsonType(datatype: DataType): String = { + val omniTypeIdStr = sparkTypeToOmniExpType(datatype) + datatype match { + case StringType => + "%s,\"width\":%s".format(omniTypeIdStr, DEFAULT_STRING_TYPE_LENGTH) + case dt: DecimalType => + "%s,\"precision\":%s,\"scale\":%s".format(omniTypeIdStr, dt.precision, dt.scale) + case _ => + omniTypeIdStr + } + } + + def sparkTypeToOmniType(dataType: DataType, metadata: Metadata = Metadata.empty): + nova.hetu.omniruntime.`type`.DataType = { + dataType match { + case ShortType => + ShortDataType.SHORT + case IntegerType => + IntDataType.INTEGER + case LongType => + LongDataType.LONG + case DoubleType => + DoubleDataType.DOUBLE + case BooleanType => + BooleanDataType.BOOLEAN + case StringType => + new VarcharDataType(getStringLength(metadata)) + case DateType => + Date32DataType.DATE32 + case dt: DecimalType => + checkDecimalTypeWhiteList(dt) + if (DecimalType.is64BitDecimalType(dt)) { + new Decimal64DataType(dt.precision, dt.scale) + } else { + new Decimal128DataType(dt.precision, dt.scale) + } + case _ => + throw new UnsupportedOperationException(s"Unsupported datatype: $dataType") + } + } + + def sparkProjectionToOmniJsonProjection(attr: Attribute, colVal: Int): String = { + val dataType: DataType = attr.dataType + val metadata = attr.metadata + val omniDataType: String = sparkTypeToOmniExpType(dataType) + dataType match { + case ShortType | IntegerType | LongType | DoubleType | BooleanType | DateType => + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":%s,\"colVal\":%d}" + .format(omniDataType, colVal) + case StringType => + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":%s,\"colVal\":%d,\"width\":%d}" + .format(omniDataType, colVal, getStringLength(metadata)) + case dt: DecimalType => + checkDecimalTypeWhiteList(dt) + var omniDataType = OMNI_DECIMAL128_TYPE + if (DecimalType.is64BitDecimalType(dt)) { + omniDataType = OMNI_DECIMAL64_TYPE + } + ("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":%s,\"colVal\":%d," + + "\"precision\":%s,\"scale\":%s}") + .format(omniDataType, colVal, dt.precision, dt.scale) + case _ => + throw new UnsupportedOperationException(s"Unsupported datatype: $dataType") + } + } + + private def getStringLength(metadata: Metadata): Int = { + var width = DEFAULT_STRING_TYPE_LENGTH + if (getRawTypeString(metadata).isDefined) { + val CHAR_TYPE = """char\(\s*(\d+)\s*\)""".r + val VARCHAR_TYPE = """varchar\(\s*(\d+)\s*\)""".r + val stringOrigDefine = getRawTypeString(metadata).get + stringOrigDefine match { + case CHAR_TYPE(length) => width = length.toInt + case VARCHAR_TYPE(length) => width = length.toInt + case _ => + } + } + width + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala new file mode 100644 index 0000000000000000000000000000000000000000..de5638f0a7d927380a129538ba6f664830765f90 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.serialize + +import com.google.common.io.ByteStreams +import com.huawei.boostkit.spark.ColumnarPluginConfig +import com.huawei.boostkit.spark.compress.{CompressionUtil, DecompressionStream} +import java.io.{BufferedInputStream, DataInputStream, EOFException, InputStream, OutputStream} +import java.nio.ByteBuffer +import scala.reflect.ClassTag +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.vectorized.ColumnarBatch + +class ColumnarBatchSerializer(readBatchNumRows: SQLMetric, numOutputRows: SQLMetric) + extends Serializer + with Serializable { + /** Creates a new [[SerializerInstance]]. */ + override def newInstance(): SerializerInstance = + new ColumnarBatchSerializerInstance(readBatchNumRows, numOutputRows) +} + +private class ColumnarBatchSerializerInstance( + readBatchNumRows: SQLMetric, + numOutputRows: SQLMetric) + extends SerializerInstance with Logging { + override def deserializeStream(in: InputStream): DeserializationStream = { + new DeserializationStream { + val columnarConf = ColumnarPluginConfig.getSessionConf + val shuffleCompressBlockSize = columnarConf.columnarShuffleCompressBlockSize + val enableShuffleCompress = columnarConf.enableShuffleCompress + var shuffleCompressionCodec = columnarConf.columnarShuffleCompressionCodec + + if (!enableShuffleCompress) { + shuffleCompressionCodec = "uncompressed" + } + + private var numBatchesTotal: Long = _ + private var numRowsTotal: Long = _ + + private[this] val dIn: DataInputStream = if (enableShuffleCompress) { + val codec = CompressionUtil.createCodec(shuffleCompressionCodec) + new DataInputStream(new BufferedInputStream( + new DecompressionStream(in, codec, shuffleCompressBlockSize))) + } else { + new DataInputStream(new BufferedInputStream(in)) + } + private[this] var columnarBuffer: Array[Byte] = new Array[Byte](1024) + val ibuffer: ByteBuffer = ByteBuffer.allocateDirect(4) + + private[this] val EOF: Int = -1 + + override def asKeyValueIterator: Iterator[(Int, ColumnarBatch)] = { + new Iterator[(Int, ColumnarBatch)] { + private[this] def readSize(): Int = try { + dIn.readInt() + } catch { + case e: EOFException => + dIn.close() + EOF + } + + private[this] var dataSize: Int = readSize() + override def hasNext: Boolean = dataSize != EOF + + override def next(): (Int, ColumnarBatch) = { + if (columnarBuffer.length < dataSize) { + columnarBuffer = new Array[Byte](dataSize) + } + ByteStreams.readFully(dIn, columnarBuffer, 0, dataSize) + // protobuf serialize + val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(columnarBuffer.slice(0, dataSize)) + dataSize = readSize() + if (dataSize == EOF) { + dIn.close() + columnarBuffer = null + } + (0, columnarBatch) + } + } + } + + override def asIterator: Iterator[Any] = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def readKey[T: ClassTag](): T = { + // We skipped serialization of the key in writeKey(), so just return a dummy value since + // this is going to be discarded anyways. + null.asInstanceOf[T] + } + + override def readValue[T: ClassTag](): T = { + val dataSize = dIn.readInt() + if (columnarBuffer.size < dataSize) { + columnarBuffer = new Array[Byte](dataSize) + } + ByteStreams.readFully(dIn, columnarBuffer, 0, dataSize) + // protobuf serialize + val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(columnarBuffer.slice(0, dataSize)) + numBatchesTotal += 1 + numRowsTotal += columnarBatch.numRows() + columnarBatch.asInstanceOf[T] + } + + override def readObject[T: ClassTag](): T = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def close(): Unit = { + if (numBatchesTotal > 0) { + readBatchNumRows.set(numRowsTotal.toDouble / numBatchesTotal) + } + numOutputRows += numRowsTotal + dIn.close() + } + } + } + + override def serialize[T: ClassTag](t: T): ByteBuffer = + throw new UnsupportedOperationException + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException + + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException + + override def serializeStream(s: OutputStream): SerializationStream = + throw new UnsupportedOperationException +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala new file mode 100644 index 0000000000000000000000000000000000000000..1661874db7c6e0302dba52d2be971da9e9167ecc --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.util + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ +import nova.hetu.omniruntime.operator.OmniOperator +import nova.hetu.omniruntime.vector._ + +import org.apache.spark.sql.catalyst.expressions.{Attribute, ExprId, SortOrder} +import org.apache.spark.sql.execution.datasources.orc.OrcColumnVector +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.vectorized.{OmniColumnVector, OnHeapColumnVector} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} + +import java.util + +object OmniAdaptorUtil { + def transColBatchToOmniVecs(cb: ColumnarBatch): Array[Vec] = { + transColBatchToOmniVecs(cb, false) + } + + def transColBatchToOmniVecs(cb: ColumnarBatch, isSlice: Boolean): Array[Vec] = { + val input = new Array[Vec](cb.numCols()) + for (i <- 0 until cb.numCols()) { + val omniVec: Vec = cb.column(i) match { + case vector: OrcColumnVector => + transColumnVector(vector, cb.numRows()) + case vector: OnHeapColumnVector => + transColumnVector(vector, cb.numRows()) + case vector: OmniColumnVector => + if (!isSlice) { + vector.getVec + } else { + vector.getVec.slice(0, cb.numRows()) + } + case _ => + throw new UnsupportedOperationException("unsupport column vector!") + } + input(i) = omniVec + } + input + } + + def transColumnVector(columnVector: ColumnVector, columnSize : Int): Vec = { + val datatype: DataType = columnVector.dataType() + val vec: Vec = datatype match { + case LongType => + val vec = new LongVec(columnSize) + val values = new Array[Long](columnSize) + for (i <- 0 until columnSize) { + if (!columnVector.isNullAt(i)) { + values(i) = columnVector.getLong(i) + } else { + vec.setNull(i) + } + } + vec.put(values, 0, 0, columnSize) + vec + case DateType | IntegerType => + val vec = new IntVec(columnSize) + val values = new Array[Int](columnSize) + for (i <- 0 until columnSize) { + if (!columnVector.isNullAt(i)) { + values(i) = columnVector.getInt(i) + } else { + vec.setNull(i) + } + } + vec.put(values, 0, 0, columnSize) + vec + case ShortType => + val vec = new ShortVec(columnSize) + val values = new Array[Short](columnSize) + for (i <- 0 until columnSize) { + if (!columnVector.isNullAt(i)) { + values(i) = columnVector.getShort(i) + } else { + vec.setNull(i) + } + } + vec.put(values, 0, 0, columnSize) + vec + case DoubleType => + val vec = new DoubleVec(columnSize) + val values = new Array[Double](columnSize) + for (i <- 0 until columnSize) { + if (!columnVector.isNullAt(i)) { + values(i) = columnVector.getDouble(i) + } else { + vec.setNull(i) + } + } + vec.put(values, 0, 0, columnSize) + vec + case StringType => + var totalSize = 0 + val offsets = new Array[Int](columnSize + 1) + for (i <- 0 until columnSize) { + if (null != columnVector.getUTF8String(i)) { + val strLen: Int = columnVector.getUTF8String(i).getBytes.length + totalSize += strLen + } + offsets(i + 1) = totalSize + } + val vec = new VarcharVec(totalSize, columnSize) + val values = new Array[Byte](totalSize) + for (i <- 0 until columnSize) { + if (null != columnVector.getUTF8String(i)) { + System.arraycopy(columnVector.getUTF8String(i).getBytes, 0, values, + offsets(i), offsets(i + 1) - offsets(i)) + } else { + vec.setNull(i) + } + } + vec.put(0, values, 0, offsets, 0, columnSize) + vec + case BooleanType => + val vec = new BooleanVec(columnSize) + val values = new Array[Boolean](columnSize) + for (i <- 0 until columnSize) { + if (!columnVector.isNullAt(i)) { + values(i) = columnVector.getBoolean(i) + } else { + vec.setNull(i) + } + } + vec.put(values, 0, 0, columnSize) + vec + case t: DecimalType => + if (DecimalType.is64BitDecimalType(datatype)) { + val vec = new LongVec(columnSize) + val values = new Array[Long](columnSize) + for (i <- 0 until columnSize) { + if (!columnVector.isNullAt(i)) { + values(i) = columnVector.getDecimal(i, t.precision, t.scale).toUnscaledLong + } else { + vec.setNull(i) + } + } + vec.put(values, 0, 0, columnSize) + vec + } else { + val vec = new Decimal128Vec(columnSize) + for (i <- 0 until columnSize) { + if (!columnVector.isNullAt(i)) { + vec.setBigInteger(i, + columnVector.getDecimal(i, t.precision, t.scale).toJavaBigDecimal.unscaledValue()) + } else { + vec.setNull(i) + } + } + vec + } + case _ => + throw new UnsupportedOperationException("unsupport column vector!") + } + vec + } + + def genSortParam(output: Seq[Attribute], sortOrder: Seq[SortOrder]): + (Array[nova.hetu.omniruntime.`type`.DataType], Array[Int], Array[Int], Array[String]) = { + val inputColSize: Int = output.size + val sourceTypes = new Array[nova.hetu.omniruntime.`type`.DataType](inputColSize) + val ascendings = new Array[Int](sortOrder.size) + val nullFirsts = new Array[Int](sortOrder.size) + val sortColsExp = new Array[String](sortOrder.size) + val omniAttrExpsIdMap: Map[ExprId, Int] = getExprIdMap(output) + + output.zipWithIndex.foreach { case (inputAttr, i) => + sourceTypes(i) = sparkTypeToOmniType(inputAttr.dataType, inputAttr.metadata) + } + sortOrder.zipWithIndex.foreach { case (sortAttr, i) => + sortColsExp(i) = rewriteToOmniJsonExpressionLiteral(sortAttr.child, omniAttrExpsIdMap) + ascendings(i) = if (sortAttr.isAscending) { + 1 + } else { + 0 + } + nullFirsts(i) = sortAttr.nullOrdering.sql match { + case "NULLS LAST" => 0 + case _ => 1 + } + } + checkOmniJsonWhiteList("", sortColsExp.asInstanceOf[Array[AnyRef]]) + (sourceTypes, ascendings, nullFirsts, sortColsExp) + } + + def addAllAndGetIterator(operator: OmniOperator, + inputIter: Iterator[ColumnarBatch], schema: StructType, + addInputTime: SQLMetric, numInputVecBatchs: SQLMetric, + numInputRows: SQLMetric, getOutputTime: SQLMetric, + numOutputVecBatchs: SQLMetric, numOutputRows: SQLMetric, + outputDataSize: SQLMetric): Iterator[ColumnarBatch] = { + while (inputIter.hasNext) { + val batch: ColumnarBatch = inputIter.next() + numInputVecBatchs += 1 + val input: Array[Vec] = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + val startInput: Long = System.nanoTime() + operator.addInput(vecBatch) + addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) + numInputRows += batch.numRows() + } + val startGetOp: Long = System.nanoTime() + val results: util.Iterator[VecBatch] = operator.getOutput + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + + new Iterator[ColumnarBatch] { + override def hasNext: Boolean = { + val startGetOp: Long = System.nanoTime() + val hasNext = results.hasNext + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + hasNext + } + + override def next(): ColumnarBatch = { + val startGetOp: Long = System.nanoTime() + val vecBatch: VecBatch = results.next() + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, schema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(i)) + outputDataSize += vecBatch.getVectors()(i).getRealValueBufCapacityInBytes + outputDataSize += vecBatch.getVectors()(i).getRealNullBufCapacityInBytes + outputDataSize += vecBatch.getVectors()(i).getRealOffsetBufCapacityInBytes + } + val sourceLength = vecBatch.getVectorCount + var destLength = schema.fields.length + while (destLength < sourceLength) { + vecBatch.getVectors()(destLength).close() // vecBatch releasing redundant columns + destLength += 1 + } + + // metrics + val rowCnt: Int = vecBatch.getRowCount + numOutputRows += rowCnt + numOutputVecBatchs += 1 + // close omni vecbetch + vecBatch.close() + new ColumnarBatch(vectors.toArray, rowCnt) + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala new file mode 100644 index 0000000000000000000000000000000000000000..4c27688cb741eec4913ea51e23e14dfa16aa6b64 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import com.huawei.boostkit.spark.vectorized.PartitionInfo +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.execution.metric.SQLMetric + +/** + * :: DeveloperApi :: + * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle, + * the RDD is transient since we don't need it on the executor side. + * + * @param _rdd the parent RDD + * @param partitioner partitioner used to partition the shuffle output + * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If not set + * explicitly then the default serializer, as specified by `spark.serializer` + * config option, will be used. + * @param keyOrdering key ordering for RDD's shuffles + * @param aggregator map/reduce-side aggregator for RDD's shuffle + * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine) + * @param shuffleWriterProcessor the processor to control the write behavior in ShuffleMapTask + * @param partitionInfo hold partitioning parameters needed by native splitter + */ +class ColumnarShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( + @transient private val _rdd: RDD[_ <: Product2[K, V]], + override val partitioner: Partitioner, + override val serializer: Serializer = SparkEnv.get.serializer, + override val keyOrdering: Option[Ordering[K]] = None, + override val aggregator: Option[Aggregator[K, V, C]] = None, + override val mapSideCombine: Boolean = false, + override val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, + val partitionInfo: PartitionInfo, + val dataSize: SQLMetric, + val bytesSpilled: SQLMetric, + val numInputRows: SQLMetric, + val splitTime: SQLMetric, + val spillTime: SQLMetric) + extends ShuffleDependency[K, V, C]( + _rdd, + partitioner, + serializer, + keyOrdering, + aggregator, + mapSideCombine, + shuffleWriterProcessor) {} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala new file mode 100644 index 0000000000000000000000000000000000000000..7eca3427ec3f6c618f84e70aeb85ce98d0267176 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import com.google.common.annotations.VisibleForTesting +import com.huawei.boostkit.spark.ColumnarPluginConfig +import com.huawei.boostkit.spark.jni.SparkJniWrapper +import com.huawei.boostkit.spark.vectorized.SplitResult +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import nova.hetu.omniruntime.vector.VecBatch +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils + +import java.io.IOException + +class ColumnarShuffleWriter[K, V]( + shuffleBlockResolver: IndexShuffleBlockResolver, + handle: BaseShuffleHandle[K, V, V], + mapId: Long, + writeMetrics: ShuffleWriteMetricsReporter) + extends ShuffleWriter[K, V] with Logging { + + private val dep = handle.dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]] + + private val blockManager = SparkEnv.get.blockManager + + private var stopping = false + + private var mapStatus: MapStatus = _ + + private val localDirs = blockManager.diskBlockManager.localDirs.mkString(",") + + val columnarConf = ColumnarPluginConfig.getSessionConf + val shuffleSpillBatchRowNum = columnarConf.columnarShuffleSpillBatchRowNum + val shuffleSpillMemoryThreshold = columnarConf.columnarShuffleSpillMemoryThreshold + val shuffleCompressBlockSize = columnarConf.columnarShuffleCompressBlockSize + val shuffleNativeBufferSize = columnarConf.columnarShuffleNativeBufferSize + val enableShuffleCompress = columnarConf.enableShuffleCompress + var shuffleCompressionCodec = columnarConf.columnarShuffleCompressionCodec + + if (!enableShuffleCompress) { + shuffleCompressionCodec = "uncompressed" + } + + private val jniWrapper = new SparkJniWrapper() + + private var nativeSplitter: Long = 0 + + private var splitResult: SplitResult = _ + + private var partitionLengths: Array[Long] = _ + + @throws[IOException] + override def write(records: Iterator[Product2[K, V]]): Unit = { + if (!records.hasNext) { + partitionLengths = new Array[Long](dep.partitioner.numPartitions) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, null) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + return + } + + val dataTmp = Utils.tempFileWith(shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)) + if (nativeSplitter == 0) { + nativeSplitter = jniWrapper.make( + dep.partitionInfo, + shuffleNativeBufferSize, + shuffleCompressionCodec, + dataTmp.getAbsolutePath, + blockManager.subDirsPerLocalDir, + localDirs, + shuffleCompressBlockSize, + shuffleSpillBatchRowNum, + shuffleSpillMemoryThreshold) + } + + + while (records.hasNext) { + val cb = records.next()._2.asInstanceOf[ColumnarBatch] + if (cb.numRows == 0 || cb.numCols == 0) { + logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols") + } else { + val startTime = System.nanoTime() + val input = transColBatchToOmniVecs(cb) + for (col <- 0 until cb.numCols()) { + dep.dataSize += input(col).getRealValueBufCapacityInBytes + dep.dataSize += input(col).getRealNullBufCapacityInBytes + dep.dataSize += input(col).getRealOffsetBufCapacityInBytes + } + val vb = new VecBatch(input, cb.numRows()) + jniWrapper.split(nativeSplitter, vb.getNativeVectorBatch) + dep.splitTime.add(System.nanoTime() - startTime) + dep.numInputRows.add(cb.numRows) + writeMetrics.incRecordsWritten(1) + } + } + val startTime = System.nanoTime() + splitResult = jniWrapper.stop(nativeSplitter) + + dep.splitTime.add(System.nanoTime() - startTime - splitResult.getTotalSpillTime - + splitResult.getTotalWriteTime - splitResult.getTotalComputePidTime) + dep.spillTime.add(splitResult.getTotalSpillTime) + dep.bytesSpilled.add(splitResult.getTotalBytesSpilled) + writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten) + writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime) + + partitionLengths = splitResult.getPartitionLengths + try { + shuffleBlockResolver.writeIndexFileAndCommit( + dep.shuffleId, + mapId, + partitionLengths, + dataTmp) + } finally { + if (dataTmp.exists() && !dataTmp.delete()) { + logError(s"Error while deleting temp file ${dataTmp.getAbsolutePath}") + } + } + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + } + + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + None + } else { + stopping = true + if (success) { + Option(mapStatus) + } else { + None + } + } + } finally { + if (nativeSplitter != 0) { + jniWrapper.close(nativeSplitter) + nativeSplitter = 0 + } + } + } + + @VisibleForTesting + def getPartitionLengths: Array[Long] = partitionLengths + +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala new file mode 100644 index 0000000000000000000000000000000000000000..3940dc0dc871ff20e65da5b052aa6e77bb295a01 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import java.io.InputStream +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch +import org.apache.spark.storage.BlockId +import org.apache.spark.util.collection.OpenHashSet + +class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + import ColumnarShuffleManager._ + + /** + * A mapping from shuffle ids to the task ids of mappers producing output for those shuffles. + */ + private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() + + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + + /** + * Obtains a [[ShuffleHandle]] to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (dependency.isInstanceOf[ColumnarShuffleDependency[_, _, _]]) { + logInfo(s"Registering ColumnarShuffle shuffleId: ${shuffleId}") + new ColumnarShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]]) + } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need map-side aggregation, then write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: + new SerializedShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + // Otherwise, buffer map outputs in a deserialized form: + new BaseShuffleHandle(shuffleId, dependency) + } + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + val mapTaskIds = + taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) + mapTaskIds.synchronized { + mapTaskIds.add(context.taskAttemptId()) + } + val env = SparkEnv.get + handle match { + case columnarShuffleHandle: ColumnarShuffleHandle[K@unchecked, V@unchecked] => + new ColumnarShuffleWriter(shuffleBlockResolver, columnarShuffleHandle, mapId, metrics) + case unsafeShuffleHandle: SerializedShuffleHandle[K@unchecked, V@unchecked] => + new UnsafeShuffleWriter( + env.blockManager, + context.taskMemoryManager(), + unsafeShuffleHandle, + mapId, + context, + env.conf, + metrics, + shuffleExecutorComponents) + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K@unchecked, V@unchecked] => + new BypassMergeSortShuffleWriter( + env.blockManager, + bypassMergeSortHandle, + mapId, + env.conf, + metrics, + shuffleExecutorComponents) + case other: BaseShuffleHandle[K@unchecked, V@unchecked, _] => + new SortShuffleWriter( + shuffleBlockResolver, + other, + mapId, + context, + shuffleExecutorComponents) + } + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + if (handle.isInstanceOf[ColumnarShuffleHandle[K, _]]) { + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, + context, + metrics, + serializerManager = bypassDecompressionSerializerManger, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) + } else { + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, + context, + metrics, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { mapTaskIds => + mapTaskIds.iterator.foreach { mapId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + shuffleBlockResolver.stop() + } +} + +private[spark] object ColumnarShuffleManager extends Logging { + private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { + val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap + executorComponents.initializeExecutor( + conf.getAppId, + SparkEnv.get.executorId, + extraConfigs.asJava) + executorComponents + } + + private def bypassDecompressionSerializerManger = + new SerializerManager( + SparkEnv.get.serializer, + SparkEnv.get.conf, + SparkEnv.get.securityManager.getIOEncryptionKey()) { + // Bypass the shuffle read decompression, decryption is not supported + override def wrapStream(blockId: BlockId, s: InputStream): InputStream = { + s + } + } +} + +private[spark] class ColumnarShuffleHandle[K, V]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, dependency) {} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala new file mode 100644 index 0000000000000000000000000000000000000000..96d3189b17f218bd4d2f4ca4467d88f28f799498 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -0,0 +1,485 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.concurrent.TimeUnit.NANOSECONDS +import com.huawei.boostkit.spark.Constant.{IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP} +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.operator.config.OperatorConfig +import nova.hetu.omniruntime.operator.filter.OmniFilterAndProjectOperatorFactory +import nova.hetu.omniruntime.vector.VecBatch + +import scala.collection.JavaConverters.seqAsJavaList +import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.ColumnarProjection.dealPartitionData +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.types.{LongType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch + +case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) + extends UnaryExecNode + with AliasAwareOutputPartitioning + with AliasAwareOutputOrdering { + + override def supportsColumnar: Boolean = true + + override def nodeName: String = "OmniColumnarProject" + + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + override lazy val metrics = Map( + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + + def buildCheck(): Unit = { + val omniAttrExpsIdMap = getExprIdMap(child.output) + child.output.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + val omniExpressions: Array[AnyRef] = projectList.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray + checkOmniJsonWhiteList("", omniExpressions) + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val addInputTime = longMetric("addInputTime") + val omniCodegenTime = longMetric("omniCodegenTime") + val getOutputTime = longMetric("getOutputTime") + + val omniAttrExpsIdMap = getExprIdMap(child.output) + val omniInputTypes = child.output.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + val omniExpressions = projectList.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray + + child.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => + dealPartitionData(numOutputRows, numOutputVecBatchs, addInputTime, omniCodegenTime, + getOutputTime, omniInputTypes, omniExpressions, iter, this.schema) + } + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") + } + + override protected def outputExpressions: Seq[NamedExpression] = projectList + + override protected def orderingExpressions: Seq[SortOrder] = child.outputOrdering + + override def verboseStringWithOperatorId(): String = { + s""" + |$formattedNodeName + |${ExplainUtils.generateFieldString("Output", projectList)} + |${ExplainUtils.generateFieldString("Input", child.output)} + |""".stripMargin + } +} + +case class ColumnarFilterExec(condition: Expression, child: SparkPlan) + extends UnaryExecNode with PredicateHelper { + + override def supportsColumnar: Boolean = true + override def nodeName: String = "OmniColumnarFilter" + + // Split out all the IsNotNulls from condition. + private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { + case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet) + case _ => false + } + + // If one expression and its children are null intolerant, it is null intolerant. + private def isNullIntolerant(expr: Expression): Boolean = expr match { + case e: NullIntolerant => e.children.forall(isNullIntolerant) + case _ => false + } + + // The columns that will filtered out by `IsNotNull` could be considered as not nullable. + private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId) + + override def output: Seq[Attribute] = { + child.output.map { a => + if (a.nullable && notNullAttributes.contains(a.exprId)) { + a.withNullability(false) + } else { + a + } + } + } + + override lazy val metrics = Map( + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsWithIndexInternal { (index, iter) => + val predicate = Predicate.create(condition, child.output) + predicate.initialize(0) + iter.filter { row => + val r = predicate.eval(row) + if (r) numOutputRows += 1 + r + } + } + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def verboseStringWithOperatorId(): String = { + s""" + |$formattedNodeName + |${ExplainUtils.generateFieldString("Input", child.output)} + |Condition : ${condition} + |""".stripMargin + } + + def buildCheck(): Unit = { + val omniAttrExpsIdMap = getExprIdMap(child.output) + child.output.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + val filterExpression = rewriteToOmniJsonExpressionLiteral(condition, omniAttrExpsIdMap) + checkOmniJsonWhiteList(filterExpression, new Array[AnyRef](0)) + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numInputRows = longMetric("numInputRows") + val numInputVecBatchs = longMetric("numInputVecBatchs") + val numOutputRows = longMetric("numOutputRows") + val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val addInputTime = longMetric("addInputTime") + val omniCodegenTime = longMetric("omniCodegenTime") + val getOutputTime = longMetric("getOutputTime") + + val omniAttrExpsIdMap = getExprIdMap(child.output) + val omniInputTypes = child.output.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + val omniProjectIndices = child.output.map( + exp => sparkProjectionToOmniJsonProjection(exp, omniAttrExpsIdMap(exp.exprId))).toArray + val omniExpression = rewriteToOmniJsonExpressionLiteral(condition, omniAttrExpsIdMap) + + child.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => + val startCodegen = System.nanoTime() + val filterOperatorFactory = new OmniFilterAndProjectOperatorFactory( + omniExpression, omniInputTypes, seqAsJavaList(omniProjectIndices), 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val filterOperator = filterOperatorFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + + // close operator + addLeakSafeTaskCompletionListener[Unit](_ => { + filterOperator.close() + }) + + val localSchema = this.schema + new Iterator[ColumnarBatch] { + private var results: java.util.Iterator[VecBatch] = _ + override def hasNext: Boolean = { + while ((results == null || !results.hasNext) && iter.hasNext) { + val batch = iter.next() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + val startInput = System.nanoTime() + filterOperator.addInput(vecBatch) + addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) + numInputVecBatchs += 1 + numInputRows += batch.numRows() + + val startGetOp = System.nanoTime() + results = filterOperator.getOutput + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + } + if (results == null) { + false + } else { + val startGetOp: Long = System.nanoTime() + val hasNext = results.hasNext + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + hasNext + } + } + + override def next(): ColumnarBatch = { + val startGetOp = System.nanoTime() + val vecBatch = results.next() + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, localSchema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(i)) + } + numOutputRows += vecBatch.getRowCount + numOutputVecBatchs += 1 + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + } + } + } + } +} + +case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], + condition: Expression, + child: SparkPlan) + extends UnaryExecNode + with AliasAwareOutputPartitioning + with AliasAwareOutputOrdering { + + override def supportsColumnar: Boolean = true + + override def nodeName: String = "OmniColumnarConditionProject" + + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + override lazy val metrics = Map( + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numInputRows = longMetric("numInputRows") + val numInputVecBatchs = longMetric("numInputVecBatchs") + val numOutputRows = longMetric("numOutputRows") + val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val addInputTime = longMetric("addInputTime") + val omniCodegenTime = longMetric("omniCodegenTime") + val getOutputTime = longMetric("getOutputTime") + + val omniAttrExpsIdMap = getExprIdMap(child.output) + val omniInputTypes = child.output.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + val omniExpressions = projectList.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray + val conditionExpression = rewriteToOmniJsonExpressionLiteral(condition, omniAttrExpsIdMap) + + child.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => + val startCodegen = System.nanoTime() + val operatorFactory = new OmniFilterAndProjectOperatorFactory( + conditionExpression, omniInputTypes, seqAsJavaList(omniExpressions), 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val operator = operatorFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + // close operator + addLeakSafeTaskCompletionListener[Unit](_ => { + operator.close() + }) + + val localSchema = this.schema + new Iterator[ColumnarBatch] { + private var results: java.util.Iterator[VecBatch] = _ + override def hasNext: Boolean = { + while ((results == null || !results.hasNext) && iter.hasNext) { + val batch = iter.next() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + val startInput = System.nanoTime() + operator.addInput(vecBatch) + addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) + numInputVecBatchs += 1 + numInputRows += batch.numRows() + + val startGetOp = System.nanoTime() + results = operator.getOutput + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + } + if (results == null) { + false + } else { + val startGetOp: Long = System.nanoTime() + val hasNext = results.hasNext + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + hasNext + } + } + + override def next(): ColumnarBatch = { + val startGetOp = System.nanoTime() + val vecBatch = results.next() + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, localSchema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(i)) + } + numOutputRows += vecBatch.getRowCount + numOutputVecBatchs += 1 + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + } + } + } + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") + } + + override protected def outputExpressions: Seq[NamedExpression] = projectList + + override protected def orderingExpressions: Seq[SortOrder] = child.outputOrdering + + override def verboseStringWithOperatorId(): String = { + s""" + |$formattedNodeName + |${ExplainUtils.generateFieldString("Output", projectList)} + |${ExplainUtils.generateFieldString("Input", child.output)} + |""".stripMargin + } +} + +/** + * Physical plan for unioning two plans, without a distinct. This is UNION ALL in SQL. + * + * If we change how this is implemented physically, we'd need to update + * [[org.apache.spark.sql.catalyst.plans.logical.Union.maxRowsPerPartition]]. + */ +case class ColumnarUnionExec(children: Seq[SparkPlan]) extends SparkPlan { + + override def nodeName: String = "OmniColumnarUnion" + + // updating nullability to make all the children consistent + override def output: Seq[Attribute] = { + children.map(_.output).transpose.map { attrs => + val firstAttr = attrs.head + val nullable = attrs.exists(_.nullable) + val newDt = attrs.map(_.dataType).reduce(StructType.merge) + if (firstAttr.dataType == newDt) { + firstAttr.withNullability(nullable) + } else { + AttributeReference(firstAttr.name, newDt, nullable, firstAttr.metadata)( + firstAttr.exprId, firstAttr.qualifier) + } + } + } + + def buildCheck(): Unit = { + val inputTypes = new Array[DataType](output.size) + output.zipWithIndex.foreach { + case (attr, i) => + inputTypes(i) = sparkTypeToOmniType(attr.dataType, attr.metadata) + } + } + + protected override def doExecute(): RDD[InternalRow] = + sparkContext.union(children.map(_.execute())) + + override def supportsColumnar: Boolean = true + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = + sparkContext.union(children.map(_.executeColumnar())) +} + +class ColumnarRangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) + extends RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) { + + private val maxRowCountPerBatch = 10000 + override def supportsColumnar: Boolean = true + override def supportCodegen: Boolean = false + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + + sqlContext + .sparkContext + .parallelize(0 until numSlices, numSlices) + .mapPartitionsWithIndex { (i, _) => + val partitionStart = (i * numElements) / numSlices * step + start + val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start + + def getSafeMargin(bi: BigInt): Long = + if (bi.isValidLong) { + bi.toLong + } else if (bi > 0) { + Long.MaxValue + } else { + Long.MinValue + } + val safePartitionStart = getSafeMargin(partitionStart) // inclusive + val safePartitionEnd = getSafeMargin(partitionEnd) // exclusive, unless start == this + val taskContext = TaskContext.get() + + val iter: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] { + private[this] var number: Long = safePartitionStart + private[this] var overflow: Boolean = false + + override def hasNext: Boolean = + if (!overflow) { + if (step > 0) { + number < safePartitionEnd + } else { + number > safePartitionEnd + } + } else false + + override def next(): ColumnarBatch = { + val start = number + val remainingSteps = (safePartitionEnd - start) / step + // Start is inclusive so we need to produce at least one row + val rowsThisBatch = Math.max(1, Math.min(remainingSteps, maxRowCountPerBatch)) + val endInclusive = start + ((rowsThisBatch - 1) * step) + number = endInclusive + step + if (number < endInclusive ^ step < 0) { + // we have Long.MaxValue + Long.MaxValue < Long.MaxValue + // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a + // step back, we are pretty sure that we have an overflow. + overflow = true + } + val vec = new OmniColumnVector(rowsThisBatch.toInt, LongType, true) + var s = start + for (i <- 0 until rowsThisBatch.toInt) { + vec.putLong(i, s) + s += step + } + numOutputRows += rowsThisBatch.toInt + new ColumnarBatch(Array(vec), rowsThisBatch.toInt) + } + } + new InterruptibleIterator(taskContext, iter) + } + } + + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def doCanonicalize(): SparkPlan = { + new ColumnarRangeExec( + range.canonicalized.asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Range]) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..bc7d79a91f12b1410ce43e160cb897abd46edd0d --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.concurrent._ + +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import nova.hetu.omniruntime.vector.VecBatch +import nova.hetu.omniruntime.vector.serialize.VecBatchSerializerFactory +import scala.concurrent.{ExecutionContext, Promise} +import scala.concurrent.duration.NANOSECONDS +import scala.util.control.NonFatal + +import org.apache.spark.{broadcast, SparkException} +import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.util.{SparkFatalException, ThreadUtils} + + +class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) + extends BroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) { + import ColumnarBroadcastExchangeExec._ + + override def nodeName: String = "OmniColumnarBroadcastExchange" + override def supportsColumnar: Boolean = true + + override lazy val metrics = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"), + "broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast")) + + @transient + private lazy val promise = Promise[broadcast.Broadcast[Any]]() + + @transient + private val timeout: Long = SQLConf.get.broadcastTimeout + + @transient + override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( + sqlContext.sparkSession, ColumnarBroadcastExchangeExec.executionContext) { + try { + // Setup a job group here so later it may get cancelled by groupId if necessary. + sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)", + interruptOnCancel = true) + val beforeCollect = System.nanoTime() + val numRows = longMetric("numOutputRows") + val dataSize = longMetric("dataSize") + // Use executeCollect/executeCollectIterator to avoid conversion to Scala types + val input = child.executeColumnar().mapPartitions { iter => + val serializer = VecBatchSerializerFactory.create() + new Iterator[Array[Byte]] { + override def hasNext: Boolean = { + iter.hasNext + } + + override def next(): Array[Byte] = { + val batch = iter.next() + val vectors = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(vectors, batch.numRows()) + numRows += vecBatch.getRowCount + val vecBatchSer = serializer.serialize(vecBatch) + dataSize += vecBatchSer.length + // close omni vec + vecBatch.releaseAllVectors() + vecBatch.close() + vecBatchSer + } + } + }.collect() + val numOutputRows = numRows.value + if (numOutputRows >= MAX_BROADCAST_TABLE_ROWS) { + throw new SparkException(s"Cannot broadcast the table over " + + s"$MAX_BROADCAST_TABLE_ROWS rows: $numOutputRows rows") + } + + val beforeBroadcast = System.nanoTime() + longMetric("collectTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeCollect) + + // Broadcast the relation + val broadcasted: broadcast.Broadcast[Any] = sparkContext.broadcast(input) + longMetric("broadcastTime") += NANOSECONDS.toMillis( + System.nanoTime() - beforeBroadcast) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + promise.trySuccess(broadcasted) + broadcasted + } catch { + // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw + // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult + // will catch this exception and re-throw the wrapped fatal throwable. + case oe: OutOfMemoryError => + val ex = new SparkFatalException( + new OutOfMemoryError("Not enough memory to build and broadcast the table to all " + + "worker nodes. As a workaround, you can either disable broadcast by setting " + + s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark " + + s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value.") + .initCause(oe.getCause)) + promise.tryFailure(ex) + throw ex + case e if !NonFatal(e) => + val ex = new SparkFatalException(e) + promise.tryFailure(ex) + throw ex + case e: Throwable => + promise.tryFailure(e) + throw e + } + } + } + + override protected def doPrepare(): Unit = { + // Materialize the future. + relationFuture + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException( + "BroadcastExchange does not support the execute() code path.") + } + + override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + try { + relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]] + } catch { + case ex: TimeoutException => + logError(s"Could not execute broadcast in $timeout secs.", ex) + if (!relationFuture.isDone) { + sparkContext.cancelJobGroup(runId.toString) + relationFuture.cancel(true) + } + throw new SparkException(s"Could not execute broadcast in $timeout secs. " + + s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " + + s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1", + ex) + } + } +} + +object ColumnarBroadcastExchangeExec { + // Since the maximum number of keys that BytesToBytesMap supports is 1 << 29, + // and only 70% of the slots can be used before growing in HashedRelation, + // here the limitation should not be over 341 million. + val MAX_BROADCAST_TABLE_ROWS = (BytesToBytesMap.MAX_CAPACITY / 1.5).toLong + + val MAX_BROADCAST_TABLE_BYTES = 8L << 30 + + private[execution] val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-exchange", + SQLConf.get.getConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD))) +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..6f27cf3f0515824c42236d52ac3701e1d7755ca4 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import nova.hetu.omniruntime.vector.Vec + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + +import org.apache.spark.broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, SpecializedGetters, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OmniColumnVector, WritableColumnVector} +import org.apache.spark.sql.types.{BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType} +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Holds a user defined rule that can be used to inject columnar implementations of various + * operators in the plan. The [[preColumnarTransitions]] [[Rule]] can be used to replace + * [[SparkPlan]] instances with versions that support a columnar implementation. After this + * Spark will insert any transitions necessary. This includes transitions from row to columnar + * [[RowToColumnarExec]] and from columnar to row [[ColumnarToRowExec]]. At this point the + * [[postColumnarTransitions]] [[Rule]] is called to allow replacing any of the implementations + * of the transitions or doing cleanup of the plan, like inserting stages to build larger batches + * for more efficient processing, or stages that transition the data to/from an accelerator's + * memory. + */ +class ColumnarRule { + def preColumnarTransitions: Rule[SparkPlan] = plan => plan + def postColumnarTransitions: Rule[SparkPlan] = plan => plan +} + +/** + * A trait that is used as a tag to indicate a transition from columns to rows. This allows plugins + * to replace the current [[ColumnarToRowExec]] with an optimized version and still have operations + * that walk a spark plan looking for this type of transition properly match it. + */ +trait ColumnarToRowTransition extends UnaryExecNode + + +/** + * Provides an optimized set of APIs to append row based data to an array of + * [[WritableColumnVector]]. + */ +private[execution] class RowToColumnConverter(schema: StructType) extends Serializable { + private val converters = schema.fields.map { + f => RowToColumnConverter.getConverterForType(f.dataType, f.nullable) + } + + final def convert(row: InternalRow, vectors: Array[WritableColumnVector]): Unit = { + var idx = 0 + while (idx < row.numFields) { + converters(idx).append(row, idx, vectors(idx)) + idx += 1 + } + } +} + +/** + * Provides an optimized set of APIs to extract a column from a row and append it to a + * [[WritableColumnVector]]. + */ +private object RowToColumnConverter { + SparkMemoryUtils.init() + + private abstract class TypeConverter extends Serializable { + def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit + } + + private final case class BasicNullableTypeConverter(base: TypeConverter) extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + if (row.isNullAt(column)) { + cv.appendNull + } else { + base.append(row, column, cv) + } + } + } + + private def getConverterForType(dataType: DataType, nullable: Boolean): TypeConverter = { + val core = dataType match { + case BooleanType => BooleanConverter + case ByteType => ByteConverter + case ShortType => ShortConverter + case IntegerType | DateType => IntConverter + case LongType | TimestampType => LongConverter + case DoubleType => DoubleConverter + case StringType => StringConverter + case CalendarIntervalType => CalendarConverter + case dt: DecimalType => DecimalConverter(dt) + case unknown => throw new UnsupportedOperationException( + s"Type $unknown not supported") + } + + if (nullable) { + dataType match { + case _ => new BasicNullableTypeConverter(core) + } + } else { + core + } + } + + private object BooleanConverter extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = + cv.appendBoolean(row.getBoolean(column)) + } + + private object ByteConverter extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = + cv.appendByte(row.getByte(column)) + } + + private object ShortConverter extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = + cv.appendShort(row.getShort(column)) + } + + private object IntConverter extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = + cv.appendInt(row.getInt(column)) + } + + private object LongConverter extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = + cv.appendLong(row.getLong(column)) + } + + private object DoubleConverter extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = + cv.appendDouble(row.getDouble(column)) + } + + private object StringConverter extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + val data = row.getUTF8String(column).getBytes + cv.asInstanceOf[OmniColumnVector].appendString(data.length, data, 0) + } + } + + private object CalendarConverter extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + val c = row.getInterval(column) + cv.appendStruct(false) + cv.getChild(0).appendInt(c.months) + cv.getChild(1).appendInt(c.days) + cv.getChild(2).appendLong(c.microseconds) + } + } + + private case class DecimalConverter(dt: DecimalType) extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + val d = row.getDecimal(column, dt.precision, dt.scale) + if (DecimalType.is64BitDecimalType(dt)) { + cv.appendLong(d.toUnscaledLong) + } else { + cv.asInstanceOf[OmniColumnVector].appendDecimal(d) + } + } + } +} + +/** + * A trait that is used as a tag to indicate a transition from rows to columns. This allows plugins + * to replace the current [[RowToColumnarExec]] with an optimized version and still have operations + * that walk a spark plan looking for this type of transition properly match it. + */ +trait RowToColumnarTransition extends UnaryExecNode + +/** + * Provides a common executor to translate an [[RDD]] of [[InternalRow]] into an [[RDD]] of + * [[ColumnarBatch]]. This is inserted whenever such a transition is determined to be needed. + * + * This is similar to some of the code in ArrowConverters.scala and + * [[org.apache.spark.sql.execution.arrow.ArrowWriter]]. That code is more specialized + * to convert [[InternalRow]] to Arrow formatted data, but in the future if we make + * [[OffHeapColumnVector]] internally Arrow formatted we may be able to replace much of that code. + * + * This is also similar to + * [[org.apache.spark.sql.execution.vectorized.ColumnVectorUtils.populate()]] and + * [[org.apache.spark.sql.execution.vectorized.ColumnVectorUtils.toBatch()]] toBatch is only ever + * called from tests and can probably be removed, but populate is used by both Orc and Parquet + * to initialize partition and missing columns. There is some chance that we could replace + * populate with [[RowToColumnConverter]], but the performance requirements are different and it + * would only be to reduce code. + */ + +case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransition { + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def doExecute(): RDD[InternalRow] = { + child.execute() + } + + override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + child.doExecuteBroadcast() + } + + override def nodeName: String = "RowToOmniColumnar" + + override def supportsColumnar: Boolean = true + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches") + ) + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val enableOffHeapColumnVector = sqlContext.conf.offHeapColumnVectorEnabled + val numInputRows = longMetric("numInputRows") + val numOutputBatches = longMetric("numOutputBatches") + // Instead of creating a new config we are reusing columnBatchSize. In the future if we do + // combine with some of the Arrow conversion tools we will need to unify some of the configs. + val numRows = conf.columnBatchSize + // This avoids calling `schema` in the RDD closure, so that we don't need to include the entire + // plan (this) in the closure. + val localSchema = this.schema + child.execute().mapPartitionsInternal { rowIterator => + if (rowIterator.hasNext) { + new Iterator[ColumnarBatch] { + private val converters = new RowToColumnConverter(localSchema) + + override def hasNext: Boolean = { + rowIterator.hasNext + } + + override def next(): ColumnarBatch = { + val vectors: Seq[WritableColumnVector] = OmniColumnVector.allocateColumns(numRows, + localSchema, true) + val cb: ColumnarBatch = new ColumnarBatch(vectors.toArray) + cb.setNumRows(0) + vectors.foreach(_.reset()) + var rowCount = 0 + while (rowCount < numRows && rowIterator.hasNext) { + val row = rowIterator.next() + converters.convert(row, vectors.toArray) + rowCount += 1 + } + if (!enableOffHeapColumnVector) { + vectors.foreach { v => + v.asInstanceOf[OmniColumnVector].getVec.setSize(rowCount) + } + } + cb.setNumRows(rowCount) + numInputRows += rowCount + numOutputBatches += 1 + cb + } + } + } else { + Iterator.empty + } + } + } +} + + +case class OmniColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransition { + assert(child.supportsColumnar) + + override def nodeName: String = "OmniColumnarToRow" + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches") + ) + + override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + val numInputBatches = longMetric("numInputBatches") + // This avoids calling `output` in the RDD closure, so that we don't need to include the entire + // plan (this) in the closure. + val localOutput = this.output + child.executeColumnar().mapPartitionsInternal { batches => + val toUnsafe = UnsafeProjection.create(localOutput, localOutput) + val vecsTmp = new ListBuffer[Vec] + + val batchIter = batches.flatMap { batch => + // store vec since tablescan reuse batch + for (i <- 0 until batch.numCols()) { + batch.column(i) match { + case vector: OmniColumnVector => + vecsTmp.append(vector.getVec) + case _ => + } + } + numInputBatches += 1 + numOutputRows += batch.numRows() + batch.rowIterator().asScala.map(toUnsafe) + } + + SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => + vecsTmp.foreach {vec => + vec.close() + } + } + batchIter + } + } +} + +/** + * Apply any user defined [[ColumnarRule]]s and find the correct place to insert transitions + * to/from columnar formatted data. + */ +case class ApplyColumnarRulesAndInsertTransitions(columnarRules: Seq[ColumnarRule]) + extends Rule[SparkPlan] { + + /** + * Inserts an transition to columnar formatted data. + */ + private def insertRowToColumnar(plan: SparkPlan): SparkPlan = { + if (!plan.supportsColumnar) { + // The tree feels kind of backwards + // Columnar Processing will start here, so transition from row to columnar + RowToOmniColumnarExec(insertTransitions(plan)) + } else if (!plan.isInstanceOf[RowToColumnarTransition]) { + plan.withNewChildren(plan.children.map(insertRowToColumnar)) + } else { + plan + } + } + + /** + * Inserts RowToColumnarExecs and ColumnarToRowExecs where needed. + */ + private def insertTransitions(plan: SparkPlan): SparkPlan = { + if (plan.supportsColumnar) { + // The tree feels kind of backwards + // This is the end of the columnar processing so go back to rows + if (conf.getConfString("spark.omni.sql.columnar.columnarToRow", "true").toBoolean) { + OmniColumnarToRowExec(insertRowToColumnar(plan)) + } else { + ColumnarToRowExec(insertRowToColumnar(plan)) + } + } else if (!plan.isInstanceOf[ColumnarToRowTransition]) { + plan.withNewChildren(plan.children.map(insertTransitions)) + } else { + plan + } + } + + def apply(plan: SparkPlan): SparkPlan = { + var preInsertPlan: SparkPlan = plan + columnarRules.foreach((r: ColumnarRule) => + preInsertPlan = r.preColumnarTransitions(preInsertPlan)) + var postInsertPlan = insertTransitions(preInsertPlan) + columnarRules.reverse.foreach((r: ColumnarRule) => + postInsertPlan = r.postColumnarTransitions(postInsertPlan)) + postInsertPlan + } +} + diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..cf78751c3db301fdc2fdd060aaa717a12d617ff1 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -0,0 +1,1381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.Optional +import java.util.concurrent.TimeUnit.NANOSECONDS +import com.huawei.boostkit.spark.Constant.{IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP} +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor + +import scala.collection.mutable.HashMap +import scala.collection.JavaConverters._ +import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.constants.FunctionType +import nova.hetu.omniruntime.operator.aggregator.OmniHashAggregationWithExprOperatorFactory +import nova.hetu.omniruntime.operator.filter.OmniFilterAndProjectOperatorFactory +import nova.hetu.omniruntime.vector.{Vec, VecBatch} +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ +import com.huawei.boostkit.spark.util.OmniAdaptorUtil._ +import nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_INNER +import nova.hetu.omniruntime.operator.config.OperatorConfig +import nova.hetu.omniruntime.operator.join.{OmniHashBuilderWithExprOperatorFactory, OmniLookupJoinWithExprOperatorFactory} +import nova.hetu.omniruntime.operator.project.OmniProjectOperatorFactory +import nova.hetu.omniruntime.vector.serialize.VecBatchSerializerFactory +import org.apache.hadoop.fs.Path +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.optimizer.BuildLeft +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.orc.{OmniOrcFileFormat, OrcFileFormat} +import org.apache.spark.sql.execution.joins.ColumnarBroadcastHashJoinExec +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.collection.BitSet + + + +abstract class BaseColumnarFileSourceScanExec( + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false) + extends DataSourceScanExec { + + override lazy val supportsColumnar: Boolean = true + + override def vectorTypes: Option[Seq[String]] = + relation.fileFormat.vectorTypes( + requiredSchema = requiredSchema, + partitionSchema = relation.partitionSchema, + relation.sparkSession.sessionState.conf) + + private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty + + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has + * been initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = { + driverMetrics.foreach(e => metrics(e._1).add(e._2)) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, + metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) + } + + private def isDynamicPruningFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + @transient lazy val selectedPartitions: Array[PartitionDirectory] = { + val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) + val startTime = System.nanoTime() + val ret = + relation.location.listFiles( + partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) + setFilesNumAndSizeMetric(ret, true) + val timeTakenMs = NANOSECONDS.toMillis( + (System.nanoTime() - startTime) + optimizerMetadataTimeNs) + driverMetrics("metadataTime") = timeTakenMs + ret + }.toArray + + // We can only determine the actual partitions at runtime when a dynamic partition filter is + // present. This is because such a filter relies on information that is only available at run + // time (for instance the keys used in the other side of a join). + @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { + val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) + + if (dynamicPartitionFilters.nonEmpty) { + val startTime = System.nanoTime() + // call the file index for the files matching all filters except dynamic partition filters + val predicate = dynamicPartitionFilters.reduce(And) + val partitionColumns = relation.partitionSchema + val boundPredicate = Predicate.create(predicate.transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }, Nil) + val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + setFilesNumAndSizeMetric(ret, false) + val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 + driverMetrics("pruningTime") = timeTakenMs + ret + } else { + selectedPartitions + } + } + + /** + * [[partitionFilters]] can contain subqueries whose results are available only at runtime so + * accessing [[selectedPartitions]] should be guarded by this method during planning + */ + private def hasPartitionsAvailableAtRunTime: Boolean = { + partitionFilters.exists(ExecSubqueryExpression.hasSubquery) + } + + private def toAttribute(colName: String): Option[Attribute] = + output.find(_.name == colName) + + // exposed for testing + lazy val bucketedScan: Boolean = { + if (relation.sparkSession.sessionState.conf.bucketingEnabled && relation.bucketSpec.isDefined + && !disableBucketedScan) { + val spec = relation.bucketSpec.get + val bucketColumns = spec.bucketColumnNames.flatMap(n => toAttribute(n)) + bucketColumns.size == spec.bucketColumnNames.size + } else { + false + } + } + + override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { + if (bucketedScan) { + // For bucketed columns: + // ----------------------- + // `HashPartitioning` would be used only when: + // 1. ALL the bucketing columns are being read from the table + // + // For sorted columns: + // --------------------- + // Sort ordering should be used when ALL these criteria's match: + // 1. `HashPartitioning` is being used + // 2. A prefix (or all) of the sort columns are being read from the table. + // + // Sort ordering would be over the prefix subset of `sort columns` being read + // from the table. + // e.g. + // Assume (col0, col2, col3) are the columns read from the table + // If sort columns are (col0, col1), then sort ordering would be considered as (col0) + // If sort columns are (col1, col0), then sort ordering would be empty as per rule #2 + // above + val spec = relation.bucketSpec.get + val bucketColumns = spec.bucketColumnNames.flatMap(n => toAttribute(n)) + val numPartitions = optionalNumCoalescedBuckets.getOrElse(spec.numBuckets) + val partitioning = HashPartitioning(bucketColumns, numPartitions) + val sortColumns = + spec.sortColumnNames.map(x => toAttribute(x)).takeWhile(x => x.isDefined).map(_.get) + val shouldCalculateSortOrder = + conf.getConf(SQLConf.LEGACY_BUCKETED_TABLE_SCAN_OUTPUT_ORDERING) && + sortColumns.nonEmpty && + !hasPartitionsAvailableAtRunTime + + val sortOrder = if (shouldCalculateSortOrder) { + // In case of bucketing, its possible to have multiple files belonging to the + // same bucket in a given relation. Each of these files are locally sorted + // but those files combined together are not globally sorted. Given that, + // the RDD partition will not be sorted even if the relation has sort columns set + // Current solution is to check if all the buckets have a single file in it + + val files = selectedPartitions.flatMap(partition => partition.files) + val bucketToFilesGrouping = + files.map(_.getPath.getName).groupBy(file => BucketingUtils.getBucketId(file)) + val singleFilePartitions = bucketToFilesGrouping.forall(p => p._2.length <= 1) + + // TODO SPARK-24528 Sort order is currently ignored if buckets are coalesced. + if (singleFilePartitions && optionalNumCoalescedBuckets.isEmpty) { + // TODO Currently Spark does not support writing columns sorting in descending order + // so using Ascending order. This can be fixed in future + sortColumns.map(attribute => SortOrder(attribute, Ascending)) + } else { + Nil + } + } else { + Nil + } + (partitioning, sortOrder) + } else { + (UnknownPartitioning(0), Nil) + } + } + + @transient + private lazy val pushedDownFilters = { + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) + dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + } + + override protected def metadata: Map[String, String] = { + def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") + + val location = relation.location + val locationDesc = + location.getClass.getSimpleName + seqToString(location.rootPaths) + val metadata = + Map( + "Format" -> relation.fileFormat.toString, + "ReadSchema" -> requiredSchema.catalogString, + "Batched" -> supportsColumnar.toString, + "PartitionFilters" -> seqToString(partitionFilters), + "PushedFilters" -> seqToString(pushedDownFilters), + "DataFilters" -> seqToString(dataFilters), + "Location" -> locationDesc) + + // TODO(SPARK-32986): Add bucketed scan info in explain output of FileSourceScanExec + if (bucketedScan) { + relation.bucketSpec.map { spec => + val numSelectedBuckets = optionalBucketSet.map { b => + b.cardinality() + } getOrElse { + spec.numBuckets + } + metadata + ("SelectedBucketsCount" -> + (s"$numSelectedBuckets out of ${spec.numBuckets}" + + optionalNumCoalescedBuckets.map { b => s" (Coalesced to $b)" }.getOrElse(""))) + } getOrElse { + metadata + } + } else { + metadata + } + } + + override def verboseStringWithOperatorId(): String = { + val metadataStr = metadata.toSeq.sorted.filterNot { + case (_, value) if (value.isEmpty || value.equals("[]")) => true + case (key, _) if (key.equals("DataFilters") || key.equals("Format")) => true + case (_, _) => false + }.map { + case (key, _) if (key.equals("Location")) => + val location = relation.location + val numPaths = location.rootPaths.length + val abbreviatedLocation = if (numPaths <= 1) { + location.rootPaths.mkString("[", ", ", "]") + } else { + "[" + location.rootPaths.head + s", ... ${numPaths - 1} entries]" + } + s"$key: ${location.getClass.getSimpleName} ${redact(abbreviatedLocation)}" + case (key, value) => s"$key: ${redact(value)}" + } + + s""" + |$formattedNodeName + |${ExplainUtils.generateFieldString("Output", output)} + |${metadataStr.mkString("\n")} + |""".stripMargin + } + + lazy val inputRDD: RDD[InternalRow] = { + val fileFormat: FileFormat = relation.fileFormat match { + case orcFormat: OrcFileFormat => + new OmniOrcFileFormat() + case _ => + throw new UnsupportedOperationException("Unsupported FileFormat!") + } + val readFile: (PartitionedFile) => Iterator[InternalRow] = + fileFormat.buildReaderWithPartitionValues( + sparkSession = relation.sparkSession, + dataSchema = relation.dataSchema, + partitionSchema = relation.partitionSchema, + requiredSchema = requiredSchema, + filters = pushedDownFilters, + options = relation.options, + hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) + + val readRDD = if (bucketedScan) { + createBucketedReadRDD(relation.bucketSpec.get, readFile, dynamicallySelectedPartitions, + relation) + } else { + createNonBucketedReadRDD(readFile, dynamicallySelectedPartitions, relation) + } + sendDriverMetrics() + readRDD + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + inputRDD :: Nil + } + + /** SQL metrics generated only for scans using dynamic partition pruning. */ + protected lazy val staticMetrics = if (partitionFilters.exists(isDynamicPruningFilter)) { + Map("staticFilesNum" -> SQLMetrics.createMetric(sparkContext, "static number of files read"), + "staticFilesSize" -> SQLMetrics.createSizeMetric(sparkContext, "static size of files read")) + } else { + Map.empty[String, SQLMetric] + } + + /** Helper for computing total number and size of files in selected partitions. */ + private def setFilesNumAndSizeMetric( + partitions: Seq[PartitionDirectory], + static: Boolean): Unit = { + val filesNum = partitions.map(_.files.size.toLong).sum + val filesSize = partitions.map(_.files.map(_.getLen).sum).sum + if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { + driverMetrics("numFiles") = filesNum + driverMetrics("filesSize") = filesSize + } else { + driverMetrics("staticFilesNum") = filesNum + driverMetrics("staticFilesSize") = filesSize + } + if (relation.partitionSchemaOption.isDefined) { + driverMetrics("numPartitions") = partitions.length + } + } + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files read"), + "metadataTime" -> SQLMetrics.createTimingMetric(sparkContext, "metadata time"), + "filesSize" -> SQLMetrics.createSizeMetric(sparkContext, "size of files read"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs") + ) ++ { + // Tracking scan time has overhead, we can't afford to do it for each row, and can only do + // it for each batch. + if (supportsColumnar) { + Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + } else { + None + } + } ++ { + if (relation.partitionSchemaOption.isDefined) { + Map( + "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions read"), + "pruningTime" -> + SQLMetrics.createTimingMetric(sparkContext, "dynamic partition pruning time")) + } else { + Map.empty[String, SQLMetric] + } + } ++ staticMetrics + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException() + } + + def buildCheck(): Unit = { + output.zipWithIndex.foreach { + case (attr, i) => + sparkTypeToOmniType(attr.dataType, attr.metadata) + if (attr.dataType.isInstanceOf[DecimalType]) { + val dt = attr.dataType.asInstanceOf[DecimalType] + if (!DecimalType.is64BitDecimalType(dt)) { + throw new UnsupportedOperationException(s"ColumnarTableScan is not supported for type:${dt}"); + } + } + } + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val scanTime = longMetric("scanTime") + val numOutputVecBatchs = longMetric("numOutputVecBatchs") + inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => + new Iterator[ColumnarBatch] { + + override def hasNext: Boolean = { + // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. + val startNs = System.nanoTime() + val res = batches.hasNext + scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) + res + } + + override def next(): ColumnarBatch = { + val batch = batches.next() + numOutputRows += batch.numRows() + numOutputVecBatchs += 1 + batch + } + } + } + } + + override val nodeNamePrefix: String = "ColumnarFile" + + /** + * Create an RDD for bucketed reads. + * The non-bucketed variant of this function is [[createNonBucketedReadRDD]]. + * + * The algorithm is pretty simple: each RDD partition being returned should include all the files + * with the same bucket id from all the given Hive partitions. + * + * @param bucketSpec the bucketing spec. + * @param readFile a function to read each (part of a) file. + * @param selectedPartitions Hive-style partition that are part of the read. + * @param fsRelation [[HadoopFsRelation]] associated with the read. + */ + private def createBucketedReadRDD( + bucketSpec: BucketSpec, + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") + val filesGroupedToBuckets = + selectedPartitions.flatMap { p => + p.files.map { f => + PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) + } + }.groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath).getName) + .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) + } + + // TODO(SPARK-32985): Decouple bucket filter pruning and bucketed table scan + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { + f => bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + + val filePartitions = optionalNumCoalescedBuckets.map { numCoalescedBuckets => + logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") + val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) + Seq.tabulate(numCoalescedBuckets) { bucketId => + val partitionedFiles = coalescedBuckets.get(bucketId).map { + _.values.flatten.toArray + }.getOrElse(Array.empty) + FilePartition(bucketId, partitionedFiles) + } + }.getOrElse { + Seq.tabulate(bucketSpec.numBuckets) { bucketId => + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) + } + } + + new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions) + } + + /** + * Create an RDD for non-bucketed reads. + * The bucketed variant of this function is [[createBucketedReadRDD]]. + * + * @param readFile a function to read each (part of a) file. + * @param selectedPartitions Hive-style partition that are part of the read. + * @param fsRelation [[HadoopFsRelation]] associated with the read. + */ + private def createNonBucketedReadRDD( + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes + val maxSplitBytes = + FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) + logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + + s"open cost is considered as scanning $openCostInBytes bytes.") + + val splitFiles = selectedPartitions.flatMap { partition => + partition.files.flatMap { file => + // getPath() is very expensive so we only want to call it once in this block: + val filePath = file.getPath + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, relation.options, filePath) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values + ) + } + }.sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + val partitions = + FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes) + + new FileScanRDD(fsRelation.sparkSession, readFile, partitions) + } + + // Filters unused DynamicPruningExpression expressions - one which has been replaced + // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning + protected def filterUnusedDynamicPruningExpressions( + predicates: Seq[Expression]): Seq[Expression] = { + predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) + } + + def genAggOutput(agg: HashAggregateExec) = { + val attrAggExpsIdMap = getExprIdMap(agg.child.output) + val omniGroupByChanel = agg.groupingExpressions.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, attrAggExpsIdMap)).toArray + + var omniOutputExressionOrder: Map[ExprId, Int] = Map() + var aggIndexOffset = 0 + agg.groupingExpressions.zipWithIndex.foreach { case (exp, index) => + omniOutputExressionOrder += (exp.exprId -> (index + aggIndexOffset)) + } + aggIndexOffset += agg.groupingExpressions.size + + val omniAggInputRaw = true + val omniAggOutputPartial = true + val omniAggTypes = new Array[DataType](agg.aggregateExpressions.size) + val omniAggFunctionTypes = new Array[FunctionType](agg.aggregateExpressions.size) + val omniAggOutputTypes = new Array[DataType](agg.aggregateExpressions.size) + val omniAggChannels = new Array[String](agg.aggregateExpressions.size) + var omniAggindex = 0 + for (exp <- agg.aggregateExpressions) { + if (exp.mode == Final) { + throw new UnsupportedOperationException(s"Unsupported final aggregate expression in operator fusion, exp: $exp") + } else if (exp.mode == Partial) { + exp.aggregateFunction match { + case Sum(_) | Min(_) | Average(_) | Max(_) | Count(_) => + val aggExp = exp.aggregateFunction.children.head + omniOutputExressionOrder += { + exp.aggregateFunction.inputAggBufferAttributes.head.exprId -> + (omniAggindex + aggIndexOffset) + } + omniAggTypes(omniAggindex) = sparkTypeToOmniType(aggExp.dataType) + omniAggFunctionTypes(omniAggindex) = toOmniAggFunType(exp, true) + omniAggOutputTypes(omniAggindex) = + sparkTypeToOmniType(exp.aggregateFunction.dataType) + omniAggChannels(omniAggindex) = + rewriteToOmniJsonExpressionLiteral(aggExp, attrAggExpsIdMap) + case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: $exp") + } + } else { + throw new UnsupportedOperationException(s"Unsupported aggregate mode: $exp.mode") + } + omniAggindex += 1 + } + + var resultIdxToOmniResultIdxMap: Map[Int, Int] = Map() + agg.resultExpressions.zipWithIndex.foreach { case (exp, index) => + if (omniOutputExressionOrder.contains(getRealExprId(exp))) { + resultIdxToOmniResultIdxMap += + (index -> omniOutputExressionOrder(getRealExprId(exp))) + } + } + + val omniAggSourceTypes = new Array[DataType](agg.child.outputSet.size) + val inputIter = agg.child.outputSet.toIterator + var i = 0 + while (inputIter.hasNext) { + val inputAttr = inputIter.next() + omniAggSourceTypes(i) = sparkTypeToOmniType(inputAttr.dataType, inputAttr.metadata) + i += 1 + } + (omniGroupByChanel, omniAggChannels, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, + omniAggInputRaw, omniAggOutputPartial, resultIdxToOmniResultIdxMap) + } + + def genProjectOutput(project: ColumnarProjectExec) = { + val omniAttrExpsIdMap = getExprIdMap(project.child.output) + val omniInputTypes = project.child.output.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + val omniExpressions = project.projectList.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray + (omniExpressions, omniInputTypes) + } + + def genJoinOutput(join: ColumnarBroadcastHashJoinExec) = { + val buildTypes = new Array[DataType](join.getBuildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 + join.getBuildOutput.zipWithIndex.foreach { case (att, i) => + buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } + + // {0}, buildKeys: col1#12 + val buildOutputCols = join.getBuildOutput.indices.toArray // {0,1} + val buildJoinColsExp = join.getBuildKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(join.getBuildOutput.map(_.toAttribute))) + }.toArray + val buildData = join.getBuildPlan.executeBroadcast[Array[Array[Byte]]]() + + val buildOutputTypes = buildTypes // {1,1} + + val probeTypes = new Array[DataType](join.getStreamedOutput.size) // {2,2},streamedOutput:col1#10,col2#11 + join.getStreamedOutput.zipWithIndex.foreach { case (attr, i) => + probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val probeOutputCols = join.getStreamedOutput.indices.toArray// {0,1} + val probeHashColsExp = join.getStreamedKeys.map {x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(join.getStreamedOutput.map(_.toAttribute))) + }.toArray + val filter: Option[String] = join.condition match { + case Some(expr) => + Some(OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, + OmniExpressionAdaptor.getExprIdMap((join.getStreamedOutput ++ join.getBuildOutput).map(_.toAttribute)))) + case _ => None + } + (buildTypes, buildJoinColsExp, filter, probeTypes, probeOutputCols, + probeHashColsExp, buildOutputCols, buildOutputTypes, buildData) + } + + def genFilterOutput(cond: ColumnarFilterExec) = { + val omniCondAttrExpsIdMap = getExprIdMap(cond.child.output) + val omniCondInputTypes = cond.child.output.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + val omniCondExpressions = cond.child.output.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, omniCondAttrExpsIdMap)).toArray + val conditionExpression = rewriteToOmniJsonExpressionLiteral(cond.condition, omniCondAttrExpsIdMap) + (conditionExpression, omniCondInputTypes, omniCondExpressions) + } + + def genJoinOutputWithReverse(join: ColumnarBroadcastHashJoinExec) = { + val buildTypes = new Array[DataType](join.getBuildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 + join.getBuildOutput.zipWithIndex.foreach { case (att, i) => + buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } + + // {0}, buildKeys: col1#12 + val buildOutputCols = join.getBuildOutput.indices.toArray // {0,1} + val buildJoinColsExp = join.getBuildKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(join.getBuildOutput.map(_.toAttribute))) + }.toArray + val buildData = join.getBuildPlan.executeBroadcast[Array[Array[Byte]]]() + + val buildOutputTypes = buildTypes // {1,1} + + val probeTypes = new Array[DataType](join.getStreamedOutput.size) // {2,2},streamedOutput:col1#10,col2#11 + join.getStreamedOutput.zipWithIndex.foreach { case (attr, i) => + probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val probeOutputCols = join.getStreamedOutput.indices.toArray// {0,1} + val probeHashColsExp = join.getStreamedKeys.map {x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(join.getStreamedOutput.map(_.toAttribute))) + }.toArray + val filter: Option[String] = join.condition match { + case Some(expr) => + Some(OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, + OmniExpressionAdaptor.getExprIdMap((join.getStreamedOutput ++ join.getBuildOutput).map(_.toAttribute)))) + case _ => None + } + + val reverse = join.buildSide == BuildLeft + var left = 0 + var leftLen = join.getStreamPlan.output.size + var right = join.getStreamPlan.output.size + var rightLen = join.output.size + if (reverse) { + left = join.getStreamPlan.output.size + leftLen = join.output.size + right = 0 + rightLen = join.getStreamPlan.output.size + } + (buildTypes, buildJoinColsExp, filter, probeTypes, probeOutputCols, + probeHashColsExp, buildOutputCols, buildOutputTypes, buildData, (left, leftLen, right, rightLen)) + } +} + +case class ColumnarFileSourceScanExec( + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false) + extends BaseColumnarFileSourceScanExec( + relation, + output, + requiredSchema, + partitionFilters, + optionalBucketSet, + optionalNumCoalescedBuckets, + dataFilters, + tableIdentifier, + disableBucketedScan) { + override def doCanonicalize(): ColumnarFileSourceScanExec = { + ColumnarFileSourceScanExec( + relation, + output.map(QueryPlan.normalizeExpressions(_, output)), + requiredSchema, + QueryPlan.normalizePredicates( + filterUnusedDynamicPruningExpressions(partitionFilters), output), + optionalBucketSet, + optionalNumCoalescedBuckets, + QueryPlan.normalizePredicates(dataFilters, output), + None, + disableBucketedScan) + } +} + +case class WrapperLeafExec() extends LeafExecNode { + + override def supportsColumnar: Boolean = true + + + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = throw new UnsupportedOperationException + + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException + + override def output: Seq[Attribute] = Seq() +} + +case class ColumnarMultipleOperatorExec( + aggregate: HashAggregateExec, + proj1: ColumnarProjectExec, + join1: ColumnarBroadcastHashJoinExec, + proj2: ColumnarProjectExec, + join2: ColumnarBroadcastHashJoinExec, + proj3: ColumnarProjectExec, + join3: ColumnarBroadcastHashJoinExec, + proj4: ColumnarProjectExec, + join4: ColumnarBroadcastHashJoinExec, + filter: ColumnarFilterExec, + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false) + extends BaseColumnarFileSourceScanExec( + relation, + output, + requiredSchema, + partitionFilters, + optionalBucketSet, + optionalNumCoalescedBuckets, + dataFilters, + tableIdentifier, + disableBucketedScan) { + + protected override def doPrepare(): Unit = { + super.doPrepare() + join1.getBuildPlan.asInstanceOf[ColumnarBroadcastExchangeExec].relationFuture + join2.getBuildPlan.asInstanceOf[ColumnarBroadcastExchangeExec].relationFuture + join3.getBuildPlan.asInstanceOf[ColumnarBroadcastExchangeExec].relationFuture + join4.getBuildPlan.asInstanceOf[ColumnarBroadcastExchangeExec].relationFuture + } + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files read"), + "metadataTime" -> SQLMetrics.createTimingMetric(sparkContext, "metadata time"), + "filesSize" -> SQLMetrics.createSizeMetric(sparkContext, "size of files read"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "outputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "omniJitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs") + ) ++ { + // Tracking scan time has overhead, we can't afford to do it for each row, and can only do + // it for each batch. + if (supportsColumnar) { + Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + } else { + None + } + } ++ { + if (relation.partitionSchemaOption.isDefined) { + Map( + "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions read"), + "pruningTime" -> + SQLMetrics.createTimingMetric(sparkContext, "dynamic partition pruning time")) + } else { + Map.empty[String, SQLMetric] + } + } ++ staticMetrics + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val scanTime = longMetric("scanTime") + val numInputRows = longMetric("numInputRows") + val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val addInputTime = longMetric("addInputTime") + val omniCodegenTime = longMetric("omniJitTime") + val getOutputTime = longMetric("outputTime") + + val (omniGroupByChanel, omniAggChannels, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, + omniAggInputRaw, omniAggOutputPartial, resultIdxToOmniResultIdxMap) = genAggOutput(aggregate) + val (proj1OmniExpressions, proj1OmniInputTypes) = genProjectOutput(proj1) + val (buildTypes1, buildJoinColsExp1, joinFilter1, probeTypes1, probeOutputCols1, + probeHashColsExp1, buildOutputCols1, buildOutputTypes1, buildData1) = genJoinOutput(join1) + val (proj2OmniExpressions, proj2OmniInputTypes) = genProjectOutput(proj2) + val (buildTypes2, buildJoinColsExp2, joinFilter2, probeTypes2, probeOutputCols2, + probeHashColsExp2, buildOutputCols2, buildOutputTypes2, buildData2) = genJoinOutput(join2) + val (proj3OmniExpressions, proj3OmniInputTypes) = genProjectOutput(proj3) + val (buildTypes3, buildJoinColsExp3, joinFilter3, probeTypes3, probeOutputCols3, + probeHashColsExp3, buildOutputCols3, buildOutputTypes3, buildData3) = genJoinOutput(join3) + val (proj4OmniExpressions, proj4OmniInputTypes) = genProjectOutput(proj4) + val (buildTypes4, buildJoinColsExp4, joinFilter4, probeTypes4, probeOutputCols4, + probeHashColsExp4, buildOutputCols4, buildOutputTypes4, buildData4) = genJoinOutput(join4) + val (conditionExpression, omniCondInputTypes, omniCondExpressions) = genFilterOutput(filter) + + inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => + // for join + val deserializer = VecBatchSerializerFactory.create() + val startCodegen = System.nanoTime() + val aggFactory = new OmniHashAggregationWithExprOperatorFactory( + omniGroupByChanel, + omniAggChannels, + omniAggSourceTypes, + omniAggFunctionTypes, + omniAggOutputTypes, + omniAggInputRaw, + omniAggOutputPartial, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val aggOperator = aggFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + aggOperator.close() + }) + + val projectOperatorFactory1 = new OmniProjectOperatorFactory(proj1OmniExpressions, proj1OmniInputTypes, 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val projectOperator1 = projectOperatorFactory1.createOperator + // close operator + addLeakSafeTaskCompletionListener[Unit](_ => { + projectOperator1.close() + }) + + val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(buildTypes1, + buildJoinColsExp1, if (joinFilter1.nonEmpty) {Optional.of(joinFilter1.get)} else {Optional.empty()}, 1, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val buildOp1 = buildOpFactory1.createOperator() + buildData1.value.foreach { input => + buildOp1.addInput(deserializer.deserialize(input)) + } + buildOp1.getOutput + val lookupOpFactory1 = new OmniLookupJoinWithExprOperatorFactory(probeTypes1, probeOutputCols1, + probeHashColsExp1, buildOutputCols1, buildOutputTypes1, OMNI_JOIN_TYPE_INNER, buildOpFactory1, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val lookupOp1 = lookupOpFactory1.createOperator() + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]( _ => { + buildOp1.close() + lookupOp1.close() + buildOpFactory1.close() + lookupOpFactory1.close() + }) + + val projectOperatorFactory2 = new OmniProjectOperatorFactory(proj2OmniExpressions, proj2OmniInputTypes, 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val projectOperator2 = projectOperatorFactory2.createOperator + // close operator + addLeakSafeTaskCompletionListener[Unit](_ => { + projectOperator2.close() + }) + + val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(buildTypes2, + buildJoinColsExp2, if (joinFilter2.nonEmpty) {Optional.of(joinFilter2.get)} else {Optional.empty()}, 1, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val buildOp2 = buildOpFactory2.createOperator() + buildData2.value.foreach { input => + buildOp2.addInput(deserializer.deserialize(input)) + } + buildOp2.getOutput + val lookupOpFactory2 = new OmniLookupJoinWithExprOperatorFactory(probeTypes2, probeOutputCols2, + probeHashColsExp2, buildOutputCols2, buildOutputTypes2, OMNI_JOIN_TYPE_INNER, buildOpFactory2, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val lookupOp2 = lookupOpFactory2.createOperator() + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]( _ => { + buildOp2.close() + lookupOp2.close() + buildOpFactory2.close() + lookupOpFactory2.close() + }) + + val projectOperatorFactory3 = new OmniProjectOperatorFactory(proj3OmniExpressions, proj3OmniInputTypes, 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val projectOperator3 = projectOperatorFactory3.createOperator + // close operator + addLeakSafeTaskCompletionListener[Unit](_ => { + projectOperator3.close() + }) + + val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(buildTypes3, + buildJoinColsExp3, if (joinFilter3.nonEmpty) {Optional.of(joinFilter3.get)} else {Optional.empty()}, 1, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val buildOp3 = buildOpFactory3.createOperator() + buildData3.value.foreach { input => + buildOp3.addInput(deserializer.deserialize(input)) + } + buildOp3.getOutput + val lookupOpFactory3 = new OmniLookupJoinWithExprOperatorFactory(probeTypes3, probeOutputCols3, + probeHashColsExp3, buildOutputCols3, buildOutputTypes3, OMNI_JOIN_TYPE_INNER, buildOpFactory3, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val lookupOp3 = lookupOpFactory3.createOperator() + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]( _ => { + buildOp3.close() + lookupOp3.close() + buildOpFactory3.close() + lookupOpFactory3.close() + }) + + val projectOperatorFactory4 = new OmniProjectOperatorFactory(proj4OmniExpressions, proj4OmniInputTypes, 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val projectOperator4 = projectOperatorFactory4.createOperator + // close operator + addLeakSafeTaskCompletionListener[Unit](_ => { + projectOperator4.close() + }) + + val buildOpFactory4 = new OmniHashBuilderWithExprOperatorFactory(buildTypes4, + buildJoinColsExp4, if (joinFilter4.nonEmpty) {Optional.of(joinFilter4.get)} else {Optional.empty()}, 1, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val buildOp4 = buildOpFactory4.createOperator() + buildData4.value.foreach { input => + buildOp4.addInput(deserializer.deserialize(input)) + } + buildOp4.getOutput + val lookupOpFactory4 = new OmniLookupJoinWithExprOperatorFactory(probeTypes4, probeOutputCols4, + probeHashColsExp4, buildOutputCols4, buildOutputTypes4, OMNI_JOIN_TYPE_INNER, buildOpFactory4, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val lookupOp4 = lookupOpFactory4.createOperator() + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]( _ => { + buildOp4.close() + lookupOp4.close() + buildOpFactory4.close() + lookupOpFactory4.close() + }) + + val condOperatorFactory = new OmniFilterAndProjectOperatorFactory( + conditionExpression, omniCondInputTypes, seqAsJavaList(omniCondExpressions), 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val condOperator = condOperatorFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + condOperator.close() + }) + + while (batches.hasNext) { + val batch = batches.next() + val startInput = System.nanoTime() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + condOperator.addInput(vecBatch) + val condOutput = condOperator.getOutput + while (condOutput.hasNext) { + val output = condOutput.next() + lookupOp4.addInput(output) + val joinOutput4 = lookupOp4.getOutput + while (joinOutput4.hasNext) { + val output = joinOutput4.next() + projectOperator4.addInput(output) + val projOutput4 = projectOperator4.getOutput + while (projOutput4.hasNext) { + val output = projOutput4.next() + lookupOp3.addInput(output) + val joinOutput3 = lookupOp3.getOutput + while (joinOutput3.hasNext) { + val output = joinOutput3.next() + projectOperator3.addInput(output) + val projOutput3 = projectOperator3.getOutput + while (projOutput3.hasNext) { + val output = projOutput3.next() + lookupOp2.addInput(output) + val joinOutput2 = lookupOp2.getOutput + while (joinOutput2.hasNext) { + val output = joinOutput2.next() + projectOperator2.addInput(output) + val projOutput2 = projectOperator2.getOutput + while (projOutput2.hasNext) { + val output = projOutput2.next() + lookupOp1.addInput(output) + val joinOutput1 = lookupOp1.getOutput + while (joinOutput1.hasNext) { + val output = joinOutput1.next() + projectOperator1.addInput(output) + val proj1Output = projectOperator1.getOutput + while (proj1Output.hasNext) { + val output = proj1Output.next() + numInputRows += output.getRowCount() + aggOperator.addInput(output) + } + } + } + } + } + } + } + } + } + addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) + } + val startGetOp = System.nanoTime() + val aggOutput = aggOperator.getOutput + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + val localSchema = aggregate.schema + + new Iterator[ColumnarBatch] { + override def hasNext: Boolean = { + // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. + val startNs = System.nanoTime() + val res = aggOutput.hasNext + scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) + res + } + override def next(): ColumnarBatch = { + val vecBatch = aggOutput.next() + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, localSchema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(resultIdxToOmniResultIdxMap(i))) + } + numOutputRows += vecBatch.getRowCount + numOutputVecBatchs += 1 + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + } + } + } + } + + override val nodeNamePrefix: String = "" + + override val nodeName: String = "OmniColumnarMultipleOperatorExec" + + // TODO: + override protected def doCanonicalize(): SparkPlan = WrapperLeafExec() +} + +case class ColumnarMultipleOperatorExec1( + aggregate: HashAggregateExec, + proj1: ColumnarProjectExec, + join1: ColumnarBroadcastHashJoinExec, + proj2: ColumnarProjectExec, + join2: ColumnarBroadcastHashJoinExec, + proj3: ColumnarProjectExec, + join3: ColumnarBroadcastHashJoinExec, + filter: ColumnarFilterExec, + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false) + extends BaseColumnarFileSourceScanExec( + relation, + output, + requiredSchema, + partitionFilters, + optionalBucketSet, + optionalNumCoalescedBuckets, + dataFilters, + tableIdentifier, + disableBucketedScan) { + + protected override def doPrepare(): Unit = { + super.doPrepare() + join1.getBuildPlan.asInstanceOf[ColumnarBroadcastExchangeExec].relationFuture + join2.getBuildPlan.asInstanceOf[ColumnarBroadcastExchangeExec].relationFuture + join3.getBuildPlan.asInstanceOf[ColumnarBroadcastExchangeExec].relationFuture + } + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files read"), + "metadataTime" -> SQLMetrics.createTimingMetric(sparkContext, "metadata time"), + "filesSize" -> SQLMetrics.createSizeMetric(sparkContext, "size of files read"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "outputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "omniJitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), + //operator metric + "lookupAddInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni lookup addInput"), + // + ) ++ { + // Tracking scan time has overhead, we can't afford to do it for each row, and can only do + // it for each batch. + if (supportsColumnar) { + Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + } else { + None + } + } ++ { + if (relation.partitionSchemaOption.isDefined) { + Map( + "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions read"), + "pruningTime" -> + SQLMetrics.createTimingMetric(sparkContext, "dynamic partition pruning time")) + } else { + Map.empty[String, SQLMetric] + } + } ++ staticMetrics + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val scanTime = longMetric("scanTime") + val numInputRows = longMetric("numInputRows") + val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val addInputTime = longMetric("addInputTime") + val omniCodegenTime = longMetric("omniJitTime") + val getOutputTime = longMetric("outputTime") + + val (omniGroupByChanel, omniAggChannels, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, + omniAggInputRaw, omniAggOutputPartial, resultIdxToOmniResultIdxMap) = genAggOutput(aggregate) + val (proj1OmniExpressions, proj1OmniInputTypes) = genProjectOutput(proj1) + val (buildTypes1, buildJoinColsExp1, joinFilter1, probeTypes1, probeOutputCols1, + probeHashColsExp1, buildOutputCols1, buildOutputTypes1, buildData1, reserved1) = genJoinOutputWithReverse(join1) + val (proj2OmniExpressions, proj2OmniInputTypes) = genProjectOutput(proj2) + val (buildTypes2, buildJoinColsExp2, joinFilter2, probeTypes2, probeOutputCols2, + probeHashColsExp2, buildOutputCols2, buildOutputTypes2, buildData2, reserved2) = genJoinOutputWithReverse(join2) + val (proj3OmniExpressions, proj3OmniInputTypes) = genProjectOutput(proj3) + val (buildTypes3, buildJoinColsExp3, joinFilter3, probeTypes3, probeOutputCols3, + probeHashColsExp3, buildOutputCols3, buildOutputTypes3, buildData3, reserved3) = genJoinOutputWithReverse(join3) + val (conditionExpression, omniCondInputTypes, omniCondExpressions) = genFilterOutput(filter) + + def reserveVec(o: VecBatch): VecBatch = { + val omniVecs = o.getVectors + val newOmniVecs = new Array[Vec](omniVecs.length) + var index = 0 + for (i <- reserved3._1 until reserved3._2) { + newOmniVecs(index) = omniVecs(i) + index += 1 + } + for (i <- reserved3._3 until reserved3._4) { + newOmniVecs(index) = omniVecs(i) + index += 1 + } + o.close() + new VecBatch(newOmniVecs) + } + + inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => + // for join + val deserializer = VecBatchSerializerFactory.create() + val startCodegen = System.nanoTime() + val aggFactory = new OmniHashAggregationWithExprOperatorFactory( + omniGroupByChanel, + omniAggChannels, + omniAggSourceTypes, + omniAggFunctionTypes, + omniAggOutputTypes, + omniAggInputRaw, + omniAggOutputPartial, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val aggOperator = aggFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + aggOperator.close() + }) + + val projectOperatorFactory1 = new OmniProjectOperatorFactory(proj1OmniExpressions, proj1OmniInputTypes, 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val projectOperator1 = projectOperatorFactory1.createOperator + // close operator + addLeakSafeTaskCompletionListener[Unit](_ => { + projectOperator1.close() + }) + + val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(buildTypes1, + buildJoinColsExp1, if (joinFilter1.nonEmpty) {Optional.of(joinFilter1.get)} else {Optional.empty()}, 1, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val buildOp1 = buildOpFactory1.createOperator() + buildData1.value.foreach { input => + buildOp1.addInput(deserializer.deserialize(input)) + } + buildOp1.getOutput + val lookupOpFactory1 = new OmniLookupJoinWithExprOperatorFactory(probeTypes1, probeOutputCols1, + probeHashColsExp1, buildOutputCols1, buildOutputTypes1, OMNI_JOIN_TYPE_INNER, buildOpFactory1, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val lookupOp1 = lookupOpFactory1.createOperator() + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]( _ => { + buildOp1.close() + lookupOp1.close() + buildOpFactory1.close() + lookupOpFactory1.close() + }) + + val projectOperatorFactory2 = new OmniProjectOperatorFactory(proj2OmniExpressions, proj2OmniInputTypes, 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val projectOperator2 = projectOperatorFactory2.createOperator + // close operator + addLeakSafeTaskCompletionListener[Unit](_ => { + projectOperator2.close() + }) + + val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(buildTypes2, + buildJoinColsExp2, if (joinFilter2.nonEmpty) {Optional.of(joinFilter2.get)} else {Optional.empty()}, 1, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val buildOp2 = buildOpFactory2.createOperator() + buildData2.value.foreach { input => + buildOp2.addInput(deserializer.deserialize(input)) + } + buildOp2.getOutput + val lookupOpFactory2 = new OmniLookupJoinWithExprOperatorFactory(probeTypes2, probeOutputCols2, + probeHashColsExp2, buildOutputCols2, buildOutputTypes2, OMNI_JOIN_TYPE_INNER, buildOpFactory2, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val lookupOp2 = lookupOpFactory2.createOperator() + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]( _ => { + buildOp2.close() + lookupOp2.close() + buildOpFactory2.close() + lookupOpFactory2.close() + }) + + val projectOperatorFactory3 = new OmniProjectOperatorFactory(proj3OmniExpressions, proj3OmniInputTypes, 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val projectOperator3 = projectOperatorFactory3.createOperator + // close operator + addLeakSafeTaskCompletionListener[Unit](_ => { + projectOperator3.close() + }) + + val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(buildTypes3, + buildJoinColsExp3, if (joinFilter3.nonEmpty) {Optional.of(joinFilter3.get)} else {Optional.empty()}, 1, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val buildOp3 = buildOpFactory3.createOperator() + buildData3.value.foreach { input => + buildOp3.addInput(deserializer.deserialize(input)) + } + buildOp3.getOutput + val lookupOpFactory3 = new OmniLookupJoinWithExprOperatorFactory(probeTypes3, probeOutputCols3, + probeHashColsExp3, buildOutputCols3, buildOutputTypes3, OMNI_JOIN_TYPE_INNER, buildOpFactory3, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val lookupOp3 = lookupOpFactory3.createOperator() + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]( _ => { + buildOp3.close() + lookupOp3.close() + buildOpFactory3.close() + lookupOpFactory3.close() + }) + + val condOperatorFactory = new OmniFilterAndProjectOperatorFactory( + conditionExpression, omniCondInputTypes, seqAsJavaList(omniCondExpressions), 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val condOperator = condOperatorFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + condOperator.close() + }) + + while (batches.hasNext) { + val batch = batches.next() + val startInput = System.nanoTime() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + condOperator.addInput(vecBatch) + val condOutput = condOperator.getOutput + while (condOutput.hasNext) { + val output = condOutput.next() + lookupOp3.addInput(output) + val joinOutput3 = lookupOp3.getOutput + while (joinOutput3.hasNext) { + val output = if (reserved3._1 > 0) { + reserveVec(joinOutput3.next()) + } else { + joinOutput3.next() + } + projectOperator3.addInput(output) + val projOutput3 = projectOperator3.getOutput + while (projOutput3.hasNext) { + val output = projOutput3.next() + lookupOp2.addInput(output) + val joinOutput2 = lookupOp2.getOutput + while (joinOutput2.hasNext) { + val output = if (reserved2._1 > 0) { + reserveVec(joinOutput2.next()) + } else { + joinOutput2.next() + } + projectOperator2.addInput(output) + val projOutput2 = projectOperator2.getOutput + while (projOutput2.hasNext) { + val output = projOutput2.next() + lookupOp1.addInput(output) + val joinOutput1 = lookupOp1.getOutput + while (joinOutput1.hasNext) { + val output = if (reserved1._1 > 0) { + reserveVec(joinOutput1.next()) + } else { + joinOutput1.next() + } + projectOperator1.addInput(output) + val proj1Output = projectOperator1.getOutput + while (proj1Output.hasNext) { + val output = proj1Output.next() + numInputRows += output.getRowCount() + aggOperator.addInput(output) + } + } + } + } + } + } + } + addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) + } + val startGetOp = System.nanoTime() + val aggOutput = aggOperator.getOutput + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + val localSchema = aggregate.schema + + new Iterator[ColumnarBatch] { + override def hasNext: Boolean = { + // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. + val startNs = System.nanoTime() + val res = aggOutput.hasNext + scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) + res + } + override def next(): ColumnarBatch = { + val vecBatch = aggOutput.next() + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, localSchema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(resultIdxToOmniResultIdxMap(i))) + } + numOutputRows += vecBatch.getRowCount + numOutputVecBatchs += 1 + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + } + } + } + } + + override val nodeNamePrefix: String = "" + + override val nodeName: String = "ColumnarMultipleOperatorExec1" + + // TODO: exchange reuse +/* override def doCanonicalize(): ColumnarMultipleOperatorExec1 = { + ColumnarMultipleOperatorExec1( + aggregate, + proj1, + join1, + proj2, + join2, + proj3, + join3, + filter, + relation, + output, + requiredSchema, + partitionFilters, + optionalBucketSet, + optionalNumCoalescedBuckets, + dataFilters, + tableIdentifier, + disableBucketedScan) + }*/ + override protected def doCanonicalize(): SparkPlan = WrapperLeafExec() +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..8fd63a24dfd5e3c29cbcd216ba7aef9bded6d57e --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.concurrent.TimeUnit.NANOSECONDS +import com.huawei.boostkit.spark.Constant.{IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP} +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.constants.FunctionType +import nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL +import nova.hetu.omniruntime.operator.aggregator.OmniHashAggregationWithExprOperatorFactory +import nova.hetu.omniruntime.operator.config.OperatorConfig +import nova.hetu.omniruntime.vector.VecBatch +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.execution.ColumnarProjection.dealPartitionData +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Hash-based aggregate operator that can also fallback to sorting when data exceeds memory size. + */ +case class ColumnarHashAggregateExec( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends BaseAggregateExec + with AliasAwareOutputPartitioning { + + override lazy val metrics = Map( + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), + "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + + + override def supportsColumnar: Boolean = true + + override def nodeName: String = "OmniColumnarHashAggregate" + + def buildCheck(): Unit = { + val attrExpsIdMap = getExprIdMap(child.output) + val omniGroupByChanel: Array[AnyRef] = groupingExpressions.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, attrExpsIdMap)).toArray + + var omniInputRaw = false + var omniOutputPartial = false + val omniAggTypes = new Array[DataType](aggregateExpressions.size) + val omniAggFunctionTypes = new Array[FunctionType](aggregateExpressions.size) + val omniAggOutputTypes = new Array[DataType](aggregateExpressions.size) + var omniAggChannels = new Array[AnyRef](aggregateExpressions.size) + var index = 0 + for (exp <- aggregateExpressions) { + if (exp.filter.isDefined) { + throw new UnsupportedOperationException("Unsupported filter in AggregateExpression") + } + if (exp.isDistinct) { + throw new UnsupportedOperationException(s"Unsupported aggregate expression with distinct flag") + } + if (exp.mode == Final) { + exp.aggregateFunction match { + case Sum(_) | Min(_) | Max(_) | Count(_) => + val aggExp = exp.aggregateFunction.inputAggBufferAttributes.head + omniAggTypes(index) = sparkTypeToOmniType(aggExp.dataType, aggExp.metadata) + omniAggFunctionTypes(index) = toOmniAggFunTypeWithFinal(exp, true, true) + omniAggOutputTypes(index) = + sparkTypeToOmniType(exp.aggregateFunction.dataType) + omniAggChannels(index) = + rewriteToOmniJsonExpressionLiteral(aggExp, attrExpsIdMap) + case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") + } + } else if (exp.mode == Partial) { + omniInputRaw = true + omniOutputPartial = true + exp.aggregateFunction match { + case Sum(_) | Min(_) | Max(_) | Count(_) => + val aggExp = exp.aggregateFunction.children.head + omniAggTypes(index) = sparkTypeToOmniType(aggExp.dataType) + omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) + omniAggOutputTypes(index) = + sparkTypeToOmniType(exp.aggregateFunction.dataType) + omniAggChannels(index) = + rewriteToOmniJsonExpressionLiteral(aggExp, attrExpsIdMap) + if (omniAggFunctionTypes(index) == OMNI_AGGREGATION_TYPE_COUNT_ALL) { + omniAggChannels(index) = null + } + case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: $exp") + } + } else { + throw new UnsupportedOperationException(s"Unsupported aggregate mode: $exp.mode") + } + index += 1 + } + omniAggChannels = omniAggChannels.filter(key => key != null) + val omniSourceTypes = new Array[DataType](child.outputSet.size) + val inputIter = child.outputSet.toIterator + var i = 0 + while (inputIter.hasNext) { + val inputAttr = inputIter.next() + omniSourceTypes(i) = sparkTypeToOmniType(inputAttr.dataType, inputAttr.metadata) + i += 1 + } + + checkOmniJsonWhiteList("", omniAggChannels) + checkOmniJsonWhiteList("", omniGroupByChanel) + + // check for final project + if (!omniOutputPartial) { + val finalOut = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes + val projectInputTypes = finalOut.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + val projectExpressions: Array[AnyRef] = resultExpressions.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, getExprIdMap(finalOut))).toArray + checkOmniJsonWhiteList("", projectExpressions) + } + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val addInputTime = longMetric("addInputTime") + val numInputRows = longMetric("numInputRows") + val numInputVecBatchs = longMetric("numInputVecBatchs") + val omniCodegenTime = longMetric("omniCodegenTime") + val getOutputTime = longMetric("getOutputTime") + val numOutputRows = longMetric("numOutputRows") + val numOutputVecBatchs = longMetric("numOutputVecBatchs") + + val attrExpsIdMap = getExprIdMap(child.output) + val omniGroupByChanel = groupingExpressions.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, attrExpsIdMap)).toArray + + var omniInputRaw = false + var omniOutputPartial = false + val omniAggTypes = new Array[DataType](aggregateExpressions.size) + val omniAggFunctionTypes = new Array[FunctionType](aggregateExpressions.size) + val omniAggOutputTypes = new Array[DataType](aggregateExpressions.size) + var omniAggChannels = new Array[String](aggregateExpressions.size) + var index = 0 + for (exp <- aggregateExpressions) { + if (exp.filter.isDefined) { + throw new UnsupportedOperationException("Unsupported filter in AggregateExpression") + } + if (exp.isDistinct) { + throw new UnsupportedOperationException("Unsupported aggregate expression with distinct flag") + } + if (exp.mode == Final) { + exp.aggregateFunction match { + case Sum(_) | Min(_) | Max(_) | Count(_) => + val aggExp = exp.aggregateFunction.inputAggBufferAttributes.head + omniAggTypes(index) = sparkTypeToOmniType(aggExp.dataType, aggExp.metadata) + omniAggFunctionTypes(index) = toOmniAggFunTypeWithFinal(exp, true, true) + omniAggOutputTypes(index) = + sparkTypeToOmniType(exp.aggregateFunction.dataType) + omniAggChannels(index) = + rewriteToOmniJsonExpressionLiteral(aggExp, attrExpsIdMap) + case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") + } + } else if (exp.mode == Partial) { + omniInputRaw = true + omniOutputPartial = true + exp.aggregateFunction match { + case Sum(_) | Min(_) | Max(_) | Count(_) => + val aggExp = exp.aggregateFunction.children.head + omniAggTypes(index) = sparkTypeToOmniType(aggExp.dataType) + omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) + omniAggOutputTypes(index) = + sparkTypeToOmniType(exp.aggregateFunction.dataType) + omniAggChannels(index) = + rewriteToOmniJsonExpressionLiteral(aggExp, attrExpsIdMap) + if (omniAggFunctionTypes(index) == OMNI_AGGREGATION_TYPE_COUNT_ALL) { + omniAggChannels(index) = null + } + case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") + } + } else { + throw new UnsupportedOperationException(s"Unsupported aggregate mode: ${exp.mode}") + } + index += 1 + } + + omniAggChannels = omniAggChannels.filter(key => key != null) + val omniSourceTypes = new Array[DataType](child.outputSet.size) + val inputIter = child.outputSet.toIterator + var i = 0 + while (inputIter.hasNext) { + val inputAttr = inputIter.next() + omniSourceTypes(i) = sparkTypeToOmniType(inputAttr.dataType, inputAttr.metadata) + i += 1 + } + + child.executeColumnar().mapPartitionsWithIndex { (index, iter) => + val startCodegen = System.nanoTime() + val factory = new OmniHashAggregationWithExprOperatorFactory( + omniGroupByChanel, + omniAggChannels, + omniSourceTypes, + omniAggFunctionTypes, + omniAggOutputTypes, + omniInputRaw, + omniOutputPartial, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val operator = factory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + + while (iter.hasNext) { + val batch = iter.next() + val input = transColBatchToOmniVecs(batch) + val startInput = System.nanoTime() + val vecBatch = new VecBatch(input, batch.numRows()) + operator.addInput(vecBatch) + addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) + numInputVecBatchs += 1 + numInputRows += batch.numRows() + } + val startGetOp = System.nanoTime() + val opOutput = operator.getOutput + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + operator.close() + }) + + var localSchema = this.schema + if (!omniOutputPartial) { + val omnifinalOutSchama = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes + localSchema = StructType.fromAttributes(omnifinalOutSchama) + } + + val hashAggIter = new Iterator[ColumnarBatch] { + override def hasNext: Boolean = { + val startGetOp: Long = System.nanoTime() + val hasNext = opOutput.hasNext + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + hasNext + } + + override def next(): ColumnarBatch = { + val startGetOp = System.nanoTime() + val vecBatch = opOutput.next() + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, localSchema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(i)) + } + + numOutputRows += vecBatch.getRowCount + numOutputVecBatchs += 1 + + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + } + } + if (!omniOutputPartial) { + val finalOut = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes + val projectInputTypes = finalOut.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + val projectExpressions = resultExpressions.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, getExprIdMap(finalOut))).toArray + + dealPartitionData(null, null, addInputTime, omniCodegenTime, + getOutputTime, projectInputTypes, projectExpressions, hashAggIter, this.schema) + } else { + hashAggIter + } + } + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException("This operator doesn't support doExecute().") + } +} + +object ColumnarHashAggregateExec { + def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = { + val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala new file mode 100644 index 0000000000000000000000000000000000000000..6c8805589104e4bda04fd1d1776565df36ebcc58 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala @@ -0,0 +1,101 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.concurrent.TimeUnit.NANOSECONDS +import com.huawei.boostkit.spark.Constant.{IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP} +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.operator.config.OperatorConfig +import nova.hetu.omniruntime.operator.project.OmniProjectOperatorFactory +import nova.hetu.omniruntime.vector.VecBatch +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * @since 2022/3/5 + */ +object ColumnarProjection { + def dealPartitionData(numOutputRows: SQLMetric, numOutputVecBatchs: SQLMetric, + addInputTime: SQLMetric, + omniCodegenTime: SQLMetric, + getOutputTime: SQLMetric, omniInputTypes: Array[DataType], + omniExpressions: Array[String], iter: Iterator[ColumnarBatch], + schema: StructType): Iterator[ColumnarBatch] = { + val startCodegen = System.nanoTime() + val projectOperatorFactory = new OmniProjectOperatorFactory(omniExpressions, omniInputTypes, 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val projectOperator = projectOperatorFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + // close operator + addLeakSafeTaskCompletionListener[Unit](_ => { + projectOperator.close() + }) + + new Iterator[ColumnarBatch] { + private var results: java.util.Iterator[VecBatch] = _ + + override def hasNext: Boolean = { + while ((results == null || !results.hasNext) && iter.hasNext) { + val batch = iter.next() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()); + val startInput = System.nanoTime() + projectOperator.addInput(vecBatch) + addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) + + val startGetOp = System.nanoTime() + results = projectOperator.getOutput + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + } + if (results == null) { + false + } else { + val startGetOp: Long = System.nanoTime() + val hasNext = results.hasNext + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + hasNext + } + } + + override def next(): ColumnarBatch = { + val startGetOp = System.nanoTime() + val result = results.next() + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + result.getRowCount, schema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(result.getVectors()(i)) + } + if(numOutputRows != null) { + numOutputRows += result.getRowCount + } + if (numOutputVecBatchs != null) { + numOutputVecBatchs += 1 + } + result.close() + new ColumnarBatch(vectors.toArray, result.getRowCount) + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..96cb162a34b8db2f3d03a2c45024b27dd6ab81f3 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import com.huawei.boostkit.spark.ColumnarPluginConfig + +import java.util.Random +import com.huawei.boostkit.spark.Constant.{IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP} + +import scala.collection.JavaConverters._ +import scala.concurrent.Future +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ +import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import com.huawei.boostkit.spark.vectorized.PartitionInfo +import nova.hetu.omniruntime.`type`.{DataType, DataTypeSerializer} +import nova.hetu.omniruntime.operator.config.OperatorConfig +import nova.hetu.omniruntime.operator.project.OmniProjectOperatorFactory +import nova.hetu.omniruntime.vector.{IntVec, VecBatch} +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ColumnarShuffleDependency +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleOrigin} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.createShuffleWriteProcessor +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.execution.util.MergeIterator +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.MutablePair + +class ColumnarShuffleExchangeExec( + override val outputPartitioning: Partitioning, + child: SparkPlan, + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) + extends ShuffleExchangeExec(outputPartitioning, child) { + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + override lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics: Map[String, SQLMetric] = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "bytesSpilled" -> SQLMetrics.createSizeMetric(sparkContext, "shuffle bytes spilled"), + "splitTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "totaltime_split"), + "spillTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle spill time"), + "compressTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "totaltime_compress"), + "avgReadBatchNumRows" -> SQLMetrics + .createAverageMetric(sparkContext, "avg read batch num rows"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), + "numOutputRows" -> SQLMetrics + .createMetric(sparkContext, "number of output rows")) ++ readMetrics ++ writeMetrics + + override def nodeName: String = "OmniColumnarShuffleExchange" + + override def supportsColumnar: Boolean = true + + val serializer: Serializer = new ColumnarBatchSerializer( + longMetric("avgReadBatchNumRows"), + longMetric("numOutputRows")) + + @transient lazy val inputColumnarRDD: RDD[ColumnarBatch] = child.executeColumnar() + + // 'mapOutputStatisticsFuture' is only needed when enable AQE. + @transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { + if (inputColumnarRDD.getNumPartitions == 0) { + Future.successful(null) + } else { + sparkContext.submitMapStage(columnarShuffleDependency) + } + } + + @transient + lazy val columnarShuffleDependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + ColumnarShuffleExchangeExec.prepareShuffleDependency( + inputColumnarRDD, + child.output, + outputPartitioning, + serializer, + writeMetrics, + longMetric("dataSize"), + longMetric("bytesSpilled"), + longMetric("numInputRows"), + longMetric("splitTime"), + longMetric("spillTime")) + } + + var cachedShuffleRDD: ShuffledColumnarRDD = _ + + override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException() + } + + def buildCheck(): Unit = { + val inputTypes = new Array[DataType](child.output.size) + child.output.zipWithIndex.foreach { + case (attr, i) => + inputTypes(i) = sparkTypeToOmniType(attr.dataType, attr.metadata) + } + + outputPartitioning match { + case HashPartitioning(expressions, numPartitions) => + val genHashExpressionFunc = ColumnarShuffleExchangeExec.genHashExpr() + val hashJSonExpressions = genHashExpressionFunc(expressions, numPartitions, ColumnarShuffleExchangeExec.defaultMm3HashSeed, child.output) + checkOmniJsonWhiteList("", Array(hashJSonExpressions)) + case _ => + } + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + if (cachedShuffleRDD == null) { + cachedShuffleRDD = new ShuffledColumnarRDD(columnarShuffleDependency, readMetrics) + } + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableShuffleBatchMerge: Boolean = columnarConf.enableShuffleBatchMerge + if (enableShuffleBatchMerge) { + cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => + new MergeIterator(iter, + StructType.fromAttributes(child.output), + longMetric("numMergedVecBatchs")) + } + } else { + cachedShuffleRDD + } + } +} + +object ColumnarShuffleExchangeExec extends Logging { + val defaultMm3HashSeed: Int = 42; + + def prepareShuffleDependency( + rdd: RDD[ColumnarBatch], + outputAttributes: Seq[Attribute], + newPartitioning: Partitioning, + serializer: Serializer, + writeMetrics: Map[String, SQLMetric], + dataSize: SQLMetric, + bytesSpilled: SQLMetric, + numInputRows: SQLMetric, + splitTime: SQLMetric, + spillTime: SQLMetric): + ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + + + val rangePartitioner: Option[Partitioner] = newPartitioning match { + case RangePartitioning(sortingExpressions, numPartitions) => + // Extract only fields used for sorting to avoid collecting large fields that does not + // affect sorting result when deciding partition bounds in RangePartitioner + val rddForSampling = rdd.mapPartitionsInternal { iter => + // Internally, RangePartitioner runs a job on the RDD that samples keys to compute + // partition bounds. To get accurate samples, we need to copy the mutable keys. + iter.flatMap(batch => { + val rows: Iterator[InternalRow] = batch.rowIterator.asScala + val projection = + UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) + val mutablePair = new MutablePair[InternalRow, Null]() + new Iterator[MutablePair[InternalRow, Null]] { + var closed = false + override def hasNext: Boolean = { + val has: Boolean = rows.hasNext + if (!has && !closed) { + batch.close() + closed = true + } + has + } + override def next(): MutablePair[InternalRow, Null] = { + mutablePair.update(projection(rows.next()).copy(), null) + } + } + }) + } + // Construct ordering on extracted sort key. + val orderingAttributes: Seq[SortOrder] = sortingExpressions.zipWithIndex.map { + case (ord, i) => + ord.copy(child = BoundReference(i, ord.dataType, ord.nullable)) + } + implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes) + val part = new RangePartitioner( + numPartitions, + rddForSampling, + ascending = true, + samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) + Some(part) + case _ => None + } + + val inputTypes = new Array[DataType](outputAttributes.size) + outputAttributes.zipWithIndex.foreach { + case (attr, i) => + inputTypes(i) = sparkTypeToOmniType(attr.dataType, attr.metadata) + } + + // gen RoundRobin pid + def getRoundRobinPartitionKey: (ColumnarBatch, Int) => IntVec = { + // 随机数 + (columnarBatch: ColumnarBatch, numPartitions: Int) => { + val pidArr = new Array[Int](columnarBatch.numRows()) + for (i <- 0 until columnarBatch.numRows()) { + val position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions) + pidArr(i) = position + 1 + } + val vec = new IntVec(columnarBatch.numRows()) + vec.put(pidArr, 0, 0, pidArr.length) + vec + } + } + + def addPidToColumnBatch(): (IntVec, ColumnarBatch) => (Int, ColumnarBatch) = (pidVec, cb) => { + val pidVecTmp = new OmniColumnVector(cb.numRows(), IntegerType, false) + pidVecTmp.setVec(pidVec) + val newColumns = (pidVecTmp +: (0 until cb.numCols).map(cb.column)).toArray + (0, new ColumnarBatch(newColumns, cb.numRows)) + } + + // only used for fallback range partitioning + def computeAndAddRangePartitionId( + cbIter: Iterator[ColumnarBatch], + partitionKeyExtractor: InternalRow => Any): Iterator[(Int, ColumnarBatch)] = { + val addPid2ColumnBatch = addPidToColumnBatch() + cbIter.filter(cb => cb.numRows != 0 && cb.numCols != 0).map { + cb => + val pidArr = new Array[Int](cb.numRows) + (0 until cb.numRows).foreach { i => + val row = cb.getRow(i) + val pid = rangePartitioner.get.getPartition(partitionKeyExtractor(row)) + pidArr(i) = pid + } + val pidVec = new IntVec(cb.numRows) + pidVec.put(pidArr, 0, 0, cb.numRows) + + addPid2ColumnBatch(pidVec, cb) + } + } + + val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] && + newPartitioning.numPartitions > 1 + val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition + + val rddWithPartitionId: RDD[Product2[Int, ColumnarBatch]] = newPartitioning match { + case RoundRobinPartitioning(numPartitions) => + // 按随机数分区 + rdd.mapPartitionsWithIndexInternal((_, cbIter) => { + val getRoundRobinPid = getRoundRobinPartitionKey + val addPid2ColumnBatch = addPidToColumnBatch() + cbIter.map { cb => + val pidVec = getRoundRobinPid(cb, numPartitions) + addPid2ColumnBatch(pidVec, cb) + } + }, isOrderSensitive = isOrderSensitive) + case RangePartitioning(sortingExpressions, _) => + // 排序,按采样数据进行分区 + rdd.mapPartitionsWithIndexInternal((_, cbIter) => { + val partitionKeyExtractor: InternalRow => Any = { + val projection = + UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) + row => projection(row) + } + val newIter = computeAndAddRangePartitionId(cbIter, partitionKeyExtractor) + newIter + }, isOrderSensitive = isOrderSensitive) + case HashPartitioning(expressions, numPartitions) => + rdd.mapPartitionsWithIndexInternal((_, cbIter) => { + val addPid2ColumnBatch = addPidToColumnBatch() + // omni project + val genHashExpression = genHashExpr() + val omniExpr: String = genHashExpression(expressions, numPartitions, defaultMm3HashSeed, outputAttributes) + val factory = new OmniProjectOperatorFactory(Array(omniExpr), inputTypes, 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val op = factory.createOperator() + + cbIter.map { cb => + val vecs = transColBatchToOmniVecs(cb, true) + op.addInput(new VecBatch(vecs, cb.numRows())) + val res = op.getOutput + if (res.hasNext) { + // TODO call next() once while get all result? + val retBatch = res.next() + val pidVec = retBatch.getVectors()(0) + // close return VecBatch + retBatch.close() + addPid2ColumnBatch(pidVec.asInstanceOf[IntVec], cb) + } else { + throw new Exception("Empty Project Operator Result...") + } + } + }, isOrderSensitive = isOrderSensitive) + case SinglePartition => + rdd.mapPartitionsWithIndexInternal((_, cbIter) => { + cbIter.map { cb => (0, cb) } + }, isOrderSensitive = isOrderSensitive) + } + + val numCols = outputAttributes.size + val intputTypeArr: Seq[DataType] = outputAttributes.map { attr => + sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val intputTypes = DataTypeSerializer.serialize(intputTypeArr.toArray) + + val partitionInfo: PartitionInfo = newPartitioning match { + case SinglePartition => + new PartitionInfo("single", 1, numCols, intputTypes) + case RoundRobinPartitioning(numPartitions) => + new PartitionInfo("rr", numPartitions, numCols, intputTypes) + case HashPartitioning(expressions, numPartitions) => + new PartitionInfo("hash", numPartitions, numCols, intputTypes) + case RangePartitioning(ordering, numPartitions) => + new PartitionInfo("range", numPartitions, numCols, intputTypes) + } + + new ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( + rddWithPartitionId, + new PartitionIdPassthrough(newPartitioning.numPartitions), + serializer, + shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics), + partitionInfo = partitionInfo, + dataSize = dataSize, + bytesSpilled = bytesSpilled, + numInputRows = numInputRows, + splitTime = splitTime, + spillTime = spillTime) + } + + // gen hash partition expression + def genHashExpr(): (Seq[Expression], Int, Int, Seq[Attribute]) => String = { + (expressions: Seq[Expression], numPartitions: Int, seed: Int, outputAttributes: Seq[Attribute]) => { + val exprIdMap = getExprIdMap(outputAttributes) + val EXP_JSON_FORMATER1 = ("{\"exprType\":\"FUNCTION\",\"returnType\":1,\"function_name\":\"%s\",\"arguments\":[" + + "%s,{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":%d}]}") + val EXP_JSON_FORMATER2 = ("{\"exprType\": \"FUNCTION\",\"returnType\":1,\"function_name\":\"%s\", \"arguments\": [%s,%s] }") + var omniExpr: String = "" + expressions.foreach { expr => + val colExpr = rewriteToOmniJsonExpressionLiteral(expr, exprIdMap) + if (omniExpr.isEmpty) { + omniExpr = EXP_JSON_FORMATER1.format("mm3hash", colExpr, seed) + } else { + omniExpr = EXP_JSON_FORMATER2.format("mm3hash", colExpr, omniExpr) + } + } + omniExpr = EXP_JSON_FORMATER1.format("pmod", omniExpr, numPartitions) + logDebug(s"hash omni expression: $omniExpr") + omniExpr + } + } + +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..3edf533c2b933544c2822a540de99343154d8bba --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.{File, IOException} +import java.util.UUID +import java.util.concurrent.TimeUnit.NANOSECONDS + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import com.huawei.boostkit.spark.Constant.{IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP} +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{addAllAndGetIterator, genSortParam} +import nova.hetu.omniruntime.operator.config.{OperatorConfig, SparkSpillConfig} +import nova.hetu.omniruntime.operator.sort.OmniSortWithExprOperatorFactory +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils + +case class ColumnarSortExec( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan, + testSpillFrequency: Int = 0) + extends UnaryExecNode { + + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + + override def supportsColumnar: Boolean = true + + override def nodeName: String = "OmniColumnarSort" + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + override lazy val metrics = Map( + + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "outputDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "output data size"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + + def buildCheck(): Unit = { + genSortParam(child.output, sortOrder) + } + + val sparkConfTmp = sparkContext.conf + + private def generateLocalDirs(conf: SparkConf): Array[File] = { + Utils.getConfiguredLocalDirs(conf).flatMap { rootDir => + val localDir = generateDirs(rootDir, "columnarSortSpill") + Some(localDir) + } + } + + def generateDirs(root: String, namePrefix: String = "spark"):File = { + var attempts = 0 + val maxAttempts = MAX_DIR_CREATION_ATTEMPTS + var dir: File = null + while (dir == null) { + attempts += 1 + if (attempts > maxAttempts) { + throw new IOException("Directory conflict: failed to generate a temp directory for columnarSortSpill " + + "(under " + root + ") after " + maxAttempts + " attempts!") + } + dir = new File(root, namePrefix + "-" + UUID.randomUUID.toString) + if (dir.exists()) { + dir = null + } + } + dir.getCanonicalFile + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val omniCodegenTime = longMetric("omniCodegenTime") + + val (sourceTypes, ascendings, nullFirsts, sortColsExp) = genSortParam(child.output, sortOrder) + val outputCols = output.indices.toArray + + child.executeColumnar().mapPartitionsWithIndexInternal { (_, iter) => + val columnarConf = ColumnarPluginConfig.getSessionConf + val sortSpillRowThreshold = columnarConf.columnarSortSpillRowThreshold + val sortSpillDirDiskReserveSize = columnarConf.columnarSortSpillDirDiskReserveSize + val sortSpillEnable = columnarConf.enableSortSpill + val sortlocalDirs: Array[File] = generateLocalDirs(sparkConfTmp) + val hash = Utils.nonNegativeHash(SparkEnv.get.executorId) + val dirId = hash % sortlocalDirs.length + val spillPathDir = sortlocalDirs(dirId).getCanonicalPath + val sparkSpillConf = new SparkSpillConfig(sortSpillEnable, spillPathDir, + sortSpillDirDiskReserveSize, sortSpillRowThreshold) + val startCodegen = System.nanoTime() + val sortOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, outputCols, + sortColsExp, ascendings, nullFirsts, new OperatorConfig(IS_ENABLE_JIT, sparkSpillConf, IS_SKIP_VERIFY_EXP)) + val sortOperator = sortOperatorFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + sortOperator.close() + }) + addAllAndGetIterator(sortOperator, iter, this.schema, + longMetric("addInputTime"), longMetric("numInputVecBatchs"), longMetric("numInputRows"), + longMetric("getOutputTime"), longMetric("numOutputVecBatchs"), longMetric("numOutputRows"), + longMetric("outputDataSize")) + } + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..e0d920e014d77c2271a0e7a68f897a975c488dbb --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.concurrent.TimeUnit.NANOSECONDS +import com.huawei.boostkit.spark.Constant.{IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP} +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, getExprIdMap, rewriteToOmniJsonExpressionLiteral, sparkTypeToOmniType} +import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{addAllAndGetIterator, genSortParam} +import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.operator.config.OperatorConfig +import nova.hetu.omniruntime.operator.topn.OmniTopNWithExprOperatorFactory +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.execution.ColumnarProjection.dealPartitionData +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +case class ColumnarTakeOrderedAndProjectExec( + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Seq[NamedExpression], + child: SparkPlan) + extends UnaryExecNode { + + override def supportsColumnar: Boolean = true + + override def nodeName: String = "OmniColumnarTakeOrderedAndProject" + + val serializer: Serializer = new ColumnarBatchSerializer( + longMetric("avgReadBatchNumRows"), + longMetric("numOutputRows")) + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics: Map[String, SQLMetric] = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "bytesSpilled" -> SQLMetrics.createSizeMetric(sparkContext, "shuffle bytes spilled"), + "splitTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "totaltime_split"), + "spillTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle spill time"), + "avgReadBatchNumRows" -> SQLMetrics + .createAverageMetric(sparkContext, "avg read batch num rows"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics + .createMetric(sparkContext, "number of output rows"), + // omni + "outputDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "output data size"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput") + ) ++ readMetrics ++ writeMetrics + + override def output: Seq[Attribute] = { + projectList.map(_.toAttribute) + } + + override def executeCollect(): Array[InternalRow] = { + throw new UnsupportedOperationException + } + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException + } + + def buildCheck(): Unit = { + genSortParam(child.output, sortOrder) + val projectEqualChildOutput = projectList == child.output + var omniInputTypes: Array[DataType] = null + var omniExpressions: Array[AnyRef] = null + if (!projectEqualChildOutput) { + omniInputTypes = child.output.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + omniExpressions = projectList.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, getExprIdMap(child.output))).toArray + checkOmniJsonWhiteList("", omniExpressions) + } + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val (sourceTypes, ascendings, nullFirsts, sortColsExp) = genSortParam(child.output, sortOrder) + + def computeTopN(iter: Iterator[ColumnarBatch], schema: StructType): Iterator[ColumnarBatch] = { + val startCodegen = System.nanoTime() + val topNOperatorFactory = new OmniTopNWithExprOperatorFactory(sourceTypes, limit, + sortColsExp, ascendings, nullFirsts, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val topNOperator = topNOperatorFactory.createOperator + longMetric("omniCodegenTime") += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]( _ => { + topNOperator.close() + }) + addAllAndGetIterator(topNOperator, iter, schema, + longMetric("addInputTime"), longMetric("numInputVecBatchs"), longMetric("numInputRows"), + longMetric("getOutputTime"), longMetric("numOutputVecBatchs"), longMetric("numOutputRows"), + longMetric("outputDataSize")) + } + + val localTopK: RDD[ColumnarBatch] = { + child.executeColumnar().mapPartitionsWithIndexInternal { (_, iter) => + computeTopN(iter, this.child.schema) + } + } + + val shuffled = new ShuffledColumnarRDD( + ColumnarShuffleExchangeExec.prepareShuffleDependency( + localTopK, + child.output, + SinglePartition, + serializer, + writeMetrics, + longMetric("dataSize"), + longMetric("bytesSpilled"), + longMetric("numInputRows"), + longMetric("splitTime"), + longMetric("spillTime")), + readMetrics) + val projectEqualChildOutput = projectList == child.output + var omniInputTypes: Array[DataType] = null + var omniExpressions: Array[String] = null + var addInputTime: SQLMetric = null + var omniCodegenTime: SQLMetric = null + var getOutputTime: SQLMetric = null + if (!projectEqualChildOutput) { + omniInputTypes = child.output.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + omniExpressions = projectList.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, getExprIdMap(child.output))).toArray + addInputTime = longMetric("addInputTime") + omniCodegenTime = longMetric("omniCodegenTime") + getOutputTime = longMetric("getOutputTime") + } + shuffled.mapPartitions { iter => + // TopN = omni-top-n + omni-project + val topN: Iterator[ColumnarBatch] = computeTopN(iter, this.child.schema) + if (!projectEqualChildOutput) { + dealPartitionData(null, null, addInputTime, omniCodegenTime, + getOutputTime, omniInputTypes, omniExpressions, topN, this.schema) + } else { + topN + } + } + } + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override def outputPartitioning: Partitioning = SinglePartition + + override def simpleString(maxFields: Int): String = { + val orderByString = truncatedString(sortOrder, "[", ",", "]", maxFields) + val outputString = truncatedString(output, "[", ",", "]", maxFields) + + s"OmniColumnarTakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..851973bd5c04d42e694465395abe563e3f176ea8 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import com.huawei.boostkit.spark.Constant.{IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP} +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.constants.{FunctionType, OmniWindowFrameBoundType, OmniWindowFrameType} +import nova.hetu.omniruntime.operator.config.OperatorConfig +import nova.hetu.omniruntime.operator.window.OmniWindowWithExprOperatorFactory +import nova.hetu.omniruntime.vector.VecBatch +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.ColumnarProjection.dealPartitionData +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.execution.window.WindowExecBase +import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch + +case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], child: SparkPlan) + extends WindowExecBase { + + override def nodeName: String = "OmniColumnarWindow" + + override def supportsColumnar: Boolean = true + + override lazy val metrics = Map( + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = { + if (partitionSpec.isEmpty) { + // Only show warning when the number of bytes is larger than 100 MiB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") + AllTuples :: Nil + } else ClusteredDistribution(partitionSpec) :: Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") + } + + def checkAggFunInOutDataType(funcInDataType: org.apache.spark.sql.types.DataType, funcOutDataType: org.apache.spark.sql.types.DataType): Unit = { + //for decimal, only support decimal64 to decimal128 output + if(funcInDataType.isInstanceOf[DecimalType] && funcOutDataType.isInstanceOf[DecimalType]) { + if (!DecimalType.is64BitDecimalType(funcOutDataType.asInstanceOf[DecimalType])) + throw new UnsupportedOperationException(s"output only support decimal128 type, inDataType:${funcInDataType} outDataType:${funcOutDataType}" ) + } + } + + def getWindowFrameParam(frame: SpecifiedWindowFrame): (OmniWindowFrameType, + OmniWindowFrameBoundType, OmniWindowFrameBoundType, Int, Int) = { + var windowFrameStartChannel = -1 + var windowFrameEndChannel = -1 + val windowFrameType = frame.frameType match { + case RangeFrame => + OmniWindowFrameType.OMNI_FRAME_TYPE_RANGE + case RowFrame => + OmniWindowFrameType.OMNI_FRAME_TYPE_ROWS + } + + val windowFrameStartType = frame.lower match { + case UnboundedPreceding => + OmniWindowFrameBoundType.OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING + case UnboundedFollowing => + OmniWindowFrameBoundType.OMNI_FRAME_BOUND_UNBOUNDED_FOLLOWING + case CurrentRow => + OmniWindowFrameBoundType.OMNI_FRAME_BOUND_CURRENT_ROW + case literal: Literal => + windowFrameStartChannel = literal.value.toString.toInt + OmniWindowFrameBoundType.OMNI_FRAME_BOUND_PRECEDING + } + + val windowFrameEndType = frame.upper match { + case UnboundedPreceding => + OmniWindowFrameBoundType.OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING + case UnboundedFollowing => + OmniWindowFrameBoundType.OMNI_FRAME_BOUND_UNBOUNDED_FOLLOWING + case CurrentRow => + OmniWindowFrameBoundType.OMNI_FRAME_BOUND_CURRENT_ROW + case literal: Literal => + windowFrameEndChannel = literal.value.toString.toInt + OmniWindowFrameBoundType.OMNI_FRAME_BOUND_FOLLOWING + } + (windowFrameType, windowFrameStartType, windowFrameEndType, + windowFrameStartChannel, windowFrameEndChannel) + } + + def buildCheck(): Unit = { + val inputColSize = child.outputSet.size + val sourceTypes = new Array[DataType](inputColSize) + val winExpressions: Seq[Expression] = windowFrameExpressionFactoryPairs.flatMap(_._1) + val windowFunType = new Array[FunctionType](winExpressions.size) + var windowArgKeys = new Array[AnyRef](winExpressions.size) + val windowFunRetType = new Array[DataType](winExpressions.size) + val omniAttrExpsIdMap = getExprIdMap(child.output) + val windowFrameTypes = new Array[OmniWindowFrameType](winExpressions.size) + val windowFrameStartTypes = new Array[OmniWindowFrameBoundType](winExpressions.size) + val winddowFrameStartChannels = new Array[Int](winExpressions.size) + val windowFrameEndTypes = new Array[OmniWindowFrameBoundType](winExpressions.size) + val winddowFrameEndChannels = new Array[Int](winExpressions.size) + var attrMap: Map[String, Int] = Map() + val inputIter = child.outputSet.toIterator + var i = 0 + while (inputIter.hasNext) { + val inputAttr = inputIter.next() + sourceTypes(i) = sparkTypeToOmniType(inputAttr.dataType, inputAttr.metadata) + attrMap += (inputAttr.name -> i) + i += 1 + } + + var windowExpressionWithProject = false + winExpressions.zipWithIndex.foreach { case (x, index) => + x.foreach { + case e@WindowExpression(function, spec) => + if (spec.frameSpecification.isInstanceOf[SpecifiedWindowFrame]) { + val winFram = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + if (winFram.lower != UnboundedPreceding && winFram.lower != CurrentRow) { + throw new UnsupportedOperationException(s"Unsupported Specified frame_start: ${winFram.lower}") + } + if (winFram.upper != UnboundedFollowing && winFram.upper != CurrentRow) { + throw new UnsupportedOperationException(s"Unsupported Specified frame_end: ${winFram.upper}") + } + } + windowFunRetType(index) = sparkTypeToOmniType(function.dataType) + val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + val winFrameParam = getWindowFrameParam(frame) + windowFrameTypes(index) = winFrameParam._1 + windowFrameStartTypes(index) = winFrameParam._2 + windowFrameEndTypes(index) = winFrameParam._3 + winddowFrameStartChannels(index) = winFrameParam._4 + winddowFrameEndChannels(index) = winFrameParam._5 + function match { + // AggregateWindowFunction + case winfunc: WindowFunction => + windowFunType(index) = toOmniWindowFunType(winfunc) + windowArgKeys = winfunc.children.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray + // AggregateExpression + case agg@AggregateExpression(aggFunc, _, _, _, _) => + windowFunType(index) = toOmniAggFunType(agg) + windowArgKeys = aggFunc.children.map( + exp => { + checkAggFunInOutDataType(function.dataType, exp.dataType) + rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap) + }).toArray + case _ => throw new UnsupportedOperationException(s"Unsupported window function: ${function}") + } + case _ => + windowExpressionWithProject = true + } + } + + val winExpToReferences = winExpressions.zipWithIndex.map { case (e, i) => + // Results of window expressions will be on the right side of child's output + AttributeReference(String.valueOf(child.output.size + i), e.dataType, e.nullable)().toAttribute + } + val winExpToReferencesMap = winExpressions.zip(winExpToReferences).toMap + val patchedWindowExpression = windowExpression.map(_.transform(winExpToReferencesMap)) + if (windowExpressionWithProject) { + val finalOut = child.output ++ winExpToReferences + val projectInputTypes = finalOut.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + val projectExpressions: Array[AnyRef] = (child.output ++ patchedWindowExpression).map( + exp => rewriteToOmniJsonExpressionLiteral(exp, getExprIdMap(finalOut))).toArray + checkOmniJsonWhiteList("", projectExpressions) + } + checkOmniJsonWhiteList("", windowArgKeys) + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val addInputTime = longMetric("addInputTime") + val numInputRows = longMetric("numInputRows") + val numInputVecBatchs = longMetric("numInputVecBatchs") + val omniCodegenTime = longMetric("omniCodegenTime") + val numOutputRows = longMetric("numOutputRows") + val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val getOutputTime = longMetric("getOutputTime") + + val inputColSize = child.outputSet.size + val sourceTypes = new Array[DataType](inputColSize) + val sortCols = new Array[Int](orderSpec.size) + val ascendings = new Array[Int](orderSpec.size) + val nullFirsts = new Array[Int](orderSpec.size) + val winExpressions: Seq[Expression] = windowFrameExpressionFactoryPairs.flatMap(_._1) + val windowFunType = new Array[FunctionType](winExpressions.size) + val omminPartitionChannels = new Array[Int](partitionSpec.size) + val preGroupedChannels = new Array[Int](winExpressions.size) + var windowArgKeys = new Array[String](winExpressions.size) + val windowFunRetType = new Array[DataType](winExpressions.size) + val omniAttrExpsIdMap = getExprIdMap(child.output) + val windowFrameTypes = new Array[OmniWindowFrameType](winExpressions.size) + val windowFrameStartTypes = new Array[OmniWindowFrameBoundType](winExpressions.size) + val windowFrameStartChannels = new Array[Int](winExpressions.size) + val windowFrameEndTypes = new Array[OmniWindowFrameBoundType](winExpressions.size) + val windowFrameEndChannels = new Array[Int](winExpressions.size) + + var attrMap: Map[String, Int] = Map() + val inputIter = child.outputSet.toIterator + var i = 0 + while (inputIter.hasNext) { + val inputAttr = inputIter.next() + sourceTypes(i) = sparkTypeToOmniType(inputAttr.dataType, inputAttr.metadata) + attrMap += (inputAttr.name -> i) + i += 1 + } + // partition column parameters + + // sort column parameters + i = 0 + for (sortAttr <- orderSpec) { + if (attrMap.contains(sortAttr.child.asInstanceOf[AttributeReference].name)) { + sortCols(i) = attrMap(sortAttr.child.asInstanceOf[AttributeReference].name) + ascendings(i) = sortAttr.isAscending match { + case true => 1 + case _ => 0 + } + nullFirsts(i) = sortAttr.nullOrdering.sql match { + case "NULLS LAST" => 0 + case _ => 1 + } + } else { + throw new UnsupportedOperationException(s"Unsupported sort col not in inputset: ${sortAttr.nodeName}") + } + i += 1 + } + + i = 0 + // only window column no need to as output + val outputCols = new Array[Int](child.output.size) // 0, 1 + for (outputAttr <- child.output) { + if (attrMap.contains(outputAttr.name)) { + outputCols(i) = attrMap.get(outputAttr.name).get + } else { + throw new UnsupportedOperationException(s"output col not in input cols: ${outputAttr.name}") + } + i += 1 + } + + // partitionSpec: Seq[Expression] + i = 0 + for (partitionAttr <- partitionSpec) { + if (attrMap.contains(partitionAttr.asInstanceOf[AttributeReference].name)) { + omminPartitionChannels(i) = attrMap(partitionAttr.asInstanceOf[AttributeReference].name) + } else { + throw new UnsupportedOperationException(s"output col not in input cols: ${partitionAttr}") + } + i += 1 + } + + var windowExpressionWithProject = false + i = 0 + winExpressions.zipWithIndex.foreach { case (x, index) => + x.foreach { + case e@WindowExpression(function, spec) => + if (spec.frameSpecification.isInstanceOf[SpecifiedWindowFrame]) { + val winFram = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + if (winFram.lower != UnboundedPreceding && winFram.lower != CurrentRow) { + throw new UnsupportedOperationException(s"Unsupported Specified frame_start: ${winFram.lower}") + } + if (winFram.upper != UnboundedFollowing && winFram.upper != CurrentRow) { + throw new UnsupportedOperationException(s"Unsupported Specified frame_end: ${winFram.upper}") + } + } + windowFunRetType(index) = sparkTypeToOmniType(function.dataType) + val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + val winFrameParam = getWindowFrameParam(frame) + windowFrameTypes(index) = winFrameParam._1 + windowFrameStartTypes(index) = winFrameParam._2 + windowFrameEndTypes(index) = winFrameParam._3 + windowFrameStartChannels(index) = winFrameParam._4 + windowFrameEndChannels(index) = winFrameParam._5 + function match { + // AggregateWindowFunction + case winfunc: WindowFunction => + windowFunType(index) = toOmniWindowFunType(winfunc) + windowArgKeys(index) = null + // AggregateExpression + case agg@AggregateExpression(aggFunc, _, _, _, _) => + windowFunType(index) = toOmniAggFunType(agg) + windowArgKeys(index) = rewriteToOmniJsonExpressionLiteral(aggFunc.children.head, + omniAttrExpsIdMap) + case _ => throw new UnsupportedOperationException(s"Unsupported window function: ${function}") + } + case _ => + windowExpressionWithProject = true + } + } + windowArgKeys = windowArgKeys.filter(key => key != null) + + val winExpToReferences = winExpressions.zipWithIndex.map { case (e, i) => + // Results of window expressions will be on the right side of child's output + AttributeReference(String.valueOf(child.output.size + i), e.dataType, e.nullable)().toAttribute + } + val winExpToReferencesMap = winExpressions.zip(winExpToReferences).toMap + val patchedWindowExpression = windowExpression.map(_.transform(winExpToReferencesMap)) + + val windowExpressionWithProjectConstant = windowExpressionWithProject + child.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => + val startCodegen = System.nanoTime() + val windowOperatorFactory = new OmniWindowWithExprOperatorFactory(sourceTypes, outputCols, + windowFunType, omminPartitionChannels, preGroupedChannels, sortCols, ascendings, + nullFirsts, 0, 10000, windowArgKeys, windowFunRetType, + windowFrameTypes, windowFrameStartTypes, windowFrameStartChannels, windowFrameEndTypes, + windowFrameEndChannels, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val windowOperator = windowOperatorFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + + while (iter.hasNext) { + val batch = iter.next() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + val startInput = System.nanoTime() + windowOperator.addInput(vecBatch) + addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) + numInputVecBatchs += 1 + numInputRows += batch.numRows() + } + + val startGetOp = System.nanoTime() + val results = windowOperator.getOutput + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + windowOperator.close() + }) + + var windowResultSchema = this.schema + if (windowExpressionWithProjectConstant) { + val omnifinalOutSchema = child.output ++ winExpToReferences.map(_.toAttribute) + windowResultSchema = StructType.fromAttributes(omnifinalOutSchema) + } + val outputColSize = outputCols.length + val omniWindowResultIter = new Iterator[ColumnarBatch] { + override def hasNext: Boolean = { + val startGetOp: Long = System.nanoTime() + val hasNext = results.hasNext + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + hasNext + } + + override def next(): ColumnarBatch = { + val startGetOp = System.nanoTime() + val vecBatch = results.next() + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, windowResultSchema, false) + val offset = vecBatch.getVectors.length - vectors.size + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + if (i <= outputColSize - 1) { + vector.setVec(vecBatch.getVectors()(i)) + } else { + vector.setVec(vecBatch.getVectors()(i + offset)) + } + } + // release skip columnns memory + for (i <- outputColSize until outputColSize + offset) { + vecBatch.getVectors()(i).close() + } + numOutputRows += vecBatch.getRowCount + numOutputVecBatchs += 1 + + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + } + } + if (windowExpressionWithProjectConstant) { + val finalOut = child.output ++ winExpToReferences + val projectInputTypes = finalOut.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + val projectExpressions = (child.output ++ patchedWindowExpression).map( + exp => rewriteToOmniJsonExpressionLiteral(exp, getExprIdMap(finalOut))).toArray + dealPartitionData(null, null, addInputTime, omniCodegenTime, + getOutputTime, projectInputTypes, projectExpressions, omniWindowResultIter, this.schema) + } else { + omniWindowResultIter + } + } + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala new file mode 100644 index 0000000000000000000000000000000000000000..76d46aaae2040b094bec0367f73a69786ad493f7 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.{Dependency, MapOutputTrackerMaster, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * The [[Partition]] used by [[ShuffledRowRDD]]. + */ +private final case class ShuffledColumnarRDDPartition( + index: Int, spec: ShufflePartitionSpec) extends Partition + +class ShuffledColumnarRDD( + var dependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch], + metrics: Map[String, SQLMetric], + partitionSpecs: Array[ShufflePartitionSpec]) + extends RDD[ColumnarBatch](dependency.rdd.context, Nil) { + + def this( + dependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch], + metrics: Map[String, SQLMetric]) = { + this(dependency, metrics, + Array.tabulate(dependency.partitioner.numPartitions)(i => CoalescedPartitionSpec(i, i + 1))) + } + + dependency.rdd.context.setLocalProperty( + SortShuffleManager.FETCH_SHUFFLE_BLOCKS_IN_BATCH_ENABLED_KEY, + SQLConf.get.fetchShuffleBlocksInBatch.toString) + + override def getDependencies: Seq[Dependency[_]] = List(dependency) + + override val partitioner: Option[Partitioner] = + if (partitionSpecs.forall(_.isInstanceOf[CoalescedPartitionSpec])) { + val indices = partitionSpecs.map(_.asInstanceOf[CoalescedPartitionSpec].startReducerIndex) + // TODO this check is based on assumptions of callers' behavior but is sufficient for now. + if (indices.toSet.size == partitionSpecs.length) { + Some(new CoalescedPartitioner(dependency.partitioner, indices)) + } else { + None + } + } else { + None + } + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](partitionSpecs.length) { i => + ShuffledColumnarRDDPartition(i, partitionSpecs(i)) + } + } + + override def getPreferredLocations(partition: Partition): Seq[String] = { + val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + partition.asInstanceOf[ShuffledColumnarRDDPartition].spec match { + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => + // TODO order by partition size. + startReducerIndex.until(endReducerIndex).flatMap { reducerIndex => + tracker.getPreferredLocationsForShuffle(dependency, reducerIndex) + } + + case PartialReducerPartitionSpec(_, startMapIndex, endMapIndex, _) => + tracker.getMapLocation(dependency, startMapIndex, endMapIndex) + + case PartialMapperPartitionSpec(mapIndex, _, _) => + tracker.getMapLocation(dependency, mapIndex, mapIndex + 1) + } + } + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() + // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, + // as well as the `tempMetrics` for basic shuffle metrics. + val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) + val reader = split.asInstanceOf[ShuffledColumnarRDDPartition].spec match { + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + + case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex, _) => + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + startMapIndex, + endMapIndex, + reducerIndex, + reducerIndex + 1, + context, + sqlMetricsReporter) + + case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) => + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + mapIndex, + mapIndex + 1, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + } + reader.read().asInstanceOf[Iterator[Product2[Int, ColumnarBatch]]].map(_._2) + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala new file mode 100644 index 0000000000000000000000000000000000000000..0e5a7eae6efaac64d59c9effdcd8304d30c5c9fe --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import java.io.Serializable +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.orc.{OrcConf, OrcFile, TypeDescription} +import org.apache.orc.mapreduce.OrcInputFormat + +import org.apache.spark.TaskContext +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{SerializableConfiguration, Utils} + +class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializable { + + override def shortName(): String = "orc-native" + + override def toString: String = "ORC-NATIVE" + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[OmniOrcFileFormat] + + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + OrcUtils.inferSchema(sparkSession, files, options) + } + + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + + val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) + val sqlConf = sparkSession.sessionState.conf + val capacity = sqlConf.orcVectorizedReaderBatchSize + + OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(hadoopConf, sqlConf.caseSensitiveAnalysis) + + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + + (file: PartitionedFile) => { + val conf = broadcastedConf.value.value + + val filePath = new Path(new URI(file.filePath)) + + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val resultedColPruneInfo = + Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader => + OrcUtils.requestedColumnIds( + isCaseSensitive, dataSchema, requiredSchema, reader, conf) + } + + if (resultedColPruneInfo.isEmpty) { + Iterator.empty + } else { + // ORC predicate pushdown + if (orcFilterPushDown) { + OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { + fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f => + OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) + } + } + } + + val (requestedColIds, canPruneCols) = resultedColPruneInfo.get + val resultSchemaString = OrcUtils.orcResultSchemaString(canPruneCols, + dataSchema, resultSchema, partitionSchema, conf) + assert(requestedColIds.length == requiredSchema.length, + "[BUG] requested column IDs do not match required schema") + val taskConf = new Configuration(conf) + + val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) + + // read data from vectorized reader + val batchReader = new OmniOrcColumnarBatchReader(capacity) + // SPARK-23399 Register a task completion listener first to call `close()` in all cases. + // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) + // after opening a file. + val iter = new RecordReaderIterator(batchReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) + val requestedDataColIds = requestedColIds ++ Array.fill(partitionSchema.length)(-1) + val requestedPartitionColIds = + Array.fill(requiredSchema.length)(-1) ++ Range(0, partitionSchema.length) + SparkMemoryUtils.init() + batchReader.initialize(fileSplit, taskAttemptContext) + batchReader.initBatch( + TypeDescription.fromString(resultSchemaString), + resultSchema.fields, + requestedDataColIds, + requestedPartitionColIds, + file.partitionValues) + + iter.asInstanceOf[Iterator[InternalRow]] + } + } + } + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + throw new UnsupportedOperationException() + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..fcd601074d1573d4c66900cbb7563b4258a485a6 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.Locale + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.orc.{OrcConf, OrcFile, Reader, TypeDescription, Writer} + +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession} +import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils} +import org.apache.spark.sql.execution.datasources.SchemaMergeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.util.{ThreadUtils, Utils} + +object OrcUtils extends Logging { + + // The extensions for ORC compression codecs + val extensionsForCompressionCodecNames = Map( + "NONE" -> "", + "SNAPPY" -> ".snappy", + "ZLIB" -> ".zlib", + "LZO" -> ".lzo") + + def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) + .filterNot(_.isDirectory) + .map(_.getPath) + .filterNot(_.getName.startsWith("_")) + .filterNot(_.getName.startsWith(".")) + paths + } + + def readSchema(file: Path, conf: Configuration, ignoreCorruptFiles: Boolean) + : Option[TypeDescription] = { + val fs = file.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + try { + val schema = Utils.tryWithResource(OrcFile.createReader(file, readerOptions)) { reader => + reader.getSchema + } + if (schema.getFieldNames.isEmpty) { + None + } else { + Some(schema) + } + } catch { + case e: org.apache.orc.FileFormatException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $file", e) + None + } else { + throw new SparkException(s"Could not read footer for file: $file", e) + } + } + } + + private def toCatalystSchema(schema: TypeDescription): StructType = { + // The Spark query engine has not completely supported CHAR/VARCHAR type yet, and here we + // replace the orc CHAR/VARCHAR with STRING type. + CharVarcharUtils.replaceCharVarcharWithStringInSchema( + CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType]) + } + + def readSchema(sparkSession: SparkSession, files: Seq[FileStatus], options: Map[String, String]) + : Option[StructType] = { + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + val conf = sparkSession.sessionState.newHadoopConfWithOptions(options) + files.toIterator.map(file => readSchema(file.getPath, conf, ignoreCorruptFiles)).collectFirst { + case Some(schema) => + logDebug(s"Reading schema from file $files, got Hive schema string: $schema") + toCatalystSchema(schema) + } + } + + def readCatalystSchema( + file: Path, + conf: Configuration, + ignoreCorruptFiles: Boolean): Option[StructType] = { + readSchema(file, conf, ignoreCorruptFiles) match { + case Some(schema) => Some(toCatalystSchema(schema)) + + case None => + // Field names is empty or `FileFormatException` was thrown but ignoreCorruptFiles is true. + None + } + } + + /** + * Reads ORC file schemas in multi-threaded manner, using native version of ORC. + * This is visible for testing. + */ + def readOrcSchemasInParallel( + files: Seq[FileStatus], conf: Configuration, ignoreCorruptFiles: Boolean): Seq[StructType] = { + ThreadUtils.parmap(files, "readingOrcSchemas", 8) { currentFile => + OrcUtils.readSchema(currentFile.getPath, conf, ignoreCorruptFiles).map(toCatalystSchema) + }.flatten + } + + def inferSchema(sparkSession: SparkSession, files: Seq[FileStatus], options: Map[String, String]) + : Option[StructType] = { + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) + if (orcOptions.mergeSchema) { + SchemaMergeUtils.mergeSchemasInParallel( + sparkSession, options, files, OrcUtils.readOrcSchemasInParallel) + } else { + OrcUtils.readSchema(sparkSession, files, options) + } + } + + /** + * @return Returns the combination of requested column ids from the given ORC file and + * boolean flag to find if the pruneCols is allowed or not. Requested Column id can be + * -1, which means the requested column doesn't exist in the ORC file. Returns None + * if the given ORC file is empty. + */ + def requestedColumnIds( + isCaseSensitive: Boolean, + dataSchema: StructType, + requiredSchema: StructType, + reader: Reader, + conf: Configuration): Option[(Array[Int], Boolean)] = { + val orcFieldNames = reader.getSchema.getFieldNames.asScala + if (orcFieldNames.isEmpty) { + // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. + None + } else { + if (orcFieldNames.forall(_.startsWith("_col"))) { + // This is a ORC file written by Hive, no field names in the physical schema, assume the + // physical schema maps to the data scheme by index. + assert(orcFieldNames.length <= dataSchema.length, "The given data schema " + + s"${dataSchema.catalogString} has less fields than the actual ORC physical schema, " + + "no idea which columns were dropped, fail to read.") + // for ORC file written by Hive, no field names + // in the physical schema, there is a need to send the + // entire dataSchema instead of required schema. + // So pruneCols is not done in this case + Some(requiredSchema.fieldNames.map { name => + val index = dataSchema.fieldIndex(name) + if (index < orcFieldNames.length) { + index + } else { + -1 + } + }, false) + } else { + if (isCaseSensitive) { + Some(requiredSchema.fieldNames.zipWithIndex.map { case (name, idx) => + if (orcFieldNames.indexWhere(caseSensitiveResolution(_, name)) != -1) { + idx + } else { + -1 + } + }, true) + } else { + // Do case-insensitive resolution only if in case-insensitive mode + val caseInsensitiveOrcFieldMap = orcFieldNames.groupBy(_.toLowerCase(Locale.ROOT)) + Some(requiredSchema.fieldNames.zipWithIndex.map { case (requiredFieldName, idx) => + caseInsensitiveOrcFieldMap + .get(requiredFieldName.toLowerCase(Locale.ROOT)) + .map { matchedOrcFields => + if (matchedOrcFields.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched. + val matchedOrcFieldsString = matchedOrcFields.mkString("[", ", ", "]") + reader.close() + throw new RuntimeException(s"""Found duplicate field(s) "$requiredFieldName": """ + + s"$matchedOrcFieldsString in case-insensitive mode") + } else { + idx + } + }.getOrElse(-1) + }, true) + } + } + } + } + + /** + * Add a metadata specifying Spark version. + */ + def addSparkVersionMetadata(writer: Writer): Unit = { + writer.addUserMetadata(SPARK_VERSION_METADATA_KEY, UTF_8.encode(SPARK_VERSION_SHORT)) + } + + /** + * Given a `StructType` object, this methods converts it to corresponding string representation + * in ORC. + */ + def orcTypeDescriptionString(dt: DataType): String = dt match { + case s: StructType => + val fieldTypes = s.fields.map { f => + s"${quoteIdentifier(f.name)}:${orcTypeDescriptionString(f.dataType)}" + } + s"struct<${fieldTypes.mkString(",")}>" + case a: ArrayType => + s"array<${orcTypeDescriptionString(a.elementType)}>" + case m: MapType => + s"map<${orcTypeDescriptionString(m.keyType)},${orcTypeDescriptionString(m.valueType)}>" + case _ => dt.catalogString + } + + /** + * Returns the result schema to read from ORC file. In addition, It sets + * the schema string to 'orc.mapred.input.schema' so ORC reader can use later. + * + * @param canPruneCols Flag to decide whether pruned cols schema is send to resultSchema + * or to send the entire dataSchema to resultSchema. + * @param dataSchema Schema of the orc files. + * @param resultSchema Result data schema created after pruning cols. + * @param partitionSchema Schema of partitions. + * @param conf Hadoop Configuration. + * @return Returns the result schema as string. + */ + def orcResultSchemaString( + canPruneCols: Boolean, + dataSchema: StructType, + resultSchema: StructType, + partitionSchema: StructType, + conf: Configuration): String = { + val resultSchemaString = if (canPruneCols) { + OrcUtils.orcTypeDescriptionString(resultSchema) + } else { + OrcUtils.orcTypeDescriptionString(StructType(dataSchema.fields ++ partitionSchema.fields)) + } + OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString) + resultSchemaString + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..b10b9cc3685630f56025281302d912b222d585cc --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala @@ -0,0 +1,446 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import com.huawei.boostkit.spark.ColumnarPluginConfig + +import java.util.concurrent.TimeUnit.NANOSECONDS +import java.util.Optional +import com.huawei.boostkit.spark.Constant.{IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP} + +import scala.collection.mutable +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.checkOmniJsonWhiteList +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_INNER +import nova.hetu.omniruntime.operator.config.OperatorConfig +import nova.hetu.omniruntime.operator.join.{OmniHashBuilderWithExprOperatorFactory, OmniLookupJoinWithExprOperatorFactory} +import nova.hetu.omniruntime.vector.VecBatch +import nova.hetu.omniruntime.vector.serialize.VecBatchSerializerFactory +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Performs an inner hash join of two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcast relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +case class ColumnarBroadcastHashJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isNullAwareAntiJoin: Boolean = false) + extends HashJoin { + + if (isNullAwareAntiJoin) { + require(leftKeys.length == 1, "leftKeys length should be 1") + require(rightKeys.length == 1, "rightKeys length should be 1") + require(joinType == LeftAnti, "joinType must be LeftAnti.") + require(buildSide == BuildRight, "buildSide must be BuildRight.") + require(condition.isEmpty, "null aware anti join optimize condition should be empty.") + } + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "lookupAddInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni lookup addInput"), + "lookupGetOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni lookup getOutput"), + "lookupCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni lookup codegen"), + "buildAddInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni build addInput"), + "buildGetOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni build getOutput"), + "buildCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni build codegen"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), + "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs") + ) + + override def supportsColumnar: Boolean = true + + override def supportCodegen: Boolean = false + + override def nodeName: String = "OmniColumnarBroadcastHashJoin" + + override def requiredChildDistribution: Seq[Distribution] = { + val mode = HashedRelationBroadcastMode(buildBoundKeys, isNullAwareAntiJoin) + buildSide match { + case BuildLeft => + BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil + case BuildRight => + UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil + } + } + + override lazy val outputPartitioning: Partitioning = { + joinType match { + case _: InnerLike if sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => + streamedPlan.outputPartitioning match { + case h: HashPartitioning => expandOutputPartitioning(h) + case c: PartitioningCollection => expandOutputPartitioning(c) + case other => other + } + case _ => streamedPlan.outputPartitioning + } + } + + // An one-to-many mapping from a streamed key to build keys. + private lazy val streamedKeyToBuildKeyMapping = { + val mapping = mutable.Map.empty[Expression, Seq[Expression]] + streamedKeys.zip(buildKeys).foreach { + case (streamedKey, buildKey) => + val key = streamedKey.canonicalized + mapping.get(key) match { + case Some(v) => mapping.put(key, v :+ buildKey) + case None => mapping.put(key, Seq(buildKey)) + } + } + mapping.toMap + } + + // Expands the given partitioning collection recursively. + private def expandOutputPartitioning(partitioning: PartitioningCollection) + : PartitioningCollection = { + PartitioningCollection(partitioning.partitionings.flatMap { + case h: HashPartitioning => expandOutputPartitioning(h).partitionings + case c: PartitioningCollection => Seq(expandOutputPartitioning(c)) + case other => Seq(other) + }) + } + + // Expands the given hash partitioning by substituting streamed keys with build keys. + // For example, if the expressions for the given partitioning are Seq("a", "b", "c") + // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), + // the expanded partitioning will have the following expressions: + // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). + // The expanded expressions are returned as PartitioningCollection. + private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { + val maxNumCombinations = sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit + var currentNumCombinations = 0 + + def generateExprCombinations( + current: Seq[Expression], + accumulated: Seq[Expression]): Seq[Seq[Expression]] = { + if (currentNumCombinations >= maxNumCombinations) { + Nil + } else if (current.isEmpty) { + currentNumCombinations += 1 + Seq(accumulated) + } else { + val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) + generateExprCombinations(current.tail, accumulated :+ current.head) ++ + buildKeysOpt.map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b))) + .getOrElse(Nil) + } + } + + PartitioningCollection( + generateExprCombinations(partitioning.expressions, Nil) + .map(HashPartitioning(_, partitioning.numPartitions))) + } + + /** only for operator fusion */ + def getBuildOutput: Seq[Attribute] = { + buildOutput + } + + def getBuildKeys: Seq[Expression] = { + buildKeys + } + + def getBuildPlan: SparkPlan = { + buildPlan + } + + def getStreamedOutput: Seq[Attribute] = { + streamedOutput + } + + def getStreamedKeys: Seq[Expression] = { + streamedKeys + } + + def getStreamPlan: SparkPlan = { + streamedPlan + } + + def buildCheck(): Unit = { + if ("INNER" != joinType.sql) { + throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + + s"in ${this.nodeName}") + } + + val buildTypes = new Array[DataType](buildOutput.size) // {2, 2}, buildOutput:col1#12,col2#13 + buildOutput.zipWithIndex.foreach {case (att, i) => + buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } + + val buildJoinColsExp: Array[AnyRef] = buildKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))) + }.toArray + + val probeTypes = new Array[DataType](streamedOutput.size) + streamedOutput.zipWithIndex.foreach { case (attr, i) => + probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val probeHashColsExp: Array[AnyRef] = streamedKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute))) + }.toArray + + checkOmniJsonWhiteList("", buildJoinColsExp) + checkOmniJsonWhiteList("", probeHashColsExp) + + condition match { + case Some(expr) => + val filterExpr: String = OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, + OmniExpressionAdaptor.getExprIdMap((streamedOutput ++ buildOutput).map(_.toAttribute))) + checkOmniJsonWhiteList(filterExpr, new Array[AnyRef](0)) + case _ => Optional.empty() + } + } + + /** + * Return true if this stage of the plan supports columnar execution. + */ + + /** + * Produces the result of the query as an `RDD[ColumnarBatch]` if [[supportsColumnar]] returns + * true. By convention the executor that creates a ColumnarBatch is responsible for closing it + * when it is no longer needed. This allows input formats to be able to reuse batches if needed. + */ + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + // input/output: {col1#10,col2#11,col1#12,col2#13} + val numOutputRows = longMetric("numOutputRows") + val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val numMergedVecBatchs = longMetric("numMergedVecBatchs") + val buildAddInputTime = longMetric("buildAddInputTime") + val buildCodegenTime = longMetric("buildCodegenTime") + val buildGetOutputTime = longMetric("buildGetOutputTime") + val lookupAddInputTime = longMetric("lookupAddInputTime") + val lookupCodegenTime = longMetric("lookupCodegenTime") + val lookupGetOutputTime = longMetric("lookupGetOutputTime") + + + if ("INNER" != joinType.sql) { + throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + + s"in ${this.nodeName}") + } + + val buildTypes = new Array[DataType](buildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 + buildOutput.zipWithIndex.foreach { case (att, i) => + buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } + + // {0}, buildKeys: col1#12 + val buildOutputCols = buildOutput.indices.toArray // {0,1} + val buildJoinColsExp = buildKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))) + }.toArray + val buildData = buildPlan.executeBroadcast[Array[Array[Byte]]]() + + // TODO: check + val buildOutputTypes = buildTypes // {1,1} + + val probeTypes = new Array[DataType](streamedOutput.size) // {2,2}, streamedOutput:col1#10,col2#11 + streamedOutput.zipWithIndex.foreach { case (attr, i) => + probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val probeOutputCols = streamedOutput.indices.toArray // {0,1} + val probeHashColsExp = streamedKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute))) + }.toArray + streamedPlan.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => + val filter: Optional[String] = condition match { + case Some(expr) => + Optional.of(OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, + OmniExpressionAdaptor.getExprIdMap((streamedOutput ++ buildOutput).map(_.toAttribute)))) + case _ => Optional.empty() + } + val startBuildCodegen = System.nanoTime() + val buildOpFactory = new OmniHashBuilderWithExprOperatorFactory(buildTypes, + buildJoinColsExp, filter, 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val buildOp = buildOpFactory.createOperator() + buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) + + val deserializer = VecBatchSerializerFactory.create() + buildData.value.foreach { input => + val startBuildInput = System.nanoTime() + buildOp.addInput(deserializer.deserialize(input)) + buildAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildInput) + } + val startBuildGetOp = System.nanoTime() + buildOp.getOutput + buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp) + + val startLookupCodegen = System.nanoTime() + val lookupOpFactory = new OmniLookupJoinWithExprOperatorFactory(probeTypes, probeOutputCols, + probeHashColsExp, buildOutputCols, buildOutputTypes, OMNI_JOIN_TYPE_INNER, buildOpFactory, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val lookupOp = lookupOpFactory.createOperator() + lookupCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupCodegen) + + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + buildOp.close() + lookupOp.close() + buildOpFactory.close() + lookupOpFactory.close() + }) + + val resultSchema = this.schema + val reverse = this.output != (streamedPlan.output ++ buildPlan.output) + var left = 0 + var leftLen = streamedPlan.output.size + var right = streamedPlan.output.size + var rightLen = output.size + if (reverse) { + left = streamedPlan.output.size + leftLen = output.size + right = 0 + rightLen = streamedPlan.output.size + } + + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableJoinBatchMerge: Boolean = columnarConf.enableJoinBatchMerge + val iterBatch = new Iterator[ColumnarBatch] { + private var results: java.util.Iterator[VecBatch] = _ + var res: Boolean = true + + override def hasNext: Boolean = { + while ((results == null || !res) && iter.hasNext) { + val batch = iter.next() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + val startlookupInput = System.nanoTime() + lookupOp.addInput(vecBatch) + lookupAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startlookupInput) + + val startLookupGetOp = System.nanoTime() + results = lookupOp.getOutput + res = results.hasNext + lookupGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupGetOp) + + } + if (results == null) { + false + } else { + if (!res) { + false + } else { + val startLookupGetOp = System.nanoTime() + res = results.hasNext + lookupGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupGetOp) + res + } + } + + } + + override def next(): ColumnarBatch = { + val startLookupGetOp = System.nanoTime() + val result = results.next() + res = results.hasNext + lookupGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupGetOp) + val resultVecs = result.getVectors + val vecs = OmniColumnVector + .allocateColumns(result.getRowCount, resultSchema, false) + var index = 0 + for (i <- left until leftLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } + for (i <- right until rightLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } + numOutputRows += result.getRowCount + numOutputVecBatchs += 1 + new ColumnarBatch(vecs.toArray, result.getRowCount) + } + } + + if (enableJoinBatchMerge) { + new MergeIterator(iterBatch, resultSchema, numMergedVecBatchs) + } else { + iterBatch + } + } + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + streamedPlan.asInstanceOf[CodegenSupport].inputRDDs() + } + + private def multipleOutputForOneInput: Boolean = joinType match { + case _: InnerLike | LeftOuter | RightOuter => + // For inner and outer joins, one row from the streamed side may produce multiple result rows, + // if the build side has duplicated keys. Note that here we wait for the broadcast to be + // finished, which is a no-op because it's already finished when we wait it in `doProduce`. + !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique + + // Other joins types(semi, anti, existence) can at most produce one result row for one input + // row from the streamed side. + case _ => false + } + + // If the streaming side needs to copy result, this join plan needs to copy too. Otherwise, + // this join plan only needs to copy result if it may output multiple rows for one input. + override def needCopyResult: Boolean = + streamedPlan.asInstanceOf[CodegenSupport].needCopyResult || multipleOutputForOneInput + + /** + * Returns a tuple of Broadcast of HashedRelation and the variable name for it. + */ + private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { + throw new UnsupportedOperationException(s"This operator doesn't support prepareBroadcast().") + } + + protected override def prepareRelation(ctx: CodegenContext): HashedRelationInfo = { + throw new UnsupportedOperationException(s"This operator doesn't support prepareRelation().") + } + + protected override def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException(s"This operator doesn't support codegenAnti().") + } + +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..ac5d5fdbc1bc3f6e503ebcd84ff2f88194f162dc --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.concurrent.TimeUnit.NANOSECONDS +import java.util.Optional +import com.huawei.boostkit.spark.Constant.{IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP} +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_INNER +import nova.hetu.omniruntime.operator.config.OperatorConfig +import nova.hetu.omniruntime.operator.join.{OmniHashBuilderWithExprOperatorFactory, OmniLookupJoinWithExprOperatorFactory} +import nova.hetu.omniruntime.vector.VecBatch +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.optimizer.BuildSide +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.vectorized.ColumnarBatch + +case class ColumnarShuffledHashJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) + extends HashJoin with ShuffledJoin { + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "lookupAddInputTime" -> SQLMetrics.createTimingMetric(sparkContext, + "time in omni lookup addInput"), + "lookupGetOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, + "time in omni lookup getOutput"), + "lookupCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, + "time in omni lookup codegen"), + "buildAddInputTime" -> SQLMetrics.createTimingMetric(sparkContext, + "time in omni build addInput"), + "buildGetOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, + "time in omni build getOutput"), + "buildCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, + "time in omni build codegen"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), + "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "build side input data size") + ) + + override def supportsColumnar: Boolean = true + + override def supportCodegen: Boolean = false + + override def nodeName: String = "OmniColumnarShuffledHashJoin" + + override def output: Seq[Attribute] = super[ShuffledJoin].output + + override def outputPartitioning: Partitioning = super[ShuffledJoin].outputPartitioning + + override def outputOrdering: Seq[SortOrder] = joinType match { + case FullOuter => Nil + case _ => super.outputOrdering + } + + /** + * This is called by generated Java class, should be public. + */ + def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { + val buildDataSize = longMetric("buildDataSize") + val buildTime = longMetric("buildTime") + val start = System.nanoTime() + val context = TaskContext.get() + val relation = HashedRelation( + iter, + buildBoundKeys, + taskMemoryManager = context.taskMemoryManager(), + // Full outer join needs support for NULL key in HashedRelation. + allowsNullKey = joinType == FullOuter) + buildTime += NANOSECONDS.toMillis(System.nanoTime() - start) + buildDataSize += relation.estimatedSize + // This relation is usually used until the end of task. + context.addTaskCompletionListener[Unit](_ => relation.close()) + relation + } + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + + s"in ${this.nodeName}") + } + + def buildCheck(): Unit = { + if ("INNER" != joinType.sql) { + throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + + s"in ${this.nodeName}") + } + val buildTypes = new Array[DataType](buildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 + buildOutput.zipWithIndex.foreach { case (att, i) => + buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } + + buildKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))) + }.toArray + + val probeTypes = new Array[DataType](streamedOutput.size) + streamedOutput.zipWithIndex.foreach { case (attr, i) => + probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + streamedKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute))) + }.toArray + } + + /** + * Produces the result of the query as an `RDD[ColumnarBatch]` if [[supportsColumnar]] returns + * true. By convention the executor that creates a ColumnarBatch is responsible for closing it + * when it is no longer needed. This allows input formats to be able to reuse batches if needed. + */ + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val buildAddInputTime = longMetric("buildAddInputTime") + val buildCodegenTime = longMetric("buildCodegenTime") + val buildGetOutputTime = longMetric("buildGetOutputTime") + val lookupAddInputTime = longMetric("lookupAddInputTime") + val lookupCodegenTime = longMetric("lookupCodegenTime") + val lookupGetOutputTime = longMetric("lookupGetOutputTime") + val buildDataSize = longMetric("buildDataSize") + + if ("INNER" != joinType.sql) { + throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + + s"in ${this.nodeName}") + } + val buildTypes = new Array[DataType](buildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 + buildOutput.zipWithIndex.foreach { case (att, i) => + buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } + + val buildOutputCols = buildOutput.indices.toArray + val buildJoinColsExp = buildKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))) + }.toArray + + val buildOutputTypes = buildTypes + + val probeTypes = new Array[DataType](streamedOutput.size) + streamedOutput.zipWithIndex.foreach { case (attr, i) => + probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val probeOutputCols = streamedOutput.indices.toArray + val probeHashColsExp = streamedKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute))) + }.toArray + + streamedPlan.executeColumnar.zipPartitions(buildPlan.executeColumnar()) { + (streamIter, buildIter) => + val filter: Optional[String] = condition match { + case Some(expr) => + Optional.of(OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, + OmniExpressionAdaptor.getExprIdMap((streamedOutput ++ buildOutput). + map(_.toAttribute)))) + case _ => Optional.empty() + } + val startBuildCodegen = System.nanoTime() + val buildOpFactory = new OmniHashBuilderWithExprOperatorFactory(buildTypes, + buildJoinColsExp, filter, 1, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val buildOp = buildOpFactory.createOperator() + buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) + + while (buildIter.hasNext) { + + val cb = buildIter.next() + val vecs = transColBatchToOmniVecs(cb, false) + for (i <- 0 until vecs.length) { + buildDataSize += vecs(i).getRealValueBufCapacityInBytes + buildDataSize += vecs(i).getRealNullBufCapacityInBytes + buildDataSize += vecs(i).getRealOffsetBufCapacityInBytes + } + val startBuildInput = System.nanoTime() + buildOp.addInput(new VecBatch(vecs, cb.numRows())) + buildAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildInput) + } + + val startBuildGetOp = System.nanoTime() + buildOp.getOutput + buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp) + + val startLookupCodegen = System.nanoTime() + val lookupOpFactory = new OmniLookupJoinWithExprOperatorFactory(probeTypes, probeOutputCols, + probeHashColsExp, buildOutputCols, buildOutputTypes, OMNI_JOIN_TYPE_INNER, buildOpFactory, + new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val lookupOp = lookupOpFactory.createOperator() + lookupCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupCodegen) + + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + buildOp.close() + lookupOp.close() + buildOpFactory.close() + lookupOpFactory.close() + }) + + val resultSchema = this.schema + val reverse = this.output != (streamedPlan.output ++ buildPlan.output) + var left = 0 + var leftLen = streamedPlan.output.size + var right = streamedPlan.output.size + var rightLen = output.size + if (reverse) { + left = streamedPlan.output.size + leftLen = output.size + right = 0 + rightLen = streamedPlan.output.size + } + + new Iterator[ColumnarBatch] { + private var results: java.util.Iterator[VecBatch] = _ + var res: Boolean = true + + override def hasNext: Boolean = { + while ((results == null || !res) && streamIter.hasNext) { + val batch = streamIter.next() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + val startLookupInput = System.nanoTime() + lookupOp.addInput(vecBatch) + lookupAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupInput) + + val startLookupGetOp = System.nanoTime() + results = lookupOp.getOutput + res = results.hasNext + lookupGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupGetOp) + + } + if (results == null) { + false + } else { + if (!res) { + false + } else { + val startLookupGetOp = System.nanoTime() + res = results.hasNext + lookupGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupGetOp) + res + } + } + + } + + override def next(): ColumnarBatch = { + val startLookupGetOp = System.nanoTime() + val result = results.next() + res = results.hasNext + lookupGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupGetOp) + val resultVecs = result.getVectors + val vecs = OmniColumnVector + .allocateColumns(result.getRowCount, resultSchema, false) + var index = 0 + for (i <- left until leftLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } + for (i <- right until rightLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } + numOutputRows += result.getRowCount + numOutputVecBatchs += 1 + new ColumnarBatch(vecs.toArray, result.getRowCount) + } + } + } + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + streamedPlan.execute() :: buildPlan.execute() :: Nil + } + + protected override def prepareRelation(ctx: CodegenContext): HashedRelationInfo = { + val thisPlan = ctx.addReferenceObj("plan", this) + val clsName = classOf[HashedRelation].getName + + // Inline mutable state since not many join operations in a task + val relationTerm = ctx.addMutableState(clsName, "relation", + v => s"$v = $thisPlan.buildHashedRelation(inputs[1]);", forceInline = true) + HashedRelationInfo(relationTerm, keyIsUnique = false, isEmpty = false) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..2a16a0dbf2bd1ddab91e410f6def2c8630ff2e8c --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import com.huawei.boostkit.spark.ColumnarPluginConfig + +import java.util.concurrent.TimeUnit.NANOSECONDS +import java.util.Optional +import com.huawei.boostkit.spark.Constant.{IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP} +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.checkOmniJsonWhiteList +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_INNER +import nova.hetu.omniruntime.operator.config.OperatorConfig +import nova.hetu.omniruntime.operator.join.{OmniSmjBufferedTableWithExprOperatorFactory, OmniSmjStreamedTableWithExprOperatorFactory} +import nova.hetu.omniruntime.vector.{BooleanVec, Decimal128Vec, DoubleVec, IntVec, LongVec, VarcharVec, Vec, VecBatch} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Performs a sort merge join of two child relations. + */ +class ColumnarSortMergeJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isSkewJoin: Boolean = false) + extends SortMergeJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isSkewJoin: Boolean) with CodegenSupport { + + override def supportsColumnar: Boolean = true + + override def supportCodegen: Boolean = false + + override def nodeName: String = "OmniColumnarSortMergeJoin" + + val SMJ_NEED_ADD_STREAM_TBL_DATA = 2 + val SMJ_NEED_ADD_BUFFERED_TBL_DATA = 3 + val SMJ_NO_RESULT = 4 + val SMJ_FETCH_JOIN_DATA = 5 + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "streamedAddInputTime" -> + SQLMetrics.createMetric(sparkContext, "time in omni streamed addInput"), + "streamedCodegenTime" -> + SQLMetrics.createMetric(sparkContext, "time in omni streamed codegen"), + "bufferedAddInputTime" -> + SQLMetrics.createMetric(sparkContext, "time in omni buffered addInput"), + "bufferedCodegenTime" -> + SQLMetrics.createMetric(sparkContext, "time in omni buffered codegen"), + "getOutputTime" -> + SQLMetrics.createMetric(sparkContext, "time in omni buffered getOutput"), + "numOutputVecBatchs" -> + SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), + "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), + "numStreamVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of streamed vecBatchs"), + "numBufferVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of buffered vecBatchs") + ) + + def buildCheck(): Unit = { + if ("INNER" != joinType.sql) { + throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + + s"in ${this.nodeName}") + } + + val streamedTypes = new Array[DataType](left.output.size) + left.output.zipWithIndex.foreach { case (attr, i) => + streamedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val streamedKeyColsExp: Array[AnyRef] = leftKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute))) + }.toArray + + val bufferedTypes = new Array[DataType](right.output.size) + right.output.zipWithIndex.foreach { case (attr, i) => + bufferedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val bufferedKeyColsExp: Array[AnyRef] = rightKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute))) + }.toArray + + checkOmniJsonWhiteList("", streamedKeyColsExp) + checkOmniJsonWhiteList("", bufferedKeyColsExp) + + condition match { + case Some(expr) => + val filterExpr: String = OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, + OmniExpressionAdaptor.getExprIdMap(output.map(_.toAttribute))) + checkOmniJsonWhiteList(filterExpr, new Array[AnyRef](0)) + case _ => null + } + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val numMergedVecBatchs = longMetric("numMergedVecBatchs") + val streamedAddInputTime = longMetric("streamedAddInputTime") + val streamedCodegenTime = longMetric("streamedCodegenTime") + val bufferedAddInputTime = longMetric("bufferedAddInputTime") + val bufferedCodegenTime = longMetric("bufferedCodegenTime") + val getOutputTime = longMetric("getOutputTime") + val streamVecBatchs = longMetric("numStreamVecBatchs") + val bufferVecBatchs = longMetric("numBufferVecBatchs") + + if ("INNER" != joinType.sql) { + throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + + s"in ${this.nodeName}") + } + + val streamedTypes = new Array[DataType](left.output.size) + left.output.zipWithIndex.foreach { case (attr, i) => + streamedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val streamedKeyColsExp = leftKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute))) + }.toArray + val streamedOutputChannel = left.output.indices.toArray + + val bufferedTypes = new Array[DataType](right.output.size) + right.output.zipWithIndex.foreach { case (attr, i) => + bufferedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val bufferedKeyColsExp = rightKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute))) + }.toArray + val bufferedOutputChannel = right.output.indices.toArray + + val filterString: String = condition match { + case Some(expr) => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, + OmniExpressionAdaptor.getExprIdMap(output.map(_.toAttribute))) + case _ => null + } + + left.executeColumnar().zipPartitions(right.executeColumnar()) { (streamedIter, bufferedIter) => + val filter: Optional[String] = Optional.ofNullable(filterString) + val startStreamedCodegen = System.nanoTime() + val streamedOpFactory = new OmniSmjStreamedTableWithExprOperatorFactory(streamedTypes, + streamedKeyColsExp, streamedOutputChannel, OMNI_JOIN_TYPE_INNER, filter, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val streamedOp = streamedOpFactory.createOperator + streamedCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startStreamedCodegen) + + val startBufferedCodegen = System.nanoTime() + val bufferedOpFactory = new OmniSmjBufferedTableWithExprOperatorFactory(bufferedTypes, + bufferedKeyColsExp, bufferedOutputChannel, streamedOpFactory, new OperatorConfig(IS_ENABLE_JIT, IS_SKIP_VERIFY_EXP)) + val bufferedOp = bufferedOpFactory.createOperator + bufferedCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBufferedCodegen) + + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + streamedOp.close() + bufferedOp.close() + bufferedOpFactory.close() + streamedOpFactory.close() + }) + + val resultSchema = this.schema + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableSortMergeJoinBatchMerge: Boolean = columnarConf.enableSortMergeJoinBatchMerge + val iterBatch = new Iterator[ColumnarBatch] { + + var isFinished = !streamedIter.hasNext || !bufferedIter.hasNext + var isStreamedFinished = false + var isBufferedFinished = false + var results: java.util.Iterator[VecBatch] = null + + def checkAndClose() : Unit = { + while(streamedIter.hasNext) { + streamVecBatchs += 1 + streamedIter.next().close() + } + while(bufferedIter.hasNext) { + bufferVecBatchs += 1 + bufferedIter.next().close() + } + } + + override def hasNext: Boolean = { + if (isFinished) { + checkAndClose() + return false + } + if (results != null && results.hasNext) { + return true + } + // reset results and find next results + results = null + // Add streamed data first + var inputReturnCode = SMJ_NEED_ADD_STREAM_TBL_DATA + while (inputReturnCode == SMJ_NEED_ADD_STREAM_TBL_DATA + || inputReturnCode == SMJ_NEED_ADD_BUFFERED_TBL_DATA) { + if (inputReturnCode == SMJ_NEED_ADD_STREAM_TBL_DATA) { + val startBuildStreamedInput = System.nanoTime() + if (!isStreamedFinished && streamedIter.hasNext) { + val batch = streamedIter.next() + streamVecBatchs += 1 + val inputVecBatch = transColBatchToVecBatch(batch) + inputReturnCode = streamedOp.addInput(inputVecBatch) + } else { + inputReturnCode = streamedOp.addInput(createEofVecBatch(streamedTypes)) + isStreamedFinished = true + } + streamedAddInputTime += + NANOSECONDS.toMillis(System.nanoTime() - startBuildStreamedInput) + } else { + val startBuildBufferedInput = System.nanoTime() + if (!isBufferedFinished && bufferedIter.hasNext) { + val batch = bufferedIter.next() + bufferVecBatchs += 1 + val inputVecBatch = transColBatchToVecBatch(batch) + inputReturnCode = bufferedOp.addInput(inputVecBatch) + } else { + inputReturnCode = bufferedOp.addInput(createEofVecBatch(bufferedTypes)) + isBufferedFinished = true + } + bufferedAddInputTime += + NANOSECONDS.toMillis(System.nanoTime() - startBuildBufferedInput) + } + } + if (inputReturnCode == SMJ_FETCH_JOIN_DATA) { + val startGetOutputTime = System.nanoTime() + results = bufferedOp.getOutput + var hasNext = results.hasNext + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOutputTime) + if (hasNext) { + return true + } else { + isFinished = true + results = null + checkAndClose() + return false + } + } + + if (inputReturnCode == SMJ_NO_RESULT) { + isFinished = true + results = null + checkAndClose() + return false + } + + throw new UnsupportedOperationException(s"Unknown return code ${inputReturnCode}") + } + + override def next(): ColumnarBatch = { + val startGetOutputTime = System.nanoTime() + val result = results.next() + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOutputTime) + val resultVecs = result.getVectors + val vecs = OmniColumnVector.allocateColumns(result.getRowCount, resultSchema, false) + for (index <- output.indices) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(index)) + } + numOutputVecBatchs += 1 + numOutputRows += result.getRowCount + result.close() + new ColumnarBatch(vecs.toArray, result.getRowCount) + } + + def createEofVecBatch(types: Array[DataType]): VecBatch = { + val vecs: Array[Vec] = new Array[Vec](types.length) + for (i <- types.indices) { + vecs(i) = types(i).getId match { + case DataType.DataTypeId.OMNI_INT | DataType.DataTypeId.OMNI_DATE32 => + new IntVec(0) + case DataType.DataTypeId.OMNI_LONG | DataType.DataTypeId.OMNI_DECIMAL64 => + new LongVec(0) + case DataType.DataTypeId.OMNI_DOUBLE => + new DoubleVec(0) + case DataType.DataTypeId.OMNI_BOOLEAN => + new BooleanVec(0) + case DataType.DataTypeId.OMNI_CHAR | DataType.DataTypeId.OMNI_VARCHAR => + new VarcharVec(0, 0) + case DataType.DataTypeId.OMNI_DECIMAL128 => + new Decimal128Vec(0) + case _ => + throw new IllegalArgumentException(s"VecType [${types(i).getClass.getSimpleName}]" + + s" is not supported in [${getClass.getSimpleName}] yet") + } + } + new VecBatch(vecs, 0) + } + + def transColBatchToVecBatch(columnarBatch: ColumnarBatch): VecBatch = { + val input = transColBatchToOmniVecs(columnarBatch) + new VecBatch(input, columnarBatch.numRows()) + } + } + + if (enableSortMergeJoinBatchMerge) { + new MergeIterator(iterBatch, resultSchema, numMergedVecBatchs) + } else { + iterBatch + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala new file mode 100644 index 0000000000000000000000000000000000000000..c67d45032589b74ee414625010cba01ba716465b --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.util + +import com.huawei.boostkit.spark.ColumnarPluginConfig + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.sparkTypeToOmniType +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import nova.hetu.omniruntime.`type`.{DataType, VarcharDataType} +import nova.hetu.omniruntime.vector.{BooleanVec, Decimal128Vec, DoubleVec, IntVec, LongVec, ShortVec, VarcharVec, Vec, VecBatch} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.types.{BooleanType, DateType, DecimalType, DoubleType, IntegerType, LongType, ShortType, StringType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch + +class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, + numMergedVecBatchs: SQLMetric) extends Iterator[ColumnarBatch] { + + private val outputQueue = new mutable.Queue[VecBatch] + private val bufferedVecBatch = new ListBuffer[VecBatch]() + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + private val maxBatchSizeInBytes: Int = columnarConf.maxBatchSizeInBytes + private val maxRowCount: Int = columnarConf.maxRowCount + private var totalRows = 0 + private var currentBatchSizeInBytes = 0 + + private def createOmniVectors(schema: StructType, columnSize: Int): Array[Vec] = { + val vecs = new Array[Vec](schema.fields.length) + schema.fields.zipWithIndex.foreach { case (field, index) => + field.dataType match { + case LongType => + vecs(index) = new LongVec(columnSize) + case DateType | IntegerType => + vecs(index) = new IntVec(columnSize) + case ShortType => + vecs(index) = new ShortVec(columnSize) + case DoubleType => + vecs(index) = new DoubleVec(columnSize) + case BooleanType => + vecs(index) = new BooleanVec(columnSize) + case StringType => + val vecType: DataType = sparkTypeToOmniType(field.dataType, field.metadata) + vecs(index) = new VarcharVec(vecType.asInstanceOf[VarcharDataType].getWidth * columnSize, + columnSize) + case dt: DecimalType => + if (DecimalType.is64BitDecimalType(dt)) { + vecs(index) = new LongVec(columnSize) + } else { + vecs(index) = new Decimal128Vec(columnSize) + } + case _ => + throw new UnsupportedOperationException("Fail to create omni vector, unsupported fields") + } + } + vecs + } + + private def buffer(vecBatch: VecBatch): Unit = { + var totalSize = 0 + vecBatch.getVectors.zipWithIndex.foreach { + case (vec, i) => + totalSize += vec.getCapacityInBytes + } + currentBatchSizeInBytes += totalSize + totalRows += vecBatch.getRowCount + + bufferedVecBatch.append(vecBatch) + if (isFull()) { + flush() + } + } + + private def merge(resultBatch: VecBatch, bufferedBatch: ListBuffer[VecBatch]): Unit = { + localSchema.fields.zipWithIndex.foreach { case (field, index) => + var offset = 0 + for (elem <- bufferedBatch) { + val src: Vec = elem.getVector(index) + val dest: Vec = resultBatch.getVector(index) + dest.append(src, offset, elem.getRowCount) + offset += elem.getRowCount + src.close() + } + } + } + + private def flush(): Unit = { + + if (bufferedVecBatch.isEmpty) { + return + } + val resultBatch: VecBatch = new VecBatch(createOmniVectors(localSchema, totalRows), totalRows) + merge(resultBatch, bufferedVecBatch) + outputQueue.enqueue(resultBatch) + numMergedVecBatchs += 1 + + bufferedVecBatch.clear() + currentBatchSizeInBytes = 0 + totalRows = 0 + + } + + private def vecBatchToColumnarBatch(vecBatch: VecBatch): ColumnarBatch = { + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, localSchema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(i)) + } + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + } + + override def hasNext: Boolean = { + while (outputQueue.isEmpty && iter.hasNext) { + val batch: ColumnarBatch = iter.next() + val input: Array[Vec] = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + buffer(vecBatch) + } + + if (outputQueue.isEmpty && bufferedVecBatch.isEmpty) { + false + } else { + true + } + } + + override def next(): ColumnarBatch = { + if (outputQueue.nonEmpty) { + vecBatchToColumnarBatch(outputQueue.dequeue()) + } else if (bufferedVecBatch.nonEmpty) { + flush() + vecBatchToColumnarBatch(outputQueue.dequeue()) + } else { + throw new RuntimeException("bufferedVecBatch and outputQueue are empty") + } + } + + + def isFull(): Boolean = { + totalRows > maxRowCount || currentBatchSizeInBytes >= maxBatchSizeInBytes + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..6012da931bb3b93ef8a3e6690d42ba3d1e4949e0 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.util + +import nova.hetu.omniruntime.vector.VecAllocator + +import org.apache.spark.{SparkEnv, TaskContext} + +object SparkMemoryUtils { + + private val max: Long = SparkEnv.get.conf.getSizeAsBytes("spark.memory.offHeap.size", "1g") + VecAllocator.setRootAllocatorLimit(max) + + def init(): Unit = {} + + private def getLocalTaskContext: TaskContext = TaskContext.get() + + private def inSparkTask(): Boolean = { + getLocalTaskContext != null + } + + def addLeakSafeTaskCompletionListener[U](f: TaskContext => U): TaskContext = { + if (!inSparkTask()) { + throw new IllegalStateException("Not in a Spark task") + } + getLocalTaskContext.addTaskCompletionListener(f) + } + +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/types/ColumnarBatchSupportUtil.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/types/ColumnarBatchSupportUtil.scala new file mode 100644 index 0000000000000000000000000000000000000000..cc3763164f211cee9083395d2335f65c2a286c91 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/types/ColumnarBatchSupportUtil.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + package org.apache.spark.sql.types + +import org.apache.spark.sql.execution.FileSourceScanExec + import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat + import org.apache.spark.sql.internal.SQLConf + + object ColumnarBatchSupportUtil { + def checkColumnarBatchSupport(conf: SQLConf, plan: FileSourceScanExec): Boolean = { + val isSupportFormat: Boolean = { + plan.relation.fileFormat match { + case _: OrcFileFormat => + conf.orcVectorizedReaderEnabled + case _ => + false + } + } + val supportBatchReader: Boolean = { + val partitionSchema = plan.relation.partitionSchema + val resultSchema = StructType(plan.requiredSchema.fields ++ partitionSchema.fields) + conf.orcVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + } + supportBatchReader && isSupportFormat + } + } + diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java new file mode 100644 index 0000000000000000000000000000000000000000..d95be18832b926500b599821b6b6fd0baa8861c5 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java @@ -0,0 +1,129 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark; + +import com.huawei.boostkit.spark.jni.SparkJniWrapper; + +import java.io.File; +import nova.hetu.omniruntime.type.DataType; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_CHAR; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE32; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE64; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL128; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL64; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DOUBLE; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import nova.hetu.omniruntime.type.DataTypeSerializer; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.IOException; + +public class ColumnShuffleCompressionTest extends ColumnShuffleTest { + private static String shuffleDataFile = ""; + + @BeforeClass + public static void runOnceBeforeClass() { + File folder = new File(shuffleTestDir); + if (!folder.exists() && !folder.isDirectory()) { + folder.mkdirs(); + } + } + + @AfterClass + public static void runOnceAfterClass() { + File folder = new File(shuffleTestDir); + if (folder.exists()) { + deleteDir(folder); + } + } + + @Before + public void runBeforeTestMethod() { + + } + + @After + public void runAfterTestMethod() { + File file = new File(shuffleDataFile); + if (file.exists()) { + file.delete(); + } + } + + @Test + public void columnShuffleUncompressedTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_uncompressed_test"; + columnShuffleTestCompress("uncompressed", shuffleDataFile); + } + + @Test + public void columnShuffleSnappyCompressTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_snappy_test"; + columnShuffleTestCompress("snappy", shuffleDataFile); + } + + @Test + public void columnShuffleLz4CompressTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_lz4_test"; + columnShuffleTestCompress("lz4", shuffleDataFile); + } + + @Test + public void columnShuffleZlibCompressTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_zlib_test"; + columnShuffleTestCompress("zlib", shuffleDataFile); + } + + public void columnShuffleTestCompress(String compressType, String dataFile) throws IOException { + DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR, + OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; + DataType[] types = dataTypeId2DataType(idTypes); + String inputType = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + partitionNum, + inputType, + types.length, + 1024, //shuffle value_buffer init size + compressType, + dataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + for (int i = 0; i < 999; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 1000, partitionNum, true, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java new file mode 100644 index 0000000000000000000000000000000000000000..c8fd474137a93ea8831d3dc3ab432e409018cc55 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark; + +import com.huawei.boostkit.spark.jni.SparkJniWrapper; + +import java.io.File; +import nova.hetu.omniruntime.type.DataType; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_CHAR; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE32; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE64; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL128; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL64; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DOUBLE; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import nova.hetu.omniruntime.type.DataTypeSerializer; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.IOException; + +public class ColumnShuffleDiffPartitionTest extends ColumnShuffleTest { + private static String shuffleDataFile = ""; + + @BeforeClass + public static void runOnceBeforeClass() { + File folder = new File(shuffleTestDir); + if (!folder.exists() && !folder.isDirectory()) { + folder.mkdirs(); + } + } + + @AfterClass + public static void runOnceAfterClass() { + File folder = new File(shuffleTestDir); + if (folder.exists()) { + deleteDir(folder); + } + } + + @Before + public void runBeforeTestMethod() { + + } + + @After + public void runAfterTestMethod() { + File file = new File(shuffleDataFile); + if (file.exists()) { + file.delete(); + } + } + + @Test + public void columnShuffleSinglePartitionTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_singlePartition_test"; + columnShufflePartitionTest("single", shuffleDataFile); + } + + @Test + public void columnShuffleHashPartitionTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_hashPartition_test"; + columnShufflePartitionTest("hash", shuffleDataFile); + } + + @Test + public void columnShuffleRangePartitionTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_rangePartition_test"; + columnShufflePartitionTest("range", shuffleDataFile); + } + + public void columnShufflePartitionTest(String partitionType, String dataFile) throws IOException { + DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR, + OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 1; + boolean pidVec = true; + if (partitionType.equals("single")){ + pidVec = false; + } + long splitterId = jniWrapper.nativeMake( + partitionType, + 1, + tmpStr, + types.length, + 3, + "lz4", + dataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + for (int i = 0; i < 99; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, true, pidVec); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java new file mode 100644 index 0000000000000000000000000000000000000000..dc53fda8a1a04a15bf7ffb9919926d4812208fc0 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java @@ -0,0 +1,303 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark; + +import com.huawei.boostkit.spark.jni.SparkJniWrapper; + +import java.io.File; +import nova.hetu.omniruntime.type.DataType; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_CHAR; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE32; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE64; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL128; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL64; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DOUBLE; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import nova.hetu.omniruntime.type.DataTypeSerializer; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.IOException; + +public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest { + private static String shuffleDataFile = ""; + + @BeforeClass + public static void runOnceBeforeClass() { + File folder = new File(shuffleTestDir); + if (!folder.exists() && !folder.isDirectory()) { + folder.mkdirs(); + } + } + + @AfterClass + public static void runOnceAfterClass() { + File folder = new File(shuffleTestDir); + if (folder.exists()) { + deleteDir(folder); + } + } + + @Before + public void runBeforeTestMethod() { + + } + + @After + public void runAfterTestMethod() { + File file = new File(shuffleDataFile); + if (file.exists()) { + file.delete(); + } + } + + @Test + public void columnShuffleMixColTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_MixCol_test"; + DataType.DataTypeId[] idTypes = {OMNI_LONG, OMNI_DOUBLE, OMNI_INT, OMNI_VARCHAR, OMNI_CHAR, + OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + partitionNum, + tmpStr, + types.length, + 3, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + for (int i = 0; i < 999; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, true, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Test + public void columnShuffleVarCharFirstTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_varCharFirst_test"; + DataType.DataTypeId[] idTypes = {OMNI_VARCHAR, OMNI_LONG, OMNI_DOUBLE, OMNI_INT, OMNI_CHAR, + OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + partitionNum, + tmpStr, + types.length, + 3, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 0, + 4096, + 1024*1024*1024); + for (int i = 0; i < 999; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, true, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Test + public void columnShuffle1Row1024VBTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_1row1024vb_test"; + DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR, + OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + partitionNum, + tmpStr, + types.length, + 3, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + for (int i = 0; i < 1024; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 1, partitionNum, false, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Test + public void columnShuffle1024Row1VBTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_1024row1vb_test"; + DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR, + OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + partitionNum, + tmpStr, + types.length, + 3, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + for (int i = 0; i < 1; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 1024, partitionNum, false, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Test + public void columnShuffleChangeRowVBTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_changeRow_test"; + DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int numPartition = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + numPartition, + tmpStr, + types.length, + 3, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + for (int i = 1; i < 1000; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, i, numPartition, false, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Test + public void columnShuffleVarChar1RowVBTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_varChar1Row_test"; + DataType.DataTypeId[] idTypes = {OMNI_VARCHAR}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + partitionNum, + tmpStr, + types.length, + 3, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + VecBatch vecBatchTmp1 = new VecBatch(buildValChar(3, "N")); + jniWrapper.split(splitterId, vecBatchTmp1.getNativeVectorBatch()); + VecBatch vecBatchTmp2 = new VecBatch(buildValChar(2, "F")); + jniWrapper.split(splitterId, vecBatchTmp2.getNativeVectorBatch()); + VecBatch vecBatchTmp3 = new VecBatch(buildValChar(3, "N")); + jniWrapper.split(splitterId, vecBatchTmp3.getNativeVectorBatch()); + VecBatch vecBatchTmp4 = new VecBatch(buildValChar(2, "F")); + jniWrapper.split(splitterId, vecBatchTmp4.getNativeVectorBatch()); + VecBatch vecBatchTmp5 = new VecBatch(buildValChar(2, "F")); + jniWrapper.split(splitterId, vecBatchTmp5.getNativeVectorBatch()); + VecBatch vecBatchTmp6 = new VecBatch(buildValChar(2, "F")); + jniWrapper.split(splitterId, vecBatchTmp6.getNativeVectorBatch()); + VecBatch vecBatchTmp7 = new VecBatch(buildValChar(1, "R")); + jniWrapper.split(splitterId, vecBatchTmp7.getNativeVectorBatch()); + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Test + public void columnShuffleFix1RowVBTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_fix1Row_test"; + DataType.DataTypeId[] idTypes = {OMNI_INT}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + partitionNum, + tmpStr, + types.length, + 3, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + VecBatch vecBatchTmp1 = new VecBatch(buildValInt(3, 1)); + jniWrapper.split(splitterId, vecBatchTmp1.getNativeVectorBatch()); + VecBatch vecBatchTmp2 = new VecBatch(buildValInt(2, 2)); + jniWrapper.split(splitterId, vecBatchTmp2.getNativeVectorBatch()); + VecBatch vecBatchTmp3 = new VecBatch(buildValInt(3, 3)); + jniWrapper.split(splitterId, vecBatchTmp3.getNativeVectorBatch()); + VecBatch vecBatchTmp4 = new VecBatch(buildValInt(2, 4)); + jniWrapper.split(splitterId, vecBatchTmp4.getNativeVectorBatch()); + VecBatch vecBatchTmp5 = new VecBatch(buildValInt(2, 5)); + jniWrapper.split(splitterId, vecBatchTmp5.getNativeVectorBatch()); + VecBatch vecBatchTmp6 = new VecBatch(buildValInt(1, 6)); + jniWrapper.split(splitterId, vecBatchTmp6.getNativeVectorBatch()); + VecBatch vecBatchTmp7 = new VecBatch(buildValInt(3, 7)); + jniWrapper.split(splitterId, vecBatchTmp7.getNativeVectorBatch()); + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java new file mode 100644 index 0000000000000000000000000000000000000000..2ef81ac49e545aa617136b9d4f3e7e769ea34652 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java @@ -0,0 +1,255 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark; + +import com.huawei.boostkit.spark.jni.SparkJniWrapper; + +import java.io.File; +import nova.hetu.omniruntime.type.DataType; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_CHAR; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE32; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE64; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL128; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL64; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DOUBLE; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import nova.hetu.omniruntime.type.DataTypeSerializer; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +import java.io.IOException; + +public class ColumnShuffleGBSizeTest extends ColumnShuffleTest { + private static String shuffleDataFile = ""; + + @BeforeClass + public static void runOnceBeforeClass() { + File folder = new File(shuffleTestDir); + if (!folder.exists() && !folder.isDirectory()) { + folder.mkdirs(); + } + } + + @AfterClass + public static void runOnceAfterClass() { + File folder = new File(shuffleTestDir); + if (folder.exists()) { + deleteDir(folder); + } + } + + @Before + public void runBeforeTestMethod() { + + } + + @After + public void runAfterTestMethod() { + File file = new File(shuffleDataFile); + if (file.exists()) { + file.delete(); + } + } + + @Test + public void columnShuffleFixed1GBTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_fixed1GB_test"; + DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 3; + long splitterId = jniWrapper.nativeMake( + "hash", + partitionNum, + tmpStr, + types.length, + 4096, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + for (int i = 0; i < 6 * 1024; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Ignore + public void columnShuffleFixed10GBTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_fixed10GB_test"; + DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 3; + long splitterId = jniWrapper.nativeMake( + "hash", + partitionNum, + tmpStr, + types.length, + 4096, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + for (int i = 0; i < 10 * 8 * 1024; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Test + public void columnShuffleVarChar1GBTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_varChar1GB_test"; + DataType.DataTypeId[] idTypes = {OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + partitionNum, + tmpStr, + types.length, + 1024, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core + for (int i = 0; i < 99; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, false, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Ignore + public void columnShuffleVarChar10GBTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_varChar10GB_test"; + DataType.DataTypeId[] idTypes = {OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + partitionNum, + tmpStr, + types.length, + 1024, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + for (int i = 0; i < 10 * 3 * 999; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Test + public void columnShuffleMix1GBTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_mix1GB_test"; + DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR, + OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + partitionNum, + tmpStr, + types.length, + 4096, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core + for (int i = 0; i < 6 * 999; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Ignore + public void columnShuffleMix10GBTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_mix10GB_test"; + DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR, OMNI_CHAR, + OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int partitionNum = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + partitionNum, + tmpStr, + types.length, + 4096, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + for (int i = 0; i < 3 * 9 * 999; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java new file mode 100644 index 0000000000000000000000000000000000000000..98fc18dd8f3237928cc066887e6fcb2205686692 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java @@ -0,0 +1,197 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark; + +import com.huawei.boostkit.spark.jni.SparkJniWrapper; + +import java.io.File; +import nova.hetu.omniruntime.type.DataType; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_CHAR; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE32; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE64; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL128; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL64; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DOUBLE; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import nova.hetu.omniruntime.type.DataTypeSerializer; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.IOException; + +public class ColumnShuffleNullTest extends ColumnShuffleTest { + private static String shuffleDataFile = ""; + + @BeforeClass + public static void runOnceBeforeClass() { + File folder = new File(shuffleTestDir); + if (!folder.exists() && !folder.isDirectory()) { + folder.mkdirs(); + } + } + + @AfterClass + public static void runOnceAfterClass() { + File folder = new File(shuffleTestDir); + if (folder.exists()) { + deleteDir(folder); + } + } + + @Before + public void runBeforeTestMethod() { + + } + + @After + public void runAfterTestMethod() { + File file = new File(shuffleDataFile); + if (file.exists()) { + file.delete(); + } + } + + @Test + public void columnShuffleFixNullTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_fixNull_test"; + DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int numPartition = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + numPartition, + tmpStr, + types.length, + 3, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core + for (int i = 0; i < 1; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 9, numPartition, true, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Test + public void columnShuffleVarCharNullTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_fixNull_test"; + DataType.DataTypeId[] idTypes = {OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR,OMNI_VARCHAR}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int numPartition = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + numPartition, + tmpStr, + types.length, + 3, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core + for (int i = 0; i < 1; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 9, numPartition, true, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Test + public void columnShuffleMixNullTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_MixNull_test"; + DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE,OMNI_VARCHAR, OMNI_CHAR, + OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int numPartition = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + numPartition, + tmpStr, + types.length, + 3, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core + for (int i = 0; i < 1; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 9, numPartition, true, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } + + @Test + public void columnShuffleMixNullFullTest() throws IOException { + shuffleDataFile = shuffleTestDir + "/shuffle_dataFile_MixNullFull_test"; + DataType.DataTypeId[] idTypes = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE,OMNI_VARCHAR, OMNI_CHAR, + OMNI_DATE32, OMNI_DATE64, OMNI_DECIMAL64, OMNI_DECIMAL128}; + DataType[] types = dataTypeId2DataType(idTypes); + String tmpStr = DataTypeSerializer.serialize(types); + SparkJniWrapper jniWrapper = new SparkJniWrapper(); + int numPartition = 4; + long splitterId = jniWrapper.nativeMake( + "hash", + numPartition, + tmpStr, + types.length, + 3, //shuffle value_buffer init size + "lz4", + shuffleDataFile, + 0, + shuffleTestDir, + 64 * 1024, + 4096, + 1024 * 1024 * 1024); + for (int i = 0; i < 1; i++) { + VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, numPartition, true, true); + jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); + } + jniWrapper.stop(splitterId); + jniWrapper.close(splitterId); + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleTest.java new file mode 100644 index 0000000000000000000000000000000000000000..74fccca66fad64dac9c96ae5f60591de40e92012 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleTest.java @@ -0,0 +1,220 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark; + +import java.io.File; +import nova.hetu.omniruntime.type.CharDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Date32DataType; +import nova.hetu.omniruntime.type.Date64DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.Decimal64DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.Decimal128Vec; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +abstract class ColumnShuffleTest { + public static String shuffleTestDir = "/tmp/shuffleTests"; + + public DataType[] dataTypeId2DataType(DataType.DataTypeId[] idTypes) { + DataType[] types = new DataType[idTypes.length]; + for(int i = 0; i < idTypes.length; i++) { + switch (idTypes[i]) { + case OMNI_INT: { + types[i] = IntDataType.INTEGER; + break; + } + case OMNI_LONG: { + types[i] = LongDataType.LONG; + break; + } + case OMNI_DOUBLE: { + types[i] = DoubleDataType.DOUBLE; + break; + } + case OMNI_VARCHAR: { + types[i] = VarcharDataType.VARCHAR; + break; + } + case OMNI_CHAR: { + types[i] = CharDataType.CHAR; + break; + } + case OMNI_DATE32: { + types[i] = Date32DataType.DATE32; + break; + } + case OMNI_DATE64: { + types[i] = Date64DataType.DATE64; + break; + } + case OMNI_DECIMAL64: { + types[i] = Decimal64DataType.DECIMAL64; + break; + } + case OMNI_DECIMAL128: { + types[i] = Decimal128DataType.DECIMAL128; // Or types[i] = new Decimal128DataType(2, 0); + break; + } + default: { + throw new UnsupportedOperationException("Unsupported type : " + idTypes[i]); + } + } + } + return types; + } + + public VecBatch buildVecBatch(DataType.DataTypeId[] idTypes, int rowNum, int partitionNum, boolean mixHalfNull, boolean withPidVec) { + List columns = new ArrayList<>(); + Vec tmpVec = null; + // prepare pidVec + if (withPidVec) { + IntVec pidVec = new IntVec(rowNum); + for (int i = 0; i < rowNum; i++) { + pidVec.set(i, i % partitionNum); + } + columns.add(pidVec); + } + + for(int i = 0; i < idTypes.length; i++) { + switch (idTypes[i]) { + case OMNI_INT: + case OMNI_DATE32:{ + tmpVec = new IntVec(rowNum); + for (int j = 0; j < rowNum; j++) { + ((IntVec)tmpVec).set(j, j + 1); + if (mixHalfNull && (j % 2) == 0) { + tmpVec.setNull(j); + } + } + break; + } + case OMNI_LONG: + case OMNI_DECIMAL64: + case OMNI_DATE64: { + tmpVec = new LongVec(rowNum); + for (int j = 0; j < rowNum; j++) { + ((LongVec)tmpVec).set(j, j + 1); + if (mixHalfNull && (j % 2) == 0) { + tmpVec.setNull(j); + } + } + break; + } + case OMNI_DOUBLE: { + tmpVec = new DoubleVec(rowNum); + for (int j = 0; j < rowNum; j++) { + ((DoubleVec)tmpVec).set(j, j + 1); + if (mixHalfNull && (j % 2) == 0) { + tmpVec.setNull(j); + } + } + break; + } + case OMNI_VARCHAR: + case OMNI_CHAR: { + tmpVec = new VarcharVec(rowNum * 16, rowNum); + for (int j = 0; j < rowNum; j++) { + ((VarcharVec)tmpVec).set(j, ("VAR_" + (j + 1) + "_END").getBytes(StandardCharsets.UTF_8)); + if (mixHalfNull && (j % 2) == 0) { + tmpVec.setNull(j); + } + } + break; + } + case OMNI_DECIMAL128: { + long[][] arr = new long[rowNum][2]; + for (int j = 0; j < rowNum; j++) { + arr[j][0] = 2 * j; + arr[j][1] = 2 * j + 1; + if (mixHalfNull && (j % 2) == 0) { + arr[j] = null; + } + } + tmpVec = createDecimal128Vec(arr); + break; + } + default: { + throw new UnsupportedOperationException("Unsupported type : " + idTypes[i]); + } + } + columns.add(tmpVec); + } + return new VecBatch(columns); + } + + public Decimal128Vec createDecimal128Vec(long[][] data) { + Decimal128Vec result = new Decimal128Vec(data.length); + for (int i = 0; i < data.length; i++) { + if (data[i] == null) { + result.setNull(i); + } else { + result.set(i, new long[]{data[i][0], data[i][1]}); + } + } + return result; + } + + public List buildValInt(int pid, int val) { + IntVec c0 = new IntVec(1); + IntVec c1 = new IntVec(1); + c0.set(0, pid); + c1.set(0, val); + List columns = new ArrayList<>(); + columns.add(c0); + columns.add(c1); + return columns; + } + + public List buildValChar(int pid, String varChar) { + IntVec c0 = new IntVec(1); + VarcharVec c1 = new VarcharVec(8, 1); + c0.set(0, pid); + c1.set(0, varChar.getBytes(StandardCharsets.UTF_8)); + List columns = new ArrayList<>(); + columns.add(c0); + columns.add(c1); + return columns; + } + + public static boolean deleteDir(File dir) { + if (dir.isDirectory()) { + String[] children = dir.list(); + for (int i=0; i includedColumns = new ArrayList(); + // type long + includedColumns.add("i_item_sk"); + // type char 16 + includedColumns.add("i_item_id"); + // type char 200 + includedColumns.add("i_item_desc"); + // type int + includedColumns.add("i_current_price"); + job.put("includedColumns", includedColumns.toArray()); + + orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); + assertTrue(orcColumnarBatchJniReader.recordReader != 0); + } + + public void initBatch() { + orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); + assertTrue(orcColumnarBatchJniReader.batchReader != 0); + } + + @Test + public void testNext() { + int[] typeId = new int[4]; + long[] vecNativeId = new long[4]; + long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.reader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); + assertTrue(rtn == 4096); + LongVec vec1 = new LongVec(vecNativeId[0]); + VarcharVec vec2 = new VarcharVec(vecNativeId[1]); + VarcharVec vec3 = new VarcharVec(vecNativeId[2]); + IntVec vec4 = new IntVec(vecNativeId[3]); + assertTrue(vec1.get(10) == 11); + String tmp1 = new String(vec2.get(4080)); + assertTrue(tmp1.equals("AAAAAAAABPPAAAAA")); + String tmp2 = new String(vec3.get(4070)); + assertTrue(tmp2.equals("Particular, arab cases shall like less current, different names. Computers start for the changes. Scottish, trying exercises operate marks; long, supreme miners may ro")); + assertTrue(0 == vec4.get(1000)); + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java new file mode 100644 index 0000000000000000000000000000000000000000..853867f8d2ea593f18fe1300c530cc647cb8bb58 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java @@ -0,0 +1,101 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import junit.framework.TestCase; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; +import org.json.JSONObject; +import org.junit.After; +import org.junit.Before; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +import java.io.File; +import java.util.ArrayList; + +import static org.junit.Assert.*; + +@FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) +public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase { + public OrcColumnarBatchJniReader orcColumnarBatchJniReader; + + @Before + public void setUp() throws Exception { + orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); + initReaderJava(); + initRecordReaderJava(); + initBatch(); + } + + @After + public void tearDown() throws Exception { + System.out.println("orcColumnarBatchJniReader test finished"); + } + + public void initReaderJava() { + JSONObject job = new JSONObject(); + job.put("serializedTail",""); + job.put("tailLocation",9223372036854775807L); + File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); + System.out.println(directory.getAbsolutePath()); + orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReader(directory.getAbsolutePath(), job); + assertTrue(orcColumnarBatchJniReader.reader != 0); + } + + public void initRecordReaderJava() { + JSONObject job = new JSONObject(); + job.put("include",""); + job.put("offset", 0); + job.put("length", 3345152); + + ArrayList includedColumns = new ArrayList(); + includedColumns.add("i_item_sk"); + includedColumns.add("i_item_id"); + job.put("includedColumns", includedColumns.toArray()); + + orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); + assertTrue(orcColumnarBatchJniReader.recordReader != 0); + } + + public void initBatch() { + orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); + assertTrue(orcColumnarBatchJniReader.batchReader != 0); + } + + @Test + public void testNext() { + int[] typeId = new int[2]; + long[] vecNativeId = new long[2]; + long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.reader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); + assertTrue(rtn == 4096); + LongVec vec1 = new LongVec(vecNativeId[0]); + VarcharVec vec2 = new VarcharVec(vecNativeId[1]); + assertTrue(vec1.get(4090) == 4091); + assertTrue(vec1.get(4000) == 4001); + String tmp1 = new String(vec2.get(4090)); + String tmp2 = new String(vec2.get(4000)); + assertTrue(tmp1.equals("AAAAAAAAKPPAAAAA")); + assertTrue(tmp2.equals("AAAAAAAAAKPAAAAA")); + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java new file mode 100644 index 0000000000000000000000000000000000000000..8bdb4ce2bfb121a542ab9c1203827221348e0618 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java @@ -0,0 +1,154 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import static org.junit.Assert.*; +import junit.framework.TestCase; +import org.apache.hadoop.mapred.join.ArrayListBackedIterator; +import org.apache.orc.OrcFile.ReaderOptions; +import org.apache.orc.Reader.Options; +import org.hamcrest.Condition; +import org.json.JSONObject; +import org.junit.After; +import org.junit.Before; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; + +import java.io.File; +import java.lang.reflect.Array; +import java.util.ArrayList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) +public class OrcColumnarBatchJniReaderPushDownTest extends TestCase { + public OrcColumnarBatchJniReader orcColumnarBatchJniReader; + + @Before + public void setUp() throws Exception { + orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); + initReaderJava(); + initRecordReaderJava(); + initBatch(); + } + + @After + public void tearDown() throws Exception { + System.out.println("orcColumnarBatchJniReader test finished"); + } + + public void initReaderJava() { + JSONObject job = new JSONObject(); + job.put("serializedTail",""); + job.put("tailLocation",9223372036854775807L); + File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); + System.out.println(directory.getAbsolutePath()); + orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReader(directory.getAbsolutePath(), job); + assertTrue(orcColumnarBatchJniReader.reader != 0); + } + + public void initRecordReaderJava() { + JSONObject job = new JSONObject(); + job.put("include",""); + job.put("offset", 0); + job.put("length", 3345152); + + ArrayList childList1 = new ArrayList(); + JSONObject child1 = new JSONObject(); + child1.put("op", 3); + child1.put("leaf", "leaf-0"); + childList1.add(child1); + JSONObject subChild1 = new JSONObject(); + subChild1.put("op", 2); + subChild1.put("child", childList1); + + ArrayList childList2 = new ArrayList(); + JSONObject child2 = new JSONObject(); + child2.put("op", 3); + child2.put("leaf", "leaf-1"); + childList2.add(child2); + JSONObject subChild2 = new JSONObject(); + subChild2.put("op", 2); + subChild2.put("child", childList2); + + ArrayList childs = new ArrayList(); + childs.add(subChild1); + childs.add(subChild2); + + JSONObject expressionTree = new JSONObject(); + expressionTree.put("op", 1); + expressionTree.put("child", childs); + job.put("expressionTree", expressionTree); + + JSONObject leaves = new JSONObject(); + JSONObject leaf0 = new JSONObject(); + leaf0.put("op", 6); + leaf0.put("name", "i_item_sk"); + leaf0.put("type", 0); + leaf0.put("literal", ""); + leaf0.put("literalList", new ArrayList()); + + JSONObject leaf1 = new JSONObject(); + leaf1.put("op", 3); + leaf1.put("name", "i_item_sk"); + leaf1.put("type", 0); + leaf1.put("literal", "100"); + leaf1.put("literalList", new ArrayList()); + + leaves.put("leaf-0", leaf0); + leaves.put("leaf-1", leaf1); + job.put("leaves", leaves); + + ArrayList includedColumns = new ArrayList(); + includedColumns.add("i_item_sk"); + includedColumns.add("i_item_id"); + job.put("includedColumns", includedColumns.toArray()); + + orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); + assertTrue(orcColumnarBatchJniReader.recordReader != 0); + } + + public void initBatch() { + orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); + assertTrue(orcColumnarBatchJniReader.batchReader != 0); + } + + @Test + public void testNext() { + int[] typeId = new int[2]; + long[] vecNativeId = new long[2]; + long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.reader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); + assertTrue(rtn == 4096); + LongVec vec1 = new LongVec(vecNativeId[0]); + VarcharVec vec2 = new VarcharVec(vecNativeId[1]); + assertTrue(11 == vec1.get(10)); + assertTrue(21 == vec1.get(20)); + String tmp1 = new String(vec2.get(10)); + String tmp2 = new String(vec2.get(20)); + assertTrue(tmp1.equals("AAAAAAAAKAAAAAAA")); + assertTrue(tmp2.equals("AAAAAAAAEBAAAAAA")); + } + +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java new file mode 100644 index 0000000000000000000000000000000000000000..fe0afb3ca0fcf92ef3c4c028659be120201df7bc --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java @@ -0,0 +1,108 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import junit.framework.TestCase; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import org.json.JSONObject; +import org.junit.After; +import org.junit.Before; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +import java.io.File; +import java.util.ArrayList; + +import static org.junit.Assert.*; + +@FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) +public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase { + public OrcColumnarBatchJniReader orcColumnarBatchJniReader; + + @Before + public void setUp() throws Exception { + orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); + initReaderJava(); + initRecordReaderJava(); + initBatch(); + } + + @After + public void tearDown() throws Exception { + System.out.println("orcColumnarBatchJniReader test finished"); + } + + public void initReaderJava() { + JSONObject job = new JSONObject(); + job.put("serializedTail",""); + job.put("tailLocation",9223372036854775807L); + File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/part-00000-2d6ca713-08b0-4b40-828c-f7ee0c81bb9a-c000.snappy.orc"); + System.out.println(directory.getAbsolutePath()); + orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReader(directory.getAbsolutePath(), job); + assertTrue(orcColumnarBatchJniReader.reader != 0); + } + + public void initRecordReaderJava() { + JSONObject job = new JSONObject(); + job.put("include",""); + job.put("offset", 0); + job.put("length", 3345152); + + ArrayList includedColumns = new ArrayList(); + // type long + includedColumns.add("i_item_sk"); + // type char 16 + includedColumns.add("i_item_id"); + // type char 200 + includedColumns.add("i_item_desc"); + // type int + includedColumns.add("i_current_price"); + job.put("includedColumns", includedColumns.toArray()); + + orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); + assertTrue(orcColumnarBatchJniReader.recordReader != 0); + } + + public void initBatch() { + orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); + assertTrue(orcColumnarBatchJniReader.batchReader != 0); + } + + @Test + public void testNext() { + int[] typeId = new int[4]; + long[] vecNativeId = new long[4]; + long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.reader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); + assertTrue(rtn == 4096); + LongVec vec1 = new LongVec(vecNativeId[0]); + VarcharVec vec2 = new VarcharVec(vecNativeId[1]); + VarcharVec vec3 = new VarcharVec(vecNativeId[2]); + IntVec vec4 = new IntVec(vecNativeId[3]); + + assertTrue(vec1.get(4095) == 4096); + String tmp1 = new String(vec2.get(4095)); + assertTrue(tmp1.equals("AAAAAAAAAAABAAAA")); + String tmp2 = new String(vec3.get(4095)); + assertTrue(tmp2.equals("Find")); + assertTrue(vec4.get(4095) == 6); + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java new file mode 100644 index 0000000000000000000000000000000000000000..f26e7603f33fb5ba66474de650e53cabf1044793 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java @@ -0,0 +1,156 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import junit.framework.TestCase; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; +import org.json.JSONObject; +import org.junit.After; +import org.junit.Before; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +import java.io.File; +import java.util.ArrayList; + +import static org.junit.Assert.*; + +@FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) +public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase { + public OrcColumnarBatchJniReader orcColumnarBatchJniReader; + + @Before + public void setUp() throws Exception { + orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); + initReaderJava(); + initRecordReaderJava(); + initBatch(); + } + + @After + public void tearDown() throws Exception { + System.out.println("orcColumnarBatchJniReader test finished"); + } + + public void initReaderJava() { + JSONObject job = new JSONObject(); + job.put("serializedTail",""); + job.put("tailLocation",9223372036854775807L); + File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/part-00000-2d6ca713-08b0-4b40-828c-f7ee0c81bb9a-c000.snappy.orc"); + System.out.println(directory.getAbsolutePath()); + orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReader(directory.getAbsolutePath(), job); + assertTrue(orcColumnarBatchJniReader.reader != 0); + } + + public void initRecordReaderJava() { + JSONObject job = new JSONObject(); + job.put("include",""); + job.put("offset", 0); + job.put("length", 3345152); + + ArrayList childList1 = new ArrayList(); + JSONObject child1 = new JSONObject(); + child1.put("op", 3); + child1.put("leaf", "leaf-0"); + childList1.add(child1); + JSONObject subChild1 = new JSONObject(); + subChild1.put("op", 2); + subChild1.put("child", childList1); + + ArrayList childList2 = new ArrayList(); + JSONObject child2 = new JSONObject(); + child2.put("op", 3); + child2.put("leaf", "leaf-1"); + childList2.add(child2); + JSONObject subChild2 = new JSONObject(); + subChild2.put("op", 2); + subChild2.put("child", childList2); + + ArrayList childs = new ArrayList(); + childs.add(subChild1); + childs.add(subChild2); + + JSONObject expressionTree = new JSONObject(); + expressionTree.put("op", 1); + expressionTree.put("child", childs); + job.put("expressionTree", expressionTree); + + JSONObject leaves = new JSONObject(); + JSONObject leaf0 = new JSONObject(); + leaf0.put("op", 6); + leaf0.put("name", "i_item_sk"); + leaf0.put("type", 0); + leaf0.put("literal", ""); + leaf0.put("literalList", new ArrayList()); + + JSONObject leaf1 = new JSONObject(); + leaf1.put("op", 3); + leaf1.put("name", "i_item_sk"); + leaf1.put("type", 0); + leaf1.put("literal", "100"); + leaf1.put("literalList", new ArrayList()); + + leaves.put("leaf-0", leaf0); + leaves.put("leaf-1", leaf1); + job.put("leaves", leaves); + + ArrayList includedColumns = new ArrayList(); + // type long + includedColumns.add("i_item_sk"); + // type char 16 + includedColumns.add("i_item_id"); + // type char 200 + includedColumns.add("i_item_desc"); + // type int + includedColumns.add("i_current_price"); + job.put("includedColumns", includedColumns.toArray()); + + orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); + assertTrue(orcColumnarBatchJniReader.recordReader != 0); + } + + public void initBatch() { + orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); + assertTrue(orcColumnarBatchJniReader.batchReader != 0); + } + + @Test + public void testNext() { + int[] typeId = new int[4]; + long[] vecNativeId = new long[4]; + long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.reader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); + assertTrue(rtn == 4096); + LongVec vec1 = new LongVec(vecNativeId[0]); + VarcharVec vec2 = new VarcharVec(vecNativeId[1]); + VarcharVec vec3 = new VarcharVec(vecNativeId[2]); + IntVec vec4 = new IntVec(vecNativeId[3]); + + assertTrue(vec1.get(10) == 11); + String tmp1 = new String(vec2.get(4080)); + assertTrue(tmp1.equals("AAAAAAAABPPAAAAA")); + String tmp2 = new String(vec3.get(4070)); + assertTrue(tmp2.equals("Particular, arab cases shall like less current, different names. Computers start for the changes. Scottish, trying exercises operate marks; long, supreme miners may ro")); + assertTrue(vec4.get(1000) == 0); + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java new file mode 100644 index 0000000000000000000000000000000000000000..e55c6f8f12c28cea52235c29bd65dba3874f63df --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java @@ -0,0 +1,119 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import junit.framework.TestCase; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; +import org.apache.commons.codec.binary.Base64; +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; +import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl; +import org.apache.orc.OrcConf; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; +import org.apache.orc.mapred.OrcInputFormat; +import org.json.JSONObject; +import org.junit.After; +import org.junit.Before; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; +import org.apache.hadoop.conf.Configuration; +import java.io.File; +import java.util.ArrayList; +import org.apache.orc.Reader.Options; + +import static org.junit.Assert.*; + +@FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) +public class OrcColumnarBatchJniReaderTest extends TestCase { + public Configuration conf = new Configuration(); + public OrcColumnarBatchJniReader orcColumnarBatchJniReader; + public int batchSize = 4096; + + @Before + public void setUp() throws Exception { + Configuration conf = new Configuration(); + TypeDescription schema = + TypeDescription.fromString("struct<`i_item_sk`:bigint,`i_item_id`:string>"); + Options options = new Options(conf) + .range(0, Integer.MAX_VALUE) + .useZeroCopy(false) + .skipCorruptRecords(false) + .tolerateMissingSchema(true); + + options.schema(schema); + options.include(OrcInputFormat.parseInclude(schema, + null)); + String kryoSarg = "AQEAb3JnLmFwYWNoZS5oYWRvb3AuaGl2ZS5xbC5pby5zYXJnLkV4cHJlc3Npb25UcmXlAQEBamF2YS51dGlsLkFycmF5TGlz9AECAQABAQEBAQEAAQAAAAEEAAEBAwEAAQEBAQEBAAEAAAIIAAEJAAEBAgEBAQIBAscBb3JnLmFwYWNoZS5oYWRvb3AuaGl2ZS5xbC5pby5zYXJnLlNlYXJjaEFyZ3VtZW50SW1wbCRQcmVkaWNhdGVMZWFmSW1wbAEBaV9pdGVtX3PrAAABBwEBAQIBEAkAAAEEEg=="; + String sargColumns = "i_item_sk,i_item_id,i_rec_start_date,i_rec_end_date,i_item_desc,i_current_price,i_wholesale_cost,i_brand_id,i_brand,i_class_id,i_class,i_category_id,i_category,i_manufact_id,i_manufact,i_size,i_formulation,i_color,i_units,i_container,i_manager_id,i_product_name"; + if (kryoSarg != null && sargColumns != null) { + byte[] sargBytes = Base64.decodeBase64(kryoSarg); + SearchArgument sarg = + new Kryo().readObject(new Input(sargBytes), SearchArgumentImpl.class); + options.searchArgument(sarg, sargColumns.split(",")); + sarg.getExpression().toString(); + } + + orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); + initReaderJava(); + initRecordReaderJava(options); + initBatch(options); + } + + @After + public void tearDown() throws Exception { + System.out.println("orcColumnarBatchJniReader test finished"); + } + + public void initReaderJava() { + OrcFile.ReaderOptions readerOptions = OrcFile.readerOptions(conf); + File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); + String path = directory.getAbsolutePath(); + orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReaderJava(path, readerOptions); + assertTrue(orcColumnarBatchJniReader.reader != 0); + } + + public void initRecordReaderJava(Options options) { + orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReaderJava(options); + assertTrue(orcColumnarBatchJniReader.recordReader != 0); + } + + public void initBatch(Options options) { + orcColumnarBatchJniReader.initBatchJava(batchSize); + assertTrue(orcColumnarBatchJniReader.batchReader != 0); + } + + @Test + public void testNext() { + Vec[] vecs = new Vec[2]; + long rtn = orcColumnarBatchJniReader.next(vecs); + assertTrue(rtn == 4096); + assertTrue(((LongVec) vecs[0]).get(0) == 1); + String str = new String(((VarcharVec) vecs[1]).get(0)); + assertTrue(str.equals("AAAAAAAABAAAAAAA")); + } + +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0 b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0 new file mode 100644 index 0000000000000000000000000000000000000000..65e4e602cebab6ce7ca576d8ab20b1ae8841c981 Binary files /dev/null and b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0 differ diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/part-00000-2d6ca713-08b0-4b40-828c-f7ee0c81bb9a-c000.snappy.orc b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/part-00000-2d6ca713-08b0-4b40-828c-f7ee0c81bb9a-c000.snappy.orc new file mode 100644 index 0000000000000000000000000000000000000000..a79c7be758d63ce6f56b16deed0765e85c96a866 Binary files /dev/null and b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/part-00000-2d6ca713-08b0-4b40-828c-f7ee0c81bb9a-c000.snappy.orc differ diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarNativeReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarNativeReaderTest.java new file mode 100644 index 0000000000000000000000000000000000000000..fc7a2e2d9518ea2c6f556123a2645e808d79d373 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarNativeReaderTest.java @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import junit.framework.TestCase; +import org.apache.orc.Reader.Options; +import org.apache.hadoop.conf.Configuration; +import org.junit.After; +import org.junit.Before; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +@FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) +public class OrcColumnarNativeReaderTest extends TestCase{ + + @Before + public void setUp() throws Exception { + } + + @After + public void tearDown() throws Exception { + System.out.println("OrcColumnarNativeReaderTest test finished"); + } + + @Test + public void testBuildOptions() { + Configuration conf = new Configuration(); + Options options = OrcColumnarNativeReader.buildOptions(conf,0,1024); + assertTrue(options.getLength() == 1024L); + assertTrue(options.getOffset() == 0L); + } + +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/org/apache/spark/sql/execution/vectorized/OmniColumnVectorTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/org/apache/spark/sql/execution/vectorized/OmniColumnVectorTest.java new file mode 100644 index 0000000000000000000000000000000000000000..4a36d7b3fda8948958383bfc00e5a50136537248 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/org/apache/spark/sql/execution/vectorized/OmniColumnVectorTest.java @@ -0,0 +1,135 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized; + +import junit.framework.TestCase; +import nova.hetu.omniruntime.vector.*; +import org.apache.orc.Reader.Options; +import org.apache.hadoop.conf.Configuration; +import org.apache.spark.sql.execution.datasources.orc.OrcColumnarNativeReader; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.junit.After; +import org.junit.Before; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +import javax.validation.constraints.AssertTrue; + +import static org.junit.Assert.*; + +@FixMethodOrder(value = MethodSorters.NAME_ASCENDING) +public class OmniColumnVectorTest extends TestCase { + + @Before + public void setUp() throws Exception { + } + + @After + public void tearDown() throws Exception { + System.out.println("OmniColumnVectorTest test finished"); + } + + + @Test + public void testNewOmniColumnVector() { + OmniColumnVector vecTmp = new OmniColumnVector(4096, DataTypes.LongType, true); + LongVec vecLong = new LongVec(4096); + vecTmp.setVec(vecLong); + vecTmp.putLong(0, 123L); + assertTrue(vecTmp.getLong(0) == 123L); + assertTrue(vecTmp.getVec() != null); + vecTmp.close(); + + OmniColumnVector vecTmp1 = new OmniColumnVector(4096, DataTypes.IntegerType, true); + IntVec vecInt = new IntVec(4096); + vecTmp1.setVec(vecInt); + vecTmp1.putInt(0, 123); + assertTrue(vecTmp1.getInt(0) == 123); + assertTrue(vecTmp1.getVec() != null); + vecTmp1.close(); + + OmniColumnVector vecTmp3 = new OmniColumnVector(4096, DataTypes.BooleanType, true); + BooleanVec vecBoolean = new BooleanVec(4096); + vecTmp3.setVec(vecBoolean); + vecTmp3.putBoolean(0, true); + assertTrue(vecTmp3.getBoolean(0) == true); + assertTrue(vecTmp3.getVec() != null); + vecTmp3.close(); + + OmniColumnVector vecTmp4 = new OmniColumnVector(4096, DataTypes.BooleanType, false); + BooleanVec vecBoolean1 = new BooleanVec(4096); + vecTmp4.setVec(vecBoolean1); + vecTmp4.putBoolean(0, true); + assertTrue(vecTmp4.getBoolean(0) == true); + assertTrue(vecTmp4.getVec() != null); + vecTmp4.close(); + } + + @Test + public void testGetsPuts() { + OmniColumnVector vecTmp = new OmniColumnVector(4096, DataTypes.LongType, true); + LongVec vecLong = new LongVec(4096); + vecTmp.setVec(vecLong); + vecTmp.putLongs(0, 10, 123L); + long[] gets = vecTmp.getLongs(0, 10); + for (long i : gets) { + assertTrue(i == 123L); + } + assertTrue(vecTmp.getVec() != null); + vecTmp.close(); + + OmniColumnVector vecTmp1 = new OmniColumnVector(4096, DataTypes.IntegerType, true); + IntVec vecInt = new IntVec(4096); + vecTmp1.setVec(vecInt); + vecTmp1.putInts(0, 10, 123); + int[] getInts = vecTmp1.getInts(0, 10); + for (int i : getInts) { + assertTrue(i == 123); + } + assertTrue(vecTmp1.getVec() != null); + vecTmp1.close(); + + OmniColumnVector vecTmp3 = new OmniColumnVector(4096, DataTypes.BooleanType, true); + BooleanVec vecBoolean = new BooleanVec(4096); + vecTmp3.setVec(vecBoolean); + vecTmp3.putBooleans(0, 10, true); + boolean[] getBools = vecTmp3.getBooleans(0, 10); + for (boolean i : getBools) { + assertTrue(i == true); + } + assertTrue(vecTmp3.getVec() != null); + vecTmp3.close(); + + OmniColumnVector vecTmp4 = new OmniColumnVector(4096, DataTypes.BooleanType, false); + BooleanVec vecBoolean1 = new BooleanVec(4096); + vecTmp4.setVec(vecBoolean1); + vecTmp4.putBooleans(0, 10, true); + boolean[] getBools1 = vecTmp4.getBooleans(0, 10); + for (boolean i : getBools1) { + assertTrue(i == true); + } + System.out.println(vecTmp4.getBoolean(0)); + assertTrue(vecTmp4.getBoolean(0) == true); + assertTrue(vecTmp4.getVec() != null); + vecTmp4.close(); + } + +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties new file mode 100644 index 0000000000000000000000000000000000000000..441e271149b6d51e2873656034878d465cfb1436 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/HiveResource.properties @@ -0,0 +1,12 @@ +# +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. +# + +hive.metastore.uris=thrift://server1:9083 +spark.sql.warehouse.dir=/user/hive/warehouse +spark.memory.offHeap.size=8G +spark.sql.codegen.wholeStage=false +spark.sql.extensions=com.huawei.boostkit.spark.ColumnarPlugin +spark.shuffle.manager=org.apache.spark.shuffle.sort.ColumnarShuffleManager +spark.sql.orc.impl=native +hive.db=tpcds_bin_partitioned_orc_2 \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q1.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q1.sql new file mode 100644 index 0000000000000000000000000000000000000000..6478818e67814d2bc3c3a2239bb7a0000f27b1e0 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q1.sql @@ -0,0 +1,14 @@ +select i_item_id + ,i_item_desc + ,i_current_price +from item, inventory, date_dim, store_sales +where i_current_price between 76 and 76+30 +and inv_item_sk = i_item_sk +and d_date_sk=inv_date_sk +and d_date between cast('1998-06-29' as date) and cast('1998-08-29' as date) +and i_manufact_id in (512,409,677,16) +and inv_quantity_on_hand between 100 and 500 +and ss_item_sk = i_item_sk +group by i_item_id,i_item_desc,i_current_price +order by i_item_id +limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q10.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q10.sql new file mode 100644 index 0000000000000000000000000000000000000000..9ac4277eba4447c7205ed294fb038bf4c17955a5 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q10.sql @@ -0,0 +1,36 @@ +select + i_brand_id brand_id, + i_brand brand, + i_manufact_id, + i_manufact, + sum(ss_ext_sales_price) ext_price +from + date_dim, + store_sales, + item, + customer, + customer_address, + store +where + d_date_sk = ss_sold_date_sk + and ss_item_sk = i_item_sk + and i_manager_id = 7 + and d_moy = 11 + and d_year = 1999 + and ss_customer_sk = c_customer_sk + and c_current_addr_sk = ca_address_sk + and substr(ca_zip,1,5) <> substr(s_zip,1,5) + and ss_store_sk = s_store_sk + and ss_sold_date_sk between 2451484 and 2451513 -- partition key filter +group by + i_brand, + i_brand_id, + i_manufact_id, + i_manufact +order by + ext_price desc, + i_brand, + i_brand_id, + i_manufact_id, + i_manufact +limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q2.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q2.sql new file mode 100644 index 0000000000000000000000000000000000000000..5a2ade87aa05decff9262402dfe547910211d730 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q2.sql @@ -0,0 +1,48 @@ +with v1 as ( + select i_category, i_brand, + s_store_name, s_company_name, + d_year, d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) over + (partition by i_category, i_brand, + s_store_name,s_company_name,d_year) + avg_monthly_sales, + rank() over + (partition by i_category, i_brand, + s_store_name,s_company_name + order by d_year,d_moy) rn + from item, store_sales, date_dim, store + where ss_item_sk = i_item_sk and + ss_sold_date_sk = d_date_sk and + ss_store_sk = s_store_sk and + ( + d_year = 2000 or + ( d_year = 2000-1 and d_moy =12) or + ( d_year = 2000+1 and d_moy =1) + ) + group by i_category, i_brand, + s_store_name, s_company_name, + d_year, d_moy), + v2 as( + select v1.i_category, v1.i_brand + ,v1.d_year + ,v1.avg_monthly_sales + ,v1.sum_sales, v1_lag.sum_sales psum, v1_lead.sum_sales nsum + from v1, v1 v1_lag, v1 v1_lead + where v1.i_category = v1_lag.i_category and + v1.i_category = v1_lead.i_category and + v1.i_brand = v1_lag.i_brand and + v1.i_brand = v1_lead.i_brand and + v1.s_store_name = v1_lag.s_store_name and + v1.s_store_name = v1_lead.s_store_name and + v1.s_company_name = v1_lag.s_company_name and + v1.s_company_name = v1_lead.s_company_name and + v1.rn = v1_lag.rn + 1 and + v1.rn = v1_lead.rn -1) +select * +from v2 +where d_year = 2000 and + avg_monthly_sales > 0 and + case when avg_monthly_sales > 0 then abs(sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1 +order by sum_sales - avg_monthly_sales, d_year +limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q3.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q3.sql new file mode 100644 index 0000000000000000000000000000000000000000..33bd52ce6e07c6b7d214f23ce4cc2ab6bc23c707 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q3.sql @@ -0,0 +1,34 @@ +select + * +from + (select + i_manufact_id, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) over (partition by i_manufact_id) avg_quarterly_sales + from + item, + store_sales, + date_dim, + store + where + ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_store_sk = s_store_sk + and d_month_seq in (1212, 1212 + 1, 1212 + 2, 1212 + 3, 1212 + 4, 1212 +5, 1212 + 6, 1212+7, 1212 + 8, 1212 + 9, 1212 + 10, 1212 + 11) + and ((i_category in ('Books', 'Children', 'Electronics') + and i_class in ('personal', 'portable', 'reference', 'self-help') + and i_brand in ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')) + or (i_category in ('Women', 'Music', 'Men') + and i_class in ('accessories', 'classical', 'fragrances', 'pants') + and i_brand in ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1'))) + group by + i_manufact_id, + d_qoy + ) tmp1 +where + case when avg_quarterly_sales > 0 then abs (sum_sales -avg_quarterly_sales) / avg_quarterly_sales else null end > 0.1 +order by + avg_quarterly_sales, + sum_sales, + i_manufact_id +limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q4.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q4.sql new file mode 100644 index 0000000000000000000000000000000000000000..258c73813f4fb2f1f911c678c7f6996c05f9c15d --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q4.sql @@ -0,0 +1,35 @@ +select i_brand_id brand_id, i_brand brand,t_hour,t_minute, sum(ext_price) ext_price +from item, (select ws_ext_sales_price as ext_price, + ws_sold_date_sk as sold_date_sk, + ws_item_sk as sold_item_sk, + ws_sold_time_sk as time_sk + from web_sales,date_dim + where d_date_sk = ws_sold_date_sk + and d_moy=12 + and d_year=2001 + union all + select cs_ext_sales_price as ext_price, + cs_sold_date_sk as sold_date_sk, + cs_item_sk as sold_item_sk, + cs_sold_time_sk as time_sk + from catalog_sales,date_dim + where d_date_sk = cs_sold_date_sk + and d_moy=12 + and d_year=2001 + union all + select ss_ext_sales_price as ext_price, + ss_sold_date_sk as sold_date_sk, + ss_item_sk as sold_item_sk, + ss_sold_time_sk as time_sk + from store_sales,date_dim + where d_date_sk = ss_sold_date_sk + and d_moy=12 + and d_year=2001 + ) as tmp,time_dim +where + sold_item_sk = i_item_sk + and time_sk = t_time_sk + and i_manager_id=1 + and (t_meal_time = 'breakfast' or t_meal_time = 'dinner') +group by i_brand, i_brand_id,t_hour,t_minute +order by ext_price desc, brand_id; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q5.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q5.sql new file mode 100644 index 0000000000000000000000000000000000000000..4a8c7bc9d70ba5c1a3c11ede3881206851066e56 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q5.sql @@ -0,0 +1,20 @@ +select + c_customer_id as customer_id + ,c_last_name || ', ' || c_first_name as customername + from + customer + ,customer_address + ,customer_demographics + ,household_demographics + ,income_band + ,store_returns + where ca_city = 'Hopewell' + and c_current_addr_sk = ca_address_sk + and ib_lower_bound >= 32287 + and ib_upper_bound <= 82287 + and ib_income_band_sk = hd_income_band_sk + and cd_demo_sk = c_current_cdemo_sk + and hd_demo_sk = c_current_hdemo_sk + and sr_cdemo_sk = cd_demo_sk + order by customer_id + limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q6.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q6.sql new file mode 100644 index 0000000000000000000000000000000000000000..221c169e32482c68ebbe5b9011cc4b43934ee8c2 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q6.sql @@ -0,0 +1,25 @@ +select * +from (select i_manager_id + ,sum(ss_sales_price) sum_sales + ,avg(sum(ss_sales_price)) over (partition by i_manager_id) avg_monthly_sales + from item + ,store_sales + ,date_dim + ,store + where ss_item_sk = i_item_sk +and ss_sold_date_sk = d_date_sk +and ss_sold_date_sk between 2452123 and 2452487 +and ss_store_sk = s_store_sk +and d_month_seq in (1219,1219+1,1219+2,1219+3,1219+4,1219+5,1219+6,1219+7,1219+8,1219+9,1219+10,1219+11) +and (( i_category in ('Books','Children','Electronics') + and i_class in ('personal','portable','reference','self-help') + and i_brand in ('scholaramalgamalg #14','scholaramalgamalg #7', 'exportiunivamalg #9','scholaramalgamalg #9')) +or( i_category in ('Women','Music','Men') + and i_class in ('accessories','classical','fragrances','pants') + and i_brand in ('amalgimporto #1','edu packscholar #1','exportiimporto #1', 'importoamalg #1'))) +group by i_manager_id, d_moy) tmp1 +where case when avg_monthly_sales > 0 then abs (sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1 +order by i_manager_id + ,avg_monthly_sales + ,sum_sales +limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q7.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q7.sql new file mode 100644 index 0000000000000000000000000000000000000000..a42e5d9887c3e53bfd1570c496f2aab4b41ed5a3 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q7.sql @@ -0,0 +1,33 @@ +select + substr(w_warehouse_name,1,20) + ,sm_type + ,cc_name + ,sum(case when (cs_ship_date_sk - cs_sold_date_sk <= 30 ) then 1 else 0 end) as D30_days + ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 30) and + (cs_ship_date_sk - cs_sold_date_sk <= 60) then 1 else 0 end ) as D31_60_days + ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 60) and + (cs_ship_date_sk - cs_sold_date_sk <= 90) then 1 else 0 end) as D61_90_days + ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 90) and + (cs_ship_date_sk - cs_sold_date_sk <= 120) then 1 else 0 end) as D91_120_days + ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 120) then 1 else 0 end) as D120_days +from + catalog_sales + ,warehouse + ,ship_mode + ,call_center + ,date_dim +where + d_month_seq between 1202 and 1202 + 11 +-- equivalent to 2451605 2451969 +and cs_ship_date_sk = d_date_sk +and cs_warehouse_sk = w_warehouse_sk +and cs_ship_mode_sk = sm_ship_mode_sk +and cs_call_center_sk = cc_call_center_sk +group by + substr(w_warehouse_name,1,20) + ,sm_type + ,cc_name +order by substr(w_warehouse_name,1,20) + ,sm_type + ,cc_name +limit 100 ; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q8.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q8.sql new file mode 100644 index 0000000000000000000000000000000000000000..564b59b2460ae127197a808dbcff39e69c10e649 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q8.sql @@ -0,0 +1,41 @@ +select + * +from + (select + i_category, + i_class, + i_brand, + s_store_name, + s_company_name, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) over (partition by i_category, i_brand, s_store_name, s_company_name) avg_monthly_sales + from + item, + store_sales, + date_dim, + store + where + ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_store_sk = s_store_sk + and d_year in (2000) + and ((i_category in ('Home', 'Books', 'Electronics') + and i_class in ('wallpaper', 'parenting', 'musical')) + or (i_category in ('Shoes', 'Jewelry', 'Men') + and i_class in ('womens', 'birdal', 'pants'))) + and ss_sold_date_sk between 2451545 and 2451910 -- partition key filter + group by + i_category, + i_class, + i_brand, + s_store_name, + s_company_name, + d_moy + ) tmp1 +where + case when (avg_monthly_sales <> 0) then (abs(sum_sales - avg_monthly_sales) / avg_monthly_sales) else null end > 0.1 +order by + sum_sales - avg_monthly_sales, + s_store_name +limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q9.sql b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q9.sql new file mode 100644 index 0000000000000000000000000000000000000000..26350730a79bebcd470f73806282efd707b4d7df --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/resources/query-sqls/q9.sql @@ -0,0 +1,44 @@ +select + c_last_name, + c_first_name, + substr(s_city,1,30), + ss_ticket_number, + amt, + profit +from + (select + ss_ticket_number, + ss_customer_sk, + store.s_city, + sum(ss_coupon_amt) amt, + sum(ss_net_profit) profit + from + store_sales, + date_dim, + store, + household_demographics + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and (household_demographics.hd_dep_count = 8 + or household_demographics.hd_vehicle_count >0) + and date_dim.d_dow = 1 + and date_dim.d_year in (1998,1998+1,1998+2) + and store.s_number_employees between 200 and 295 + and ss_sold_date_sk between 2450819 and 2451904 + group by + ss_ticket_number, + ss_customer_sk, + ss_addr_sk, + store.s_city + ) ms, + customer +where + ss_customer_sk = c_customer_sk +order by + c_last_name, + c_first_name, + substr(s_city,1,30), + profit +limit 100; \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_spilled_mix_100batch_4096rows_lz4 b/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_spilled_mix_100batch_4096rows_lz4 new file mode 100644 index 0000000000000000000000000000000000000000..18c123b569c7abd71053a558dde23e811039806b Binary files /dev/null and b/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_spilled_mix_100batch_4096rows_lz4 differ diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_spilled_mix_1batch_100rows_snappy b/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_spilled_mix_1batch_100rows_snappy new file mode 100644 index 0000000000000000000000000000000000000000..712a82d4aadad6eb5f3650df0209e23b5d83e668 Binary files /dev/null and b/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_spilled_mix_1batch_100rows_snappy differ diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_spilled_mix_1batch_100rows_uncompressed b/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_spilled_mix_1batch_100rows_uncompressed new file mode 100644 index 0000000000000000000000000000000000000000..2f835a7c95af1c746f1e4a50f5ef5e26f9fafb1c Binary files /dev/null and b/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_spilled_mix_1batch_100rows_uncompressed differ diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_spilled_mix_1batch_100rows_zlib b/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_spilled_mix_1batch_100rows_zlib new file mode 100644 index 0000000000000000000000000000000000000000..e89b125edd874cc1fe287a0e19960b30d90a10fa Binary files /dev/null and b/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_spilled_mix_1batch_100rows_zlib differ diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_split_fixed_singlePartition_someNullCol b/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_split_fixed_singlePartition_someNullCol new file mode 100644 index 0000000000000000000000000000000000000000..3cec85e7ed2efb44dc54b888cdab203a8bb7b405 Binary files /dev/null and b/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_split_fixed_singlePartition_someNullCol differ diff --git a/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_split_fixed_singlePartition_someNullRow b/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_split_fixed_singlePartition_someNullRow new file mode 100644 index 0000000000000000000000000000000000000000..8e0c78f75e44cc82c9286330dd26777d21f4140d Binary files /dev/null and b/omnioperator/omniop-spark-extension/java/src/test/resources/test-data/shuffle_split_fixed_singlePartition_someNullRow differ diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/Vectorized/OmniColumnVectorSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/Vectorized/OmniColumnVectorSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..a9fc4452338869610ed8800d0417eddf9a148989 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/Vectorized/OmniColumnVectorSuite.scala @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.vectorized + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.types._ + +class OmniColumnVectorSuite extends SparkFunSuite { + test("int") { + val schema = new StructType().add("int", IntegerType); + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns(4, schema, true) + vectors(0).putInt(0, 1) + vectors(0).putInt(1, 2) + vectors(0).putInt(2, 3) + vectors(0).putInt(3, 4) + assert(1 == vectors(0).getInt(0)) + assert(2 == vectors(0).getInt(1)) + assert(3 == vectors(0).getInt(2)) + assert(4 == vectors(0).getInt(3)) + vectors(0).close() + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..76c8cfc455f54e77a12194702a2ec9b978cad0ec --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala @@ -0,0 +1,271 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.expression + +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{getExprIdMap, rewriteToOmniExpressionLiteral, rewriteToOmniJsonExpressionLiteral} +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Max, Min, Sum} +import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, StringType} + +/** + * 功能描述 + * + * @author w00630100 + * @since 2022-02-21 + */ +class OmniExpressionAdaptorSuite extends SparkFunSuite { + var allAttribute = Seq(AttributeReference("a", IntegerType)(), + AttributeReference("b", IntegerType)(), AttributeReference("c", BooleanType)(), + AttributeReference("d", BooleanType)(), AttributeReference("e", IntegerType)(), + AttributeReference("f", StringType)(), AttributeReference("g", StringType)()) + + // todo: CaseWhen,InSet + test("expression rewrite") { + checkExpressionRewrite("$operator$ADD:1(#0,#1)", Add(allAttribute(0), allAttribute(1))) + checkExpressionRewrite("$operator$ADD:1(#0,1:1)", Add(allAttribute(0), Literal(1))) + + checkExpressionRewrite("$operator$SUBTRACT:1(#0,#1)", + Subtract(allAttribute(0), allAttribute(1))) + checkExpressionRewrite("$operator$SUBTRACT:1(#0,1:1)", Subtract(allAttribute(0), Literal(1))) + + checkExpressionRewrite("$operator$MULTIPLY:1(#0,#1)", + Multiply(allAttribute(0), allAttribute(1))) + checkExpressionRewrite("$operator$MULTIPLY:1(#0,1:1)", Multiply(allAttribute(0), Literal(1))) + + checkExpressionRewrite("$operator$DIVIDE:1(#0,#1)", Divide(allAttribute(0), allAttribute(1))) + checkExpressionRewrite("$operator$DIVIDE:1(#0,1:1)", Divide(allAttribute(0), Literal(1))) + + checkExpressionRewrite("$operator$MODULUS:1(#0,#1)", + Remainder(allAttribute(0), allAttribute(1))) + checkExpressionRewrite("$operator$MODULUS:1(#0,1:1)", Remainder(allAttribute(0), Literal(1))) + + checkExpressionRewrite("$operator$GREATER_THAN:4(#0,#1)", + GreaterThan(allAttribute(0), allAttribute(1))) + checkExpressionRewrite("$operator$GREATER_THAN:4(#0,1:1)", + GreaterThan(allAttribute(0), Literal(1))) + + checkExpressionRewrite("$operator$GREATER_THAN_OR_EQUAL:4(#0,#1)", + GreaterThanOrEqual(allAttribute(0), allAttribute(1))) + checkExpressionRewrite("$operator$GREATER_THAN_OR_EQUAL:4(#0,1:1)", + GreaterThanOrEqual(allAttribute(0), Literal(1))) + + checkExpressionRewrite("$operator$LESS_THAN:4(#0,#1)", + LessThan(allAttribute(0), allAttribute(1))) + checkExpressionRewrite("$operator$LESS_THAN:4(#0,1:1)", + LessThan(allAttribute(0), Literal(1))) + + checkExpressionRewrite("$operator$LESS_THAN_OR_EQUAL:4(#0,#1)", + LessThanOrEqual(allAttribute(0), allAttribute(1))) + checkExpressionRewrite("$operator$LESS_THAN_OR_EQUAL:4(#0,1:1)", + LessThanOrEqual(allAttribute(0), Literal(1))) + + checkExpressionRewrite("$operator$EQUAL:4(#0,#1)", EqualTo(allAttribute(0), allAttribute(1))) + checkExpressionRewrite("$operator$EQUAL:4(#0,1:1)", EqualTo(allAttribute(0), Literal(1))) + + checkExpressionRewrite("OR:4(#2,#3)", Or(allAttribute(2), allAttribute(3))) + checkExpressionRewrite("OR:4(#2,3:1)", Or(allAttribute(2), Literal(3))) + + checkExpressionRewrite("AND:4(#2,#3)", And(allAttribute(2), allAttribute(3))) + checkExpressionRewrite("AND:4(#2,3:1)", And(allAttribute(2), Literal(3))) + + checkExpressionRewrite("not:4(#3)", Not(allAttribute(3))) + + checkExpressionRewrite("IS_NOT_NULL:4(#4)", IsNotNull(allAttribute(4))) + + checkExpressionRewrite("substr:15(#5,#0,#1)", + Substring(allAttribute(5), allAttribute(0), allAttribute(1))) + + checkExpressionRewrite("CAST:2(#1)", Cast(allAttribute(1), LongType)) + + checkExpressionRewrite("abs:1(#0)", Abs(allAttribute(0))) + + checkExpressionRewrite("SUM:2(#0)", Sum(allAttribute(0))) + + checkExpressionRewrite("MAX:1(#0)", Max(allAttribute(0))) + + checkExpressionRewrite("AVG:3(#0)", Average(allAttribute(0))) + + checkExpressionRewrite("MIN:1(#0)", Min(allAttribute(0))) + + checkExpressionRewrite("IN:4(#0,#0,#1)", + In(allAttribute(0), Seq(allAttribute(0), allAttribute(1)))) + + // checkExpressionRewrite("IN:4(#0, #0, #1)", InSet(allAttribute(0), Set(allAttribute(0), allAttribute(1)))) + } + + test("json expression rewrite") { + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"ADD\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", + Add(allAttribute(0), allAttribute(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"ADD\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + Add(allAttribute(0), Literal(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"SUBTRACT\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", + Subtract(allAttribute(0), allAttribute(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"SUBTRACT\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + Subtract(allAttribute(0), Literal(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"MULTIPLY\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", + Multiply(allAttribute(0), allAttribute(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"MULTIPLY\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + Multiply(allAttribute(0), Literal(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"DIVIDE\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", + Divide(allAttribute(0), allAttribute(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"DIVIDE\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + Divide(allAttribute(0), Literal(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"MODULUS\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", + Remainder(allAttribute(0), allAttribute(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"MODULUS\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + Remainder(allAttribute(0), Literal(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"GREATER_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", + GreaterThan(allAttribute(0), allAttribute(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"GREATER_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + GreaterThan(allAttribute(0), Literal(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"GREATER_THAN_OR_EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", + GreaterThanOrEqual(allAttribute(0), allAttribute(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"GREATER_THAN_OR_EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + GreaterThanOrEqual(allAttribute(0), Literal(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", + LessThan(allAttribute(0), allAttribute(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + LessThan(allAttribute(0), Literal(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"LESS_THAN_OR_EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", + LessThanOrEqual(allAttribute(0), allAttribute(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"LESS_THAN_OR_EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + LessThanOrEqual(allAttribute(0), Literal(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}}", + EqualTo(allAttribute(0), allAttribute(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + EqualTo(allAttribute(0), Literal(1))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":2}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":3}}", + Or(allAttribute(2), allAttribute(3))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":2}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":3}}", + Or(allAttribute(2), Literal(3))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":2}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":3}}", + And(allAttribute(2), allAttribute(3))) + + checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":2}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":3}}", + And(allAttribute(2), Literal(3))) + + checkJsonExprRewrite("{\"exprType\":\"UNARY\",\"returnType\":4, \"operator\":\"not\"," + + "\"expr\":{\"exprType\":\"IS_NULL\",\"returnType\":4," + + "\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4}]}}", + IsNotNull(allAttribute(4))) + + checkJsonExprRewrite("{\"exprType\":\"FUNCTION\",\"returnType\":2,\"function_name\":\"CAST\"," + + "\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}]}", + Cast(allAttribute(1), LongType)) + + checkJsonExprRewrite("{\"exprType\":\"FUNCTION\",\"returnType\":1,\"function_name\":\"abs\"," + + " \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}]}", + Abs(allAttribute(0))) + } + + protected def checkExpressionRewrite(expected: Any, expression: Expression): Unit = { + { + val runResult = rewriteToOmniExpressionLiteral(expression, getExprIdMap(allAttribute)) + if (!expected.equals(runResult)) { + fail(s"expression($expression) not match with expected value:$expected," + + s"running value:$runResult") + } + } + } + + protected def checkJsonExprRewrite(expected: Any, expression: Expression): Unit = { + val runResult = rewriteToOmniJsonExpressionLiteral(expression, getExprIdMap(allAttribute)) + if (!expected.equals(runResult)) { + fail(s"expression($expression) not match with expected value:$expected," + + s"running value:$runResult") + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..0a08416ff04ba4772f8141e418e51757bc237f12 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2021-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.hive + +import java.util.Properties + +import com.huawei.boostkit.spark.hive.util.HiveResourceRunner +import org.apache.log4j.{Level, LogManager} +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.sql.SparkSession + +/** + * @since 2021/12/15 + */ +class HiveResourceSuite extends SparkFunSuite { + private val QUERY_SQLS = "query-sqls" + private var spark: SparkSession = _ + private var runner: HiveResourceRunner = _ + + override def beforeAll(): Unit = { + val properties = new Properties() + properties.load(this.getClass.getClassLoader.getResourceAsStream("HiveResource.properties")) + + spark = SparkSession.builder() + .appName("test-sql-context") + .master("local[2]") + .config(readConf(properties)) + .enableHiveSupport() + .getOrCreate() + LogManager.getRootLogger.setLevel(Level.WARN) + runner = new HiveResourceRunner(spark, QUERY_SQLS) + + val hiveDb = properties.getProperty("hive.db") + spark.sql(if (hiveDb == null) "use default" else s"use $hiveDb") + } + + override def afterAll(): Unit = { + super.afterAll() + } + + test("queryBySparkSql-HiveDataSource") { + runner.runQuery("q1", 1) + runner.runQuery("q2", 1) + runner.runQuery("q3", 1) + runner.runQuery("q4", 1) + runner.runQuery("q5", 1) + runner.runQuery("q6", 1) + runner.runQuery("q7", 1) + runner.runQuery("q8", 1) + runner.runQuery("q9", 1) + runner.runQuery("q10", 1) + } + + def readConf(properties: Properties): SparkConf = { + val conf = new SparkConf() + val wholeStage = properties.getProperty("spark.sql.codegen.wholeStage") + val offHeapSize = properties.getProperty("spark.memory.offHeap.size") + conf.set("hive.metastore.uris", properties.getProperty("hive.metastore.uris")) + .set("spark.sql.warehouse.dir", properties.getProperty("spark.sql.warehouse.dir")) + .set("spark.memory.offHeap.size", if (offHeapSize == null) "8G" else offHeapSize) + .set("spark.sql.codegen.wholeStage", if (wholeStage == null) "false" else wholeStage) + .set("spark.sql.extensions", properties.getProperty("spark.sql.extensions")) + .set("spark.shuffle.manager", properties.getProperty("spark.shuffle.manager")) + .set("spark.sql.orc.impl", properties.getProperty("spark.sql.orc.impl")) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/hive/util/HiveResourceRunner.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/hive/util/HiveResourceRunner.scala new file mode 100644 index 0000000000000000000000000000000000000000..84e12f6bd5b6f63b57ecba62a27c20cc5e6698fc --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/hive/util/HiveResourceRunner.scala @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2021-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.hive.util + +import java.io.{File, FilenameFilter} +import java.nio.charset.StandardCharsets + +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.{Row, SparkSession} + +class HiveResourceRunner(val spark: SparkSession, val resource: String) { + val caseIds = HiveResourceRunner.parseCaseIds(HiveResourceRunner.locateResourcePath(resource), + ".sql") + + def runQuery(caseId: String, roundId: Int, explain: Boolean = false): Unit = { + val path = "%s/%s.sql".format(resource, caseId) + val absolute = HiveResourceRunner.locateResourcePath(path) + val sql = FileUtils.readFileToString(new File(absolute), StandardCharsets.UTF_8) + println("Running query %s (round %d)... ".format(caseId, roundId)) + val df = spark.sql(sql) + if (explain) { + df.explain(extended = true) + } + val result: Array[Row] = df.head(100) + result.foreach(row => println(row)) + } +} + +object HiveResourceRunner { + private def parseCaseIds(dir: String, suffix: String): List[String] = { + val folder = new File(dir) + if (!folder.exists()) { + throw new IllegalArgumentException("dir does not exist: " + dir) + } + folder + .listFiles(new FilenameFilter { + override def accept(dir: File, name: String): Boolean = name.endsWith(suffix) + }) + .map(f => f.getName) + .map(n => n.substring(0, n.lastIndexOf(suffix))) + .sortBy(s => { + //fill with leading zeros + "%s%s".format(new String((0 until 16 - s.length).map(_ => '0').toArray), s) + }) + .toList + } + + private def locateResourcePath(resource: String): String = { + classOf[HiveResourceRunner].getClassLoader.getResource("") + .getPath.concat(File.separator).concat(resource) + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..b4f8fa1d25b9cd6d67089c3848a1ea49c92dba49 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.FileInputStream + +import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.vectorized.ColumnarBatch + +class ColumnShuffleSerializerDisableCompressSuite extends SparkFunSuite with SharedSparkSession { + + private var avgBatchNumRows: SQLMetric = _ + private var outputNumRows: SQLMetric = _ + + override def sparkConf: SparkConf = + super.sparkConf + .setAppName("test ColumnarShuffleDeSerializer disable compressed") + .set("spark.shuffle.compress", "false") + + override def beforeEach(): Unit = { + avgBatchNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer avg read batch num rows") + outputNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer number of output rows") + } + + test("columnar shuffle deserialize no null uncompressed compressed") { + val input = getTestResourcePath("test-data/shuffle_spilled_mix_1batch_100rows_uncompressed") + val serializer = + new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance() + val deserializedStream = + serializer.deserializeStream(new FileInputStream(input)) + + val kv = deserializedStream.asKeyValueIterator + var length = 0 + kv.foreach { + case (_, batch: ColumnarBatch) => + length += 1 + assert(batch.numRows == 100) + assert(batch.numCols == 4) + (0 until batch.numCols).foreach { i => + val valueVector = + batch + .column(i) + .asInstanceOf[OmniColumnVector] + .getVec + assert(valueVector.getSize == batch.numRows) + } + batch.close() + } + assert(length == 1) + deserializedStream.close() + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala new file mode 100644 index 0000000000000000000000000000000000000000..f6960b828f06cd064c4111505947944b8bd5c343 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.FileInputStream + +import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.vectorized.ColumnarBatch + +class ColumnShuffleSerializerLz4Suite extends SparkFunSuite with SharedSparkSession { + private var avgBatchNumRows: SQLMetric = _ + private var outputNumRows: SQLMetric = _ + + override def sparkConf: SparkConf = + super.sparkConf + .setAppName("test ColumnarShuffleDeSerializer") + .set("spark.shuffle.compress", "true") + .set("spark.io.compression.codec", "lz4") + + override def beforeEach(): Unit = { + avgBatchNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer avg read batch num rows") + outputNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer number of output rows") + } + + test("columnar shuffle deserialize no null lz4 compressed") { + val input = getTestResourcePath("test-data/shuffle_spilled_mix_100batch_4096rows_lz4") + val serializer = + new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance() + val deserializedStream = + serializer.deserializeStream(new FileInputStream(input)) + + val kv = deserializedStream.asKeyValueIterator + var length = 0 + kv.foreach { + case (_, batch: ColumnarBatch) => + length += 1 + assert(batch.numRows == 4096) + assert(batch.numCols == 4) + (0 until batch.numCols).foreach { i => + val valueVector = + batch + .column(i) + .asInstanceOf[OmniColumnVector] + .getVec + assert(valueVector.getSize == batch.numRows) + } + batch.close() + } + assert(length == 100) + deserializedStream.close() + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..278214d4c51c1bae90a56729037b1ed7cb2be52a --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.FileInputStream + +import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.vectorized.ColumnarBatch + +class ColumnShuffleSerializerSnappySuite extends SparkFunSuite with SharedSparkSession { + private var avgBatchNumRows: SQLMetric = _ + private var outputNumRows: SQLMetric = _ + + override def sparkConf: SparkConf = + super.sparkConf + .setAppName("test ColumnarShuffleDeSerializer") + .set("spark.shuffle.compress", "true") + .set("spark.io.compression.codec", "snappy") + + override def beforeEach(): Unit = { + avgBatchNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer avg read batch num rows") + outputNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer number of output rows") + } + + test("columnar shuffle deserialize no null snappy compressed") { + val input = getTestResourcePath("test-data/shuffle_spilled_mix_1batch_100rows_snappy") + val serializer = + new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance() + val deserializedStream = + serializer.deserializeStream(new FileInputStream(input)) + + val kv = deserializedStream.asKeyValueIterator + var length = 0 + kv.foreach { + case (_, batch: ColumnarBatch) => + length += 1 + assert(batch.numRows == 100) + assert(batch.numCols == 4) + (0 until batch.numCols).foreach { i => + val valueVector = + batch + .column(i) + .asInstanceOf[OmniColumnVector] + .getVec + assert(valueVector.getSize == batch.numRows) + } + batch.close() + } + assert(length == 1) + deserializedStream.close() + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..51e3466b674141d1181f238c34db2e4b5013c16b --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSuite.scala @@ -0,0 +1,102 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.FileInputStream + +import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.vectorized.ColumnarBatch + +class ColumnShuffleSerializerSuite extends SparkFunSuite with SharedSparkSession { + private var avgBatchNumRows: SQLMetric = _ + private var outputNumRows: SQLMetric = _ + + override def sparkConf: SparkConf = + super.sparkConf + .setAppName("test ColumnarShuffleDeSerializer") + .set("spark.shuffle.compress", "true") + .set("spark.io.compression.codec", "lz4") + + override def beforeEach(): Unit = { + avgBatchNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer avg read batch num rows") + outputNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer number of output rows") + } + + test("columnar shuffle deserialize some row nullable value lz4 compressed") { + val input = getTestResourcePath("test-data/shuffle_split_fixed_singlePartition_someNullRow") + val serializer = + new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance() + val deserializedStream = + serializer.deserializeStream(new FileInputStream(input)) + + val kv = deserializedStream.asKeyValueIterator + var length = 0 + kv.foreach { + case (_, batch: ColumnarBatch) => + length += 1 + assert(batch.numRows == 600) + assert(batch.numCols == 4) + (0 until batch.numCols).foreach { i => + val valueVector = + batch + .column(i) + .asInstanceOf[OmniColumnVector] + .getVec + assert(valueVector.getSize == batch.numRows) + } + batch.close() + } + assert(length == 1) + deserializedStream.close() + } + + test("columnar shuffle deserialize some col nullable value lz4 compressed") { + val input = getTestResourcePath("test-data/shuffle_split_fixed_singlePartition_someNullCol") + val serializer = + new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance() + val deserializedStream = + serializer.deserializeStream(new FileInputStream(input)) + + val kv = deserializedStream.asKeyValueIterator + var length = 0 + kv.foreach { + case (_, batch: ColumnarBatch) => + length += 1 + assert(batch.numRows == 600) + assert(batch.numCols == 4) + (0 until batch.numCols).foreach { i => + val valueVector = + batch + .column(i) + .asInstanceOf[OmniColumnVector] + .getVec + assert(valueVector.getSize == batch.numRows) + } + batch.close() + } + assert(length == 1) + deserializedStream.close() + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..08a4343283262fe6b99854cb3da52b38251c49ae --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.FileInputStream + +import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.vectorized.ColumnarBatch + +class ColumnShuffleSerializerZlibSuite extends SparkFunSuite with SharedSparkSession { + private var avgBatchNumRows: SQLMetric = _ + private var outputNumRows: SQLMetric = _ + + override def sparkConf: SparkConf = + super.sparkConf + .setAppName("test ColumnarShuffleDeSerializer zlib compressed") + .set("spark.shuffle.compress", "true") + .set("spark.io.compression.codec", "zlib") + + override def beforeEach(): Unit = { + avgBatchNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer avg read batch num rows") + outputNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer number of output rows") + } + + test("columnar shuffle deserialize no null snappy compressed") { + val input = getTestResourcePath("test-data/shuffle_spilled_mix_1batch_100rows_zlib") + val serializer = + new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance() + val deserializedStream = + serializer.deserializeStream(new FileInputStream(input)) + + val kv = deserializedStream.asKeyValueIterator + var length = 0 + kv.foreach { + case (_, batch: ColumnarBatch) => + length += 1 + assert(batch.numRows == 100) + assert(batch.numCols == 4) + (0 until batch.numCols).foreach { i => + val valueVector = + batch + .column(i) + .asInstanceOf[OmniColumnVector] + .getVec + assert(valueVector.getSize == batch.numRows) + } + batch.close() + } + assert(length == 1) + deserializedStream.close() + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..6dddff494d23f9869398fe19cb49767d1d31967d --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala @@ -0,0 +1,297 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.{File, FileInputStream} + +import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer +import com.huawei.boostkit.spark.vectorized.PartitionInfo +import nova.hetu.omniruntime.`type`.Decimal64DataType +import nova.hetu.omniruntime.vector._ +import org.apache.spark.{HashPartitioner, SparkConf, TaskContext} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.sort.ColumnarShuffleHandle +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType} +import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} +import org.apache.spark.util.Utils +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.{doAnswer, when} +import org.mockito.invocation.InvocationOnMock + +class ColumnarShuffleWriterSuite extends SharedSparkSession { + @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ + @Mock(answer = RETURNS_SMART_NULLS) private var dependency + : ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = _ + + override def sparkConf: SparkConf = + super.sparkConf + .setAppName("test ColumnarShuffleWriter") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") + + private var taskMetrics: TaskMetrics = _ + private var tempDir: File = _ + private var outputFile: File = _ + + private var shuffleHandle: ColumnarShuffleHandle[Int, ColumnarBatch] = _ + private val numPartitions = 11 + + protected var avgBatchNumRows: SQLMetric = _ + protected var outputNumRows: SQLMetric = _ + + override def beforeEach(): Unit = { + super.beforeEach() + + avgBatchNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer avg read batch num rows") + outputNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer number of output rows") + + tempDir = Utils.createTempDir() + outputFile = File.createTempFile("shuffle", null, tempDir) + taskMetrics = new TaskMetrics + + MockitoAnnotations.initMocks(this) + + shuffleHandle = + new ColumnarShuffleHandle[Int, ColumnarBatch](shuffleId = 0, dependency = dependency) + + val inputTypes = "[{\"id\":\"OMNI_INT\",\"width\":0,\"precision\":0,\"scale\":0,\"dateUnit\":\"DAY\",\"timeUnit\":\"SEC\"}," + + "{\"id\":\"OMNI_INT\",\"width\":0,\"precision\":0,\"scale\":0,\"dateUnit\":\"DAY\",\"timeUnit\":\"SEC\"}," + + "{\"id\":\"OMNI_DECIMAL64\",\"width\":0,\"precision\":18,\"scale\":3,\"dateUnit\":\"DAY\",\"timeUnit\":\"SEC\"}," + + "{\"id\":\"OMNI_DECIMAL128\",\"width\":0,\"precision\":28,\"scale\":11,\"dateUnit\":\"DAY\",\"timeUnit\":\"SEC\"}]" + + when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) + when(dependency.serializer).thenReturn(new JavaSerializer(sparkConf)) + when(dependency.partitionInfo).thenReturn( + new PartitionInfo("hash", numPartitions, 4, inputTypes)) + // inputTypes e.g: + // [{"id":"OMNI_INT","width":0,"precision":0,"scale":0,"dateUnit":"DAY","timeUnit":"SEC"}, + // {"id":"OMNI_INT","width":0,"precision":0,"scale":0,"dateUnit":"DAY","timeUnit":"SEC"}] + when(dependency.dataSize) + .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "data size")) + when(dependency.bytesSpilled) + .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "shuffle bytes spilled")) + when(dependency.numInputRows) + .thenReturn(SQLMetrics.createMetric(spark.sparkContext, "number of input rows")) + when(dependency.splitTime) + .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_split")) + when(dependency.spillTime) + .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_spill")) + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) + + doAnswer { (invocationOnMock: InvocationOnMock) => + val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] + if (tmp != null) { + outputFile.delete + tmp.renameTo(outputFile) + } + null + }.when(blockResolver) + .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])) + } + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + override def afterAll(): Unit = { + super.afterAll() + } + + test("write empty iterator") { + val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( + blockResolver, + shuffleHandle, + 0, // MapId + taskContext.taskMetrics().shuffleWriteMetrics) + writer.write(Iterator.empty) + writer.stop( /* success = */ true) + + assert(writer.getPartitionLengths.sum === 0) + assert(outputFile.exists()) + assert(outputFile.length() === 0) + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === 0) + assert(shuffleWriteMetrics.recordsWritten === 0) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } + + test("write empty column batch") { + val vectorPid0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector() + val vector0_1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector() + val vector0_2 = ColumnarShuffleWriterSuite.initOmniColumnIntVector() + val vector0_3 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector() + val vector0_4 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector() + + val vectorPid1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector() + val vector1_1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector() + val vector1_2 = ColumnarShuffleWriterSuite.initOmniColumnIntVector() + val vector1_3 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector() + val vector1_4 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector() + + val cb0 = ColumnarShuffleWriterSuite.makeColumnarBatch( + vectorPid0.getVec.getSize,List(vectorPid0, vector0_1, vector0_2, vector0_3, vector0_4)) + val cb1 = ColumnarShuffleWriterSuite.makeColumnarBatch( + vectorPid1.getVec.getSize,List(vectorPid1, vector1_1, vector1_2, vector1_3, vector1_4)) + + def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb0), (0, cb1)) + + val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( + blockResolver, + shuffleHandle, + 0L, // MapId + taskContext.taskMetrics().shuffleWriteMetrics) + + writer.write(records) + writer.stop(success = true) + assert(writer.getPartitionLengths.sum === 0) + assert(outputFile.exists()) + assert(outputFile.length() === 0) + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === 0) + assert(shuffleWriteMetrics.recordsWritten === 0) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } + + test("write with some empty partitions") { + val vectorPid0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(0, 0, 1, 1) + val vector0_1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(null, null, null, null) + val vector0_2 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(100, 100, null, null) + val vector0_3 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(100L, 100L, 100L, 100L) + val vector0_4 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(Array(100L, 100L), Array(100L, 100L), null, null) + val cb0 = ColumnarShuffleWriterSuite.makeColumnarBatch( + vectorPid0.getVec.getSize,List(vectorPid0, vector0_1, vector0_2, vector0_3, vector0_4)) + + val vectorPid1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(0, 0, 1, 1) + val vector1_1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(null, null, null, null) + val vector1_2 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(100, 100, null, null) + val vector1_3 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(100L, 100L, 100L, 100L) + val vector1_4 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(Array(100L, 100L), Array(100L, 100L), null, null) + val cb1 = ColumnarShuffleWriterSuite.makeColumnarBatch( + vectorPid1.getVec.getSize,List(vectorPid1, vector1_1, vector1_2, vector1_3, vector1_4)) + + def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb0), (0, cb1)) + + val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( + blockResolver, + shuffleHandle, + 0L, // MapId + taskContext.taskMetrics().shuffleWriteMetrics) + + writer.write(records) + writer.stop(success = true) + + assert(writer.getPartitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.count(_ == 0L) === (numPartitions - 2)) + // should be (numPartitions - 2) zero length files + + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.recordsWritten === records.length) + + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + + val serializer = new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance() + val deserializedStream = serializer.deserializeStream(new FileInputStream(outputFile)) + + try { + val kv = deserializedStream.asKeyValueIterator + var length = 0 + kv.foreach { + case (_, batch: ColumnarBatch) => + length += 1 + assert(batch.numRows == 4) + assert(batch.numCols == 4) + (0 until batch.numCols).foreach { i => + val valueVector = + batch + .column(i) + .asInstanceOf[OmniColumnVector] + .getVec + assert(valueVector.getSize == batch.numRows) + } + batch.close() + } + assert(length == 2) + } finally { + deserializedStream.close() + } + + } +} + +object ColumnarShuffleWriterSuite { + def initOmniColumnIntVector(values: Integer*): OmniColumnVector = { + val length = values.length + val vecTmp = new IntVec(length) + (0 until length).foreach { i => + if (values(i) != null) { + vecTmp.set(i, values(i).asInstanceOf[Int]) + } + } + val colVecTmp = new OmniColumnVector(length, IntegerType, false) + colVecTmp.setVec(vecTmp) + colVecTmp + } + + def initOmniColumnDecimal64Vector(values: java.lang.Long*): OmniColumnVector = { + val length = values.length + val vecTmp = new LongVec(length) + (0 until length).foreach { i => + if (values(i) != null) { + vecTmp.set(i, values(i).asInstanceOf[Long]) + } + } + val colVecTmp = new OmniColumnVector(length, DecimalType(18, 3), false) + colVecTmp.setVec(vecTmp) + colVecTmp + } + + def initOmniColumnDecimal128Vector(values: Array[Long]*): OmniColumnVector = { + val length = values.length + val vecTmp = new Decimal128Vec(length) + (0 until length).foreach { i => + if (values(i) != null) { + vecTmp.set(i, values(i)) + } + } + val colVecTmp = new OmniColumnVector(length, DecimalType(28, 11), false) + colVecTmp.setVec(vecTmp) + colVecTmp + } + + def makeColumnarBatch(rowNum: Int, vectors: List[ColumnVector]): ColumnarBatch = { + new ColumnarBatch(vectors.toArray, rowNum) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarAggregateBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarAggregateBenchmark.scala new file mode 100644 index 0000000000000000000000000000000000000000..a1f113b1dfa515c28d780892f835278985eea01b --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarAggregateBenchmark.scala @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.benchmark + + +object ColumnarAggregateBenchmark extends ColumnarBasedBenchmark { + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + + val N = if (mainArgs.isEmpty) { + 500L << 20 + } else { + mainArgs(0).toLong + } + + runBenchmark("stat functions") { + spark.range(N).groupBy().agg("id" -> "sum").explain() + columnarBenchmark(s"spark.range(${N}).groupBy().agg(id -> sum)", N) { + spark.range(N).groupBy().agg("id" -> "sum").noop() + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarBasedBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarBasedBenchmark.scala new file mode 100644 index 0000000000000000000000000000000000000000..402932161e87e095e1a50ec19f77652a91629b78 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarBasedBenchmark.scala @@ -0,0 +1,56 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark +import org.apache.spark.sql.internal.SQLConf + +/** + * Common basic scenario to run benchmark + */ +abstract class ColumnarBasedBenchmark extends SqlBasedBenchmark { + /** Runs function `f` with 3 scenario(spark WSCG on, off and omni-columnar processing) */ + final def columnarBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { + val benchmark = new Benchmark(name, cardinality, output = output) + if (getSparkSession.conf.getOption("spark.sql.extensions").isDefined) + { + benchmark.addCase(s"$name omniruntime wholestage off", numIters = 5) { _ => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + f + } + } + } + else + { + benchmark.addCase(s"$name Spark wholestage off", numIters = 5) { _ => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + f + } + } + benchmark.addCase(s"$name Spark wholestage on", numIters = 5) { _ => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + f + } + } + } + + benchmark.run() + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarFilterBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarFilterBenchmark.scala new file mode 100644 index 0000000000000000000000000000000000000000..98e8596fe2b6c44cf869bd071385998cfa27bb54 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarFilterBenchmark.scala @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.sql.benchmark.ColumnarAggregateBenchmark.spark + +object ColumnarFilterBenchmark extends ColumnarBasedBenchmark { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + + val N = if (mainArgs.isEmpty) { + 500L << 20 + } else { + mainArgs(0).toLong + } + + runBenchmark("filter with API") { + spark.range(N).filter("id > 100").explain() + columnarBenchmark(s"spark.range(${N}).filter(id > 100)", N) { + spark.range(N).filter("id > 100").noop() + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarJoinBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarJoinBenchmark.scala new file mode 100644 index 0000000000000000000000000000000000000000..55eda5db03c85a9a6bff206555ebb40518764132 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarJoinBenchmark.scala @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.sql.functions._ + +object ColumnarJoinBenchmark extends ColumnarBasedBenchmark { + def broadcastHashJoinLongKey(rowsA: Long): Unit = { + val rowsB = 1 << 16 + val dim = spark.range(rowsB).selectExpr("id as k", "id as v") + val df = spark.range(rowsA).join(dim.hint("broadcast"), (col("id") % rowsB) === col("k")) + df.explain() + columnarBenchmark(s"broadcastHashJoinLongKey spark.range(${rowsA}).join(spark.range(${rowsB}))", rowsA) { + df.noop() + } + } + + def sortMergeJoin(rowsA: Long, rowsB: Long): Unit = { + val df1 = spark.range(rowsA).selectExpr(s"id * 2 as k1") + val df2 = spark.range(rowsB).selectExpr(s"id * 3 as k2") + val df = df1.join(df2.hint("mergejoin"), col("k1") === col("k2")) + df.explain() + columnarBenchmark(s"sortMergeJoin spark.range(${rowsA}).join(spark.range(${rowsB}))", rowsA) { + df.noop() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + + val rowsA = if (mainArgs.isEmpty) { + 20 << 20 + } else { + mainArgs(0).toLong + } + + val rowsB = if (mainArgs.isEmpty || mainArgs.length < 2) { + 1 << 16 + } else { + mainArgs(1).toLong + } + + runBenchmark("Join Benchmark") { + broadcastHashJoinLongKey(rowsA) + sortMergeJoin(rowsA, rowsB) + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarProjectBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarProjectBenchmark.scala new file mode 100644 index 0000000000000000000000000000000000000000..2540ccbc243fc415ae43ad6275c5c4e6c0b27304 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarProjectBenchmark.scala @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.benchmark + +object ColumnarProjectBenchmark extends ColumnarBasedBenchmark { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val N = if (mainArgs.isEmpty) { + 500L << 18 + } else { + mainArgs(0).toLong + } + + runBenchmark("project with API") { + spark.range(N).selectExpr("id as p").explain() + columnarBenchmark(s"spark.range(${N}).selectExpr(id as p)", N) { + spark.range(N).selectExpr("id as p").noop() + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarRangeBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarRangeBenchmark.scala new file mode 100644 index 0000000000000000000000000000000000000000..134ec158ee5cf80a235e05373324191ee8fc701a --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarRangeBenchmark.scala @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.benchmark + +object ColumnarRangeBenchmark extends ColumnarBasedBenchmark { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + + val N = if (mainArgs.isEmpty) { + 500L << 20 + } else { + mainArgs(0).toLong + } + + runBenchmark("range with API") { + spark.range(N).explain() + columnarBenchmark(s"spark.range(${N})", N) { + spark.range(N).noop() + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarSortBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarSortBenchmark.scala new file mode 100644 index 0000000000000000000000000000000000000000..99781adf1f5e81663d2d3a0310bcb197ad84a980 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarSortBenchmark.scala @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.benchmark + +object ColumnarSortBenchmark extends ColumnarBasedBenchmark { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + + val N = if (mainArgs.isEmpty) { + 500L << 20 + } else { + mainArgs(0).toLong + } + + runBenchmark("sort with API") { + val value = spark.range(N) + value.sort(value("id").desc).explain() + columnarBenchmark(s"spark.range(${N}).sort(id.desc)", N) { + val value = spark.range(N) + value.sort(value("id").desc).noop() + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarTopNBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarTopNBenchmark.scala new file mode 100644 index 0000000000000000000000000000000000000000..64d3d0ee4dcfe2234e8a4db49e87e5102dc85181 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarTopNBenchmark.scala @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.benchmark + +object ColumnarTopNBenchmark extends ColumnarBasedBenchmark { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + + val N = if (mainArgs.isEmpty) { + 500L << 20 + } else { + mainArgs(0).toLong + } + + runBenchmark("topN with API") { + val value = spark.range(N) + value.sort(value("id").desc).limit(20).explain() + + columnarBenchmark(s"spark.range(${N}).sort(id.desc).limit(20)", N) { + val value = spark.range(N) + value.sort(value("id").desc).limit(20).noop() + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarUnionBenchmark.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarUnionBenchmark.scala new file mode 100644 index 0000000000000000000000000000000000000000..24b98c9d572a727289b91e61c4c48a2a639dda58 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/benchmark/ColumnarUnionBenchmark.scala @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.benchmark + +object ColumnarUnionBenchmark extends ColumnarBasedBenchmark { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val N = if (mainArgs.isEmpty) { + 5L << 15 + } else { + mainArgs(0).toLong + } + + val M = if (mainArgs.isEmpty || mainArgs.length < 2) { + 10L << 15 + } else { + mainArgs(1).toLong + } + + runBenchmark("union with API") { + val rangeM = spark.range(M) + spark.range(N).union(rangeM).explain() + columnarBenchmark(s"spark.range(${N}).union(spark.range(${M}))", N) { + spark.range(N).union(rangeM).noop() + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..f362d85e5feda8c32e7075b8742bcc19a45b30a9 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StructType} + +class ColumnarExecSuite extends ColumnarSparkPlanTest { + private lazy val df = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0, false), + Row(1, 2.0, false), + Row(2, 1.0, false), + Row(null, null, false), + Row(null, 5.0, false), + Row(6, null, false) + )), new StructType().add("a", IntegerType).add("b", DoubleType) + .add("c", BooleanType)) + + test("validate columnar transfer exec happened") { + val res = df.filter("a > 1") + print(res.queryExecution.executedPlan) + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[RowToOmniColumnarExec]).isDefined, s"RowToOmniColumnarExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + } + + test("validate data type convert") { + val res = df.filter("a > 1") + print(res.queryExecution.executedPlan) + + checkAnswer( + df.filter("a > 1"), + Row(2, 1.0, false) :: Row(6, null, false) :: Nil) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarFilterExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarFilterExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..42360758eede5774e5fba5c2fabc055185a6da23 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarFilterExecSuite.scala @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.expressions.Expression + +class ColumnarFilterExecSuite extends ColumnarSparkPlanTest { + import testImplicits.{localSeqToDatasetHolder, newProductEncoder} + + private var inputDf: DataFrame = _ + private var inputDfWithNull: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + inputDf = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + ("abc", "", 4, 2.0), + ("", "Hello", 1, 1.0), + (" add", "World", 8, 3.0), + (" yeah ", "yeah", 10, 8.0) + ).toDF("a", "b", "c", "d") + + inputDfWithNull = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + (null, "", 4, 2.0), + (null, null, 1, 1.0), + (" add", "World", 8, null), + (" yeah ", "yeah", 10, 8.0), + (" yeah ", "yeah", 10, 8.0) + ).toDF("a", "b", "c", "d") + } + + test("validate columnar filter exec happened") { + val res = inputDf.filter("c > 1") + print(res.queryExecution.executedPlan) + val isColumnarFilterHappen = res.queryExecution.executedPlan + .find(_.isInstanceOf[ColumnarFilterExec]).isDefined + val isColumnarConditionProjectHappen = res.queryExecution.executedPlan + .find(_.isInstanceOf[ColumnarConditionProjectExec]).isDefined + assert(isColumnarFilterHappen || isColumnarConditionProjectHappen, s"ColumnarFilterExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + } + + test("columnar filter is equal to native") { + val expr: Expression = (inputDf.col("c") > 3).expr + checkThatPlansAgreeTemplate(expr = expr, df = inputDf) + } + + test("columnar filter is equal to native with null") { + val expr: Expression = (inputDfWithNull.col("c") > 3 && inputDfWithNull.col("d").isNotNull).expr + checkThatPlansAgreeTemplate(expr = expr, df = inputDfWithNull) + } + + def checkThatPlansAgreeTemplate(expr: Expression, df: DataFrame): Unit = { + checkThatPlansAgree( + df, + (child: SparkPlan) => + ColumnarFilterExec(expr, child = child), + (child: SparkPlan) => + FilterExec(expr, child = child), + sortAnswers = false) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..a425584f8ecee838ad1a52feab96ff00f17f5bfb --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.functions.{sum, count} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +class ColumnarHashAggregateExecSuite extends ColumnarSparkPlanTest { + private var df: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + df = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0, 1L, "a"), + Row(1, 2.0, 2L, null), + Row(2, 1.0, 3L, "c"), + Row(null, null, 6L, "e"), + Row(null, 5.0, 7L, "f") + )), new StructType().add("a", IntegerType).add("b", DoubleType) + .add("c", LongType).add("d", StringType)) + } + + test("validate columnar hashAgg exec happened") { + val res = df.groupBy("a").agg(sum("b")) + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + } + + test("check columnar hashAgg result") { + val res = testData2.groupBy("a").agg(sum("b")) + checkAnswer( + res, + Seq(Row(1, 3), Row(2, 3), Row(3, 3)) + ) + } + + test("check columnar hashAgg result with null") { + val res = df.filter(df("a").isNotNull && df("d").isNotNull).groupBy("a").agg(sum("b")) + checkAnswer( + res, + Seq(Row(1, 2.0), Row(2, 1.0)) + ) + } + + test("test count(*)/count(1)") { + val res1 = df.agg(count("*")) + assert(res1.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n${res1.queryExecution.executedPlan}") + checkAnswer( + res1, + Seq(Row(5)) + ) + + val res2 = df.groupBy("a").agg(count("*")) + checkAnswer( + res2, + Seq(Row(1, 2), Row(2, 1), Row(null, 2)) + ) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..978b2f3ba68fe1aa37edadc61028e352b1ed9996 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.BuildRight +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ColumnarBroadcastHashJoinExec, ColumnarSortMergeJoinExec} +import org.apache.spark.sql.functions.col + +// refer to joins package +class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { + import testImplicits.{localSeqToDatasetHolder, newProductEncoder} + + private var left: DataFrame = _ + private var right: DataFrame = _ + private var leftWithNull: DataFrame = _ + private var rightWithNull: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + left = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + ("abc", "", 4, 2.0), + ("", "Hello", 1, 1.0), + (" add", "World", 8, 3.0), + (" yeah ", "yeah", 10, 8.0) + ).toDF("a", "b", "q", "d") + + right = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + ("abc", "", 4, 1.0), + ("", "Hello", 2, 2.0), + (" add", "World", 1, 3.0), + (" yeah ", "yeah", 0, 4.0) + ).toDF("a", "b", "c", "d") + + leftWithNull = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + ("abc", null, 4, 2.0), + ("", "Hello", null, 1.0), + (" add", "World", 8, 3.0), + (" yeah ", "yeah", 10, 8.0) + ).toDF("a", "b", "q", "d") + + rightWithNull = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + ("abc", "", 4, 1.0), + ("", "Hello", 2, 2.0), + (" add", null, 1, null), + (" yeah ", null, null, 4.0) + ).toDF("a", "b", "c", "d") + } + + test("validate columnar broadcastHashJoin exec happened") { + val res = left.join(right.hint("broadcast"), col("q") === col("c")) + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarBroadcastHashJoinExec]).isDefined, s"ColumnarBroadcastHashJoinExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + } + + test("validate columnar sortMergeJoin exec happened") { + val res = left.join(right.hint("mergejoin"), col("q") === col("c")) + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortMergeJoinExec]).isDefined, s"ColumnarSortMergeJoinExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + } + + test("columnar broadcastHashJoin is equal to native") { + val df = left.join(right.hint("broadcast"), col("q") === col("c")) + val leftKeys = Seq(left.col("q").expr) + val rightKeys = Seq(right.col("c").expr) + checkThatPlansAgreeTemplate(df, leftKeys, rightKeys) + } + + test("columnar sortMergeJoin is equal to native") { + val df = left.join(right.hint("mergejoin"), col("q") === col("c")) + val leftKeys = Seq(left.col("q").expr) + val rightKeys = Seq(right.col("c").expr) + checkThatPlansAgreeTemplate(df, leftKeys, rightKeys) + } + + test("columnar broadcastHashJoin is equal to native with null") { + val df = leftWithNull.join(rightWithNull.hint("broadcast"), + col("q").isNotNull === col("c").isNotNull) + val leftKeys = Seq(leftWithNull.col("q").isNotNull.expr) + val rightKeys = Seq(rightWithNull.col("c").isNotNull.expr) + checkThatPlansAgreeTemplate(df, leftKeys, rightKeys) + } + + def checkThatPlansAgreeTemplate(df: DataFrame, leftKeys: Seq[Expression], + rightKeys: Seq[Expression]): Unit = { + checkThatPlansAgree( + df, + (child: SparkPlan) => + ColumnarBroadcastHashJoinExec(leftKeys, rightKeys, Inner, + BuildRight, None, child, child), + (child: SparkPlan) => + BroadcastHashJoinExec(leftKeys, rightKeys, Inner, + BuildRight, None, child, child), + sortAnswers = false) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarProjectExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarProjectExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..ce39461cd5bba1920509eda6be231467f2488e1c --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarProjectExecSuite.scala @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.expressions.NamedExpression + +class ColumnarProjectExecSuite extends ColumnarSparkPlanTest { + import testImplicits.{localSeqToDatasetHolder, newProductEncoder} + + private var inputDf: DataFrame = _ + private var inputDfWithNull: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + inputDf = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + ("abc", "", 4, 2.0), + ("", "Hello", 1, 1.0), + (" add", "World", 8, 3.0), + (" yeah ", "yeah", 10, 8.0) + ).toDF("a", "b", "c", "d") + + inputDfWithNull = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + (null, "", 4, 2.0), + (null, null, 1, 1.0), + (" add", "World", 8, 3.0), + (" yeah ", "yeah", 10, 8.0) + ).toDF("a", "b", "c", "d") + } + + test("validate columnar project exec happened") { + val res = inputDf.selectExpr("a as t") + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, + s"ColumnarProjectExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + } + + test("columnar project is equal to native") { + val projectList: Seq[NamedExpression] = Seq(inputDf.col("a").as("abc").expr.asInstanceOf[NamedExpression]) + checkThatPlansAgreeTemplate(projectList, inputDf) + } + + test("columnar project is equal to native with null") { + val projectList: Seq[NamedExpression] = Seq(inputDfWithNull.col("a").as("abc").expr.asInstanceOf[NamedExpression]) + checkThatPlansAgreeTemplate(projectList, inputDfWithNull) + } + + def checkThatPlansAgreeTemplate(projectList: Seq[NamedExpression], df: DataFrame): Unit = { + checkThatPlansAgree( + df, + (child: SparkPlan) => + ColumnarProjectExec(projectList, child = child), + (child: SparkPlan) => + ProjectExec(projectList, child = child), + sortAnswers = false) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarRangeExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarRangeExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..79699046aee4600572ceca73be7f89113045b26e --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarRangeExecSuite.scala @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +// refer to DataFrameRangeSuite +class ColumnarRangeSuite extends ColumnarSparkPlanTest { + test("validate columnar range exec happened") { + val res = spark.range(0, 10, 1) + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarRangeExec]).isDefined, s"ColumnarRangeExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..91fe50455e4f184069adee0be316d63658f4af16 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecSuite.scala @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +class ColumnarShuffleExchangeExecSuite extends ColumnarSparkPlanTest { + import testImplicits.{localSeqToDatasetHolder, newProductEncoder} + + protected override def beforeAll(): Unit = { + super.beforeAll() + } + + test("validate columnar shuffleExchange exec worked") { + val inputDf = Seq[(String, java.lang.Integer, java.lang.Double)] ( + ("Sam", 12, 9.1), + ("Bob", 13, 9.3), + ("Ted", 10, 8.9) + ).toDF("name", "age", "point") + val res = inputDf.sort(inputDf("age").asc) + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined, + s"ColumnarSortExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined, + s"ColumnarShuffleExchangeExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSortExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSortExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..cecf846af71efdfe1f25e064f600025c56feaeae --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSortExecSuite.scala @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.lang + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.SortOrder + +class ColumnarSortExecSuite extends ColumnarSparkPlanTest { + import testImplicits.{localSeqToDatasetHolder, newProductEncoder} + + test("validate columnar sort exec happened") { + val inputDf = Seq[(String, java.lang.Integer, java.lang.Double)]( + ("Hello", 4, 2.0), + ("Hello", 1, 1.0), + ("World", 8, 3.0) + ).toDF("a", "b", "c") + val res = inputDf.sort(inputDf("b").asc) + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined, s"ColumnarSortExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + } + + test("columnar sort is equal to native sort") { + val df = Seq[(String, java.lang.Integer, java.lang.Double)]( + ("Hello", 4, 2.0), + ("Hello", 1, 1.0), + ("World", 8, 3.0) + ).toDF("a", "b", "c") + val sortOrder = Stream('a.asc, 'b.asc, 'c.asc) + checkThatPlansAgreeTemplate(input = df, sortOrder = sortOrder) + } + + test("columnar sort is equal to native sort with null") { + val dfWithNull = Seq[(String, Integer, lang.Double)]( + ("Hello", 4, 2.0), + (null, 1, 1.0), + ("World", null, 3.0), + ("World", 8, 3.0) + ).toDF("a", "b", "c") + val sortOrder = Stream('a.asc, 'b.asc, 'c.asc) + checkThatPlansAgreeTemplate(input = dfWithNull, sortOrder = sortOrder) + } + + def checkThatPlansAgreeTemplate(input: DataFrame, sortOrder: Seq[SortOrder]): Unit = { + checkThatPlansAgree( + input, + (child: SparkPlan) => + ColumnarSortExec(sortOrder, global = true, child = child), + (child: SparkPlan) => + SortExec(sortOrder, global = true, child = child), + sortAnswers = false) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala new file mode 100644 index 0000000000000000000000000000000000000000..b58f9ee3bd853630645aff388eb51ed4d440817d --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row} +import org.apache.spark.sql.catalyst.util.stackTraceToString +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.test.SharedSparkSession + +private[sql] abstract class ColumnarSparkPlanTest extends SparkPlanTest with SharedSparkSession { + // setup basic columnar configuration for columnar exec + override def sparkConf: SparkConf = super.sparkConf + .set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "com.huawei.boostkit.spark.ColumnarPlugin") + .set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") + + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + val analyzedDF = try df catch { + case ae: AnalysisException => + if (ae.plan.isDefined) { + fail( + s""" + |Failed to analyze query: $ae + |${ae.plan.get} + | + |${stackTraceToString(ae)} + |""".stripMargin) + } else { + throw ae + } + } + assertEmptyMissingInput(analyzedDF) + QueryTest.checkAnswer(analyzedDF, expectedAnswer) + } + + private def assertEmptyMissingInput(query: Dataset[_]): Unit = { + assert(query.queryExecution.analyzed.missingInput.isEmpty, + s"The analyzed logical plan has missing inputs:\n${query.queryExecution.analyzed}") + assert(query.queryExecution.optimizedPlan.missingInput.isEmpty, + s"The optimized logical plan has missing inputs:\n${query.queryExecution.optimizedPlan}") + assert(query.queryExecution.executedPlan.missingInput.isEmpty, + s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}") + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..3fb8a1bf9c6306362687d4d975dabfd44f43b1b7 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNExecSuite.scala @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.dsl.expressions.DslSymbol +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, SortOrder} + +// refer to TakeOrderedAndProjectSuite +class ColumnarTopNExecSuite extends ColumnarSparkPlanTest { + import testImplicits.{localSeqToDatasetHolder, newProductEncoder} + + private var inputDf: DataFrame = _ + private var inputDfWithNull: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + inputDf = Seq[(java.lang.Integer, java.lang.Double, String)]( + (4, 2.0, "abc"), + (1, 1.0, "aaa"), + (8, 3.0, "ddd"), + (10, 8.0, "") + ).toDF("a", "b", "c") + + inputDfWithNull = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + ("abc", "", 4, 2.0), + ("", null, 1, 1.0), + (" add", "World", 8, null), + (" yeah ", "yeah", 10, 8.0) + ).toDF("a", "b", "c", "d") + } + + test("validate columnar topN exec happened") { + val res = inputDf.sort("a").limit(2) + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarTakeOrderedAndProjectExec]).isDefined, s"ColumnarTakeOrderedAndProjectExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + } + + test("columnar topN is equal to native") { + val limit = 3 + val sortOrder = Stream('a.asc, 'b.desc) + val projectList = Seq(inputDf.col("a").as("abc").expr.asInstanceOf[NamedExpression]) + checkThatPlansAgreeTemplate(inputDf, limit, sortOrder, projectList) + } + + test("columnar topN is equal to native with null") { + val res = inputDfWithNull.orderBy("a", "b").selectExpr("c + 1", "d + 2").limit(2) + checkAnswer(res, Seq(Row(2, 3.0), Row(9, null))) + } + + def checkThatPlansAgreeTemplate(df: DataFrame, limit: Int, sortOrder: Seq[SortOrder], + projectList: Seq[NamedExpression]): Unit = { + checkThatPlansAgree( + df, + input => ColumnarTakeOrderedAndProjectExec(limit, sortOrder, projectList, input), + input => TakeOrderedAndProjectExec(limit, sortOrder, projectList, input), + sortAnswers = false) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarUnionExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarUnionExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..9539d9448aad5b33a3277225a6b11b69c5b97913 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarUnionExecSuite.scala @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{DataFrame, Row} + +class ColumnarUnionExecSuite extends ColumnarSparkPlanTest { + import testImplicits.{localSeqToDatasetHolder, newProductEncoder} + + private var left: DataFrame = _ + private var right: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + left = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + ("abc", "", 4, 2.0), + ("", "Hello", 1, 1.0), + (" add", "World", 8, 3.0), + (" yeah ", "yeah", 10, 8.0) + ).toDF("a", "b", "c", "d") + + right = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + (null, "", 4, 2.0), + (null, null, 1, 1.0), + (" add", "World", 8, 3.0), + (" yeah ", "yeah", 10, 8.0) + ).toDF("a", "b", "c", "d") + } + + test("validate columnar union exec happened") { + val res = left.union(right) + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarUnionExec]).isDefined, s"ColumnarUnionExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + } + + test("columnar union is equal to expected") { + val expected = Array(Row("abc", "", 4, 2.0), + Row("", "Hello", 1, 1.0), + Row(" add", "World", 8, 3.0), + Row(" yeah ", "yeah", 10, 8.0), + Row(null, "", 4, 2.0), + Row(null, null, 1, 1.0), + Row(" add", "World", 8, 3.0), + Row(" yeah ", "yeah", 10, 8.0)) + val res = left.union(right) + val result: Array[Row] = res.head(8) + assertResult(expected)(result) + } + + test("columnar union is equal to native with null") { + val df = left.union(right) + val children = Seq(left.queryExecution.executedPlan, right.queryExecution.executedPlan) + checkThatPlansAgreeTemplate(df, children) + } + + def checkThatPlansAgreeTemplate(df: DataFrame, child: Seq[SparkPlan]): Unit = { + checkThatPlansAgree( + df, + (_: SparkPlan) => + ColumnarUnionExec(child), + (_: SparkPlan) => + UnionExec(child), + sortAnswers = false) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarWindowExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarWindowExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..f8c712c4b0824f2693fca3700c5ea9d1f0ffe359 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarWindowExecSuite.scala @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSparkSession + +// refer to DataFrameWindowFramesSuite +class ColumnarWindowExecSuite extends ColumnarSparkPlanTest with SharedSparkSession { + import testImplicits._ + + private var inputDf: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + inputDf = Seq( + ("abc", "", 4, 2.0), + ("", "Hello", 1, 1.0), + (" add", "World", 8, 3.0), + (" yeah ", "yeah", 10, 8.0), + ("abc", "", 10, 8.0) + ).toDF("a", "b", "c", "d") + } + + test("validate columnar window exec happened") { + val res1 = Window.partitionBy("a").orderBy('c.desc) + val res2 = inputDf.withColumn("max", max("c").over(res1)) + res2.head(10).foreach(row => println(row)) + assert(res2.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarWindowExec]).isDefined, s"ColumnarWindowExec not happened, executedPlan as follows: \n${res2.queryExecution.executedPlan}") + } + + // todo: window check answer + // test("lead/lag with negative offset") { + // val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + // val window = Window.partitionBy($"key").orderBy($"value") + // + // checkAnswer( + // df.select( + // $"key", + // lead("value", -1).over(window), + // lag("value", -1).over(window)), + // Row(1, null, "3") :: Row(1, "1", null) :: Row(2, null, "4") :: Row(2, "2", null) :: Nil) + // } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarHashAggregateExecSqlSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarHashAggregateExecSqlSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..1e0862c4a41ecae4a7147eab6da46c5c38f64fbb --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarHashAggregateExecSqlSuite.scala @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.forsql + +import org.apache.spark.sql.execution.{ColumnarHashAggregateExec, ColumnarSparkPlanTest} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +class ColumnarHashAggregateExecSqlSuite extends ColumnarSparkPlanTest { + private var df: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + df = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0, 1L, "a"), + Row(1, 2.0, 2L, null), + Row(2, 1.0, 3L, "c"), + Row(null, null, 6L, "e"), + Row(null, 5.0, 7L, "f") + )), new StructType().add("a", IntegerType).add("b", DoubleType) + .add("c", LongType).add("d", StringType)) + } + + test("test count(*)/count(1)") { + df.createOrReplaceTempView("test_table") + val res1 = spark.sql("select count(*) from test_table") + assert(res1.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n${res1.queryExecution.executedPlan}") + checkAnswer( + res1, + Seq(Row(5)) + ) + + val res2 = spark.sql("select count(1) from test_table") + checkAnswer( + res2, + Seq(Row(5)) + ) + + val res3 = spark.sql("select count(-1) from test_table") + checkAnswer( + res3, + Seq(Row(5)) + ) + + val res4 = spark.sql("select max(a), count(1) from test_table") + checkAnswer( + res4, + Seq(Row(2, 5)) + ) + + val res5 = spark.sql("select max(a), count(1) from test_table") + checkAnswer( + res5, + Seq(Row(2, 5, 1)) + ) + } +} diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..df3a0b67864ee43bb626f75da4799db93ed87b7c --- /dev/null +++ b/omnioperator/omniop-spark-extension/pom.xml @@ -0,0 +1,147 @@ + + + + 4.0.0 + + com.huawei.kunpeng + boostkit-omniop-spark-parent + pom + 3.1.1-1.0.0 + + BoostKit Spark Native Sql Engine Extension Parent Pom + + + 2.12.10 + 2.12 + 3.1.1 + 2.7.4 + 1.2.0-SNAPSHOT + UTF-8 + UTF-8 + 3.15.8 + FALSE + 1.0.0 + + + java + + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.arrow + arrow-vector + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + org.apache.hadoop + hadoop-client + + + org.apache.curator + curator-recipes + + + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${spark.version} + provided + + + org.apache.hadoop + hadoop-client + ${hadoop.version} + provided + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-databind + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + io.netty + netty + + + + + com.huawei.boostkit + boostkit-omniop-bindings + ${omniruntime.version} + provided + + + com.google.protobuf + protobuf-java + ${protobuf.version} + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + test-jar + test + + + org.apache.hadoop + hadoop-client + + + org.apache.curator + curator-recipes + + + + + junit + junit + 4.12 + test + + + + + + + hadoop-3.2 + + 3.2.0 + + + + + \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/scalastyle-config.xml b/omnioperator/omniop-spark-extension/scalastyle-config.xml new file mode 100644 index 0000000000000000000000000000000000000000..8e693067d047702e69f0a44b0582002ecc7caae4 --- /dev/null +++ b/omnioperator/omniop-spark-extension/scalastyle-config.xml @@ -0,0 +1,404 @@ + + + + + Scalastyle standard configuration + + + + + + + + + + + + + + + + + + + + + + + + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + + + + ^FunSuite[A-Za-z]*$ + Tests must extend org.apache.spark.SparkFunSuite instead. + + + + + ^println$ + + + + + spark(.sqlContext)?.sparkContext.hadoopConfiguration + + + + + @VisibleForTesting + + + + + Runtime\.getRuntime\.addShutdownHook + + + + + mutable\.SynchronizedBuffer + + + + + Class\.forName + + + + + Await\.result + + + + + Await\.ready + + + + + (\.toUpperCase|\.toLowerCase)(?!(\(|\(Locale.ROOT\))) + + + + + throw new \w+Error\( + + + + + + JavaConversions + Instead of importing implicits in scala.collection.JavaConversions._, import + scala.collection.JavaConverters._ and use .asScala / .asJava methods + + + + org\.apache\.commons\.lang\. + Use Commons Lang 3 classes (package org.apache.commons.lang3.*) instead + of Commons Lang 2 (package org.apache.commons.lang.*) + + + + + FileSystem.get\([a-zA-Z_$][a-zA-Z_$0-9]*\) + + + + + extractOpt + Use jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter + is slower. + + + + + java,scala,3rdParty,spark + javax?\..* + scala\..* + (?!org\.apache\.spark\.).* + org\.apache\.spark\..* + + + + + + COMMA + + + + + + \)\{ + + + + + (?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*] + Use Javadoc style indentation for multiline comments + + + + case[^\n>]*=>\s*\{ + Omit braces in case clauses. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 800> + + + + + 30 + + + + + 10 + + + + + 50 + + + + + + + + + + + -1,0,1,2,3 + + +